blob: c5ac0f99390d94c8a29a86c530aab154d5fb7513 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005import os
Tai Ly60dc48c2024-03-08 22:19:41 +00006import struct
Matthew Haddon630c17c2021-10-14 15:05:41 +01007from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01008from datetime import datetime
9from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -070010
Jeremy Johnson1271c442023-09-05 11:39:26 +010011import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000012import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000013import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010014from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010015from generator.tosa_arg_gen import TosaArgGen
16from generator.tosa_arg_gen import TosaQuantGen
17from generator.tosa_arg_gen import TosaTensorGen
18from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000019from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_error_if import TosaErrorIfArgGen
21from generator.tosa_error_if import TosaErrorValidator
22from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010023from generator.tosa_random_gen import TosaHashRandomGenerator
24from generator.tosa_random_gen import TosaRandomGenerator
Jeremy Johnson1271c442023-09-05 11:39:26 +010025from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000026from tosa.DType import DType
27from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010028
Jeremy Johnson1271c442023-09-05 11:39:26 +010029TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
30// SPDX-License-Identifier: Apache-2.0
31// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
32"""
33
Jeremy Johnsonaf090182024-02-13 18:25:39 +000034logging.basicConfig()
35logger = logging.getLogger("tosa_verif_build_tests")
36
Matthew Haddonb724efc2021-08-25 16:40:29 +010037
Eric Kunzee5e26762020-10-13 16:11:07 -070038class TosaTestGen:
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000039 # This currently matches the 8K level defined in the specification.
Jeremy Johnsonb2099702023-04-12 15:59:01 +010040 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010041 TOSA_8K_LEVEL_MAX_KERNEL = 8192
42 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010043
Jeremy Johnson1271c442023-09-05 11:39:26 +010044 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000045 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010046 TOSA_MI_DOT_PRODUCT_MIN = 1000
47
Eric Kunzee5e26762020-10-13 16:11:07 -070048 def __init__(self, args):
49 self.args = args
50 self.basePath = args.output_dir
51 self.random_seed = args.random_seed
52 self.ser = None
Eric Kunzee5e26762020-10-13 16:11:07 -070053 self.createDynamicOpLists()
54 self.initOpListDefaults()
55 self.quantGen = TosaQuantGen()
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010056 self.global_rng = None
Eric Kunzee5e26762020-10-13 16:11:07 -070057 # Force makeShape to do a specific starting shape
58 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010059 # JSON schema validation
60 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010061 # Data generator library is sometimes needed for compliance set up
62 # even if we are generating the data later (lazy_data_generation)
63 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070064
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010065 # Work out floating point range
66 def convertFPRange(rangeFP, maxFP):
67 # Converts program arguments of max/-max to FP max
68 vals = []
69 for v in rangeFP:
70 if v == "max":
71 v = maxFP
72 elif v == "-max":
73 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000074 elif v < 0:
75 # Trim to minimum data type value
76 v = max(v, -maxFP)
77 elif v > 0:
78 # Trim to maximum data type value
79 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010080 vals.append(v)
81 return tuple(sorted(vals))
82
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010083 self.random_dtype_range = {
84 DType.SHAPE: tuple(self.args.tensor_shape_range[0:2])
85 }
Won Jeon2c34b462024-02-06 18:37:00 +000086 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010087 self.random_dtype_range[dtype] = convertFPRange(
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010088 args.tensor_fp_value_range,
89 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
90 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010091 self.resetGlobalRNG()
92
93 def resetGlobalRNG(self):
94 self.global_rng = TosaRandomGenerator(self.random_seed, self.random_dtype_range)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010095
Eric Kunzee5e26762020-10-13 16:11:07 -070096 def createSerializer(self, opName, testPath):
97 self.testPath = os.path.join(opName, testPath)
98
99 fullPath = os.path.join(self.basePath, self.testPath)
100 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100101 # Embed const data in the flatbuffer
102 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +0100103 if self.args.lazy_data_gen:
104 # Lazy data generation - so make constants files
105 constMode = ts.ConstMode.INPUTS
106 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100107 constMode = ts.ConstMode.EMBED_DUMP
108 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -0700109
110 def getSerializer(self):
111 return self.ser
112
evacha01ad8e1e22024-03-19 12:42:17 +0000113 def serialize(self, testName, metaData=None, tags=None):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100114 path = Path(self.basePath) / self.testPath
115
116 # Write out TOSA flatbuffer binary
117 path_fb = path / f"{testName}.tosa"
118 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700119 fd.write(self.ser.serialize())
120
Jeremy Johnson1271c442023-09-05 11:39:26 +0100121 # Get JSON descriptor from serializer
122 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
123
124 if metaData:
125 # Add extra meta data to desc.json
126 desc["meta"] = metaData
127
evacha01ad8e1e22024-03-19 12:42:17 +0000128 if tags:
129 desc["tag"] = tags
130
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 # Validate desc.json before we output it
132 self.descSchemaValidator.validate_config(desc)
133
134 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100135 if "data_gen" in metaData:
136 if self.args.lazy_data_gen:
137 # Output datagen meta data as CPP data
138 path_md = path / f"{testName}_meta_data_gen.cpp"
139 with path_md.open("w") as fd:
140 fd.write(TOSA_AUTOGENERATED_HEADER)
141 fd.write("// Test meta data for data generation setup\n\n")
142 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
143 json.dump(metaData["data_gen"], fd)
144 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100145 if "compliance" in metaData:
146 # Output datagen meta data as CPP data
147 path_md = path / f"{testName}_meta_compliance.cpp"
148 with path_md.open("w") as fd:
149 fd.write(TOSA_AUTOGENERATED_HEADER)
150 fd.write("// Test meta data for compliance validation\n\n")
151 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
152 json.dump(metaData["compliance"], fd)
153 fd.write(')";\n\n')
154
155 # Write desc.json
156 path_desc = path / "desc.json"
157 with path_desc.open("w") as fd:
158 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700159
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100160 def buildPlaceholderTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700161 placeholders = []
162
Kevin Cheng989cb052021-04-28 16:29:44 -0700163 assert len(shape_list) == len(dtype_list)
164
Jeremy Johnson1271c442023-09-05 11:39:26 +0100165 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700166 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100167 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100168 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700169 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700170
171 return placeholders
172
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100173 def buildConstTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700174 consts = []
175
Kevin Cheng989cb052021-04-28 16:29:44 -0700176 assert len(shape_list) == len(dtype_list)
177
Jeremy Johnson1271c442023-09-05 11:39:26 +0100178 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700179 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100180 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100181 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700182 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700183
184 return consts
185
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100186 def makeShape(self, rng, rank):
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 if self.targetted_shape:
188 return np.int32(self.targetted_shape)
Jeremy Johnson18a379d2024-03-28 15:53:21 +0000189 else:
190 return np.int32(
191 rng.integers(
192 low=self.args.tensor_shape_range[0],
193 high=self.args.tensor_shape_range[1],
194 size=rank,
195 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800196 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700197
198 def setTargetShape(self, shape):
199 self.targetted_shape = shape
200
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 def shapeStr(self, shape):
Jeremy Johnson18a379d2024-03-28 15:53:21 +0000202 assert shape is not None
203 if len(shape) > 0:
204 # Rank > 0
205 return "x".join([str(d) for d in shape])
206 else:
207 # Rank 0
208 return "0"
Eric Kunzee5e26762020-10-13 16:11:07 -0700209
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100210 def typeStr(self, dtype):
211 if isinstance(dtype, list) or isinstance(dtype, tuple):
212 assert len(dtype) >= 2
213 strs = [self.typeStr(t) for t in dtype]
214 # Limit types to the first 2 as the 3rd is the accumulator
215 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700216 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100217 if dtype in gtu.DTYPE_ATTRIBUTES:
218 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700219 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100220 raise Exception(
221 "Unknown dtype, cannot convert to string: {}".format(dtype)
222 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700223
Luke Hutton57287132023-02-06 14:54:18 +0000224 def constrictBatchSize(self, shape):
225 # Limit the batch size unless an explicit target shape set
226 if self.args.max_batch_size and not self.args.target_shapes:
227 shape[0] = min(shape[0], self.args.max_batch_size)
228 return shape
229
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100230 def makeDimension(self, rng):
231 return rng.randInt(
James Ward30124a82023-02-02 14:56:33 +0000232 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
233 )
234
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100235 def tensorComplianceMetaData(
236 self, op, inputType, argsDict, outputTensor, errorName
237 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000238 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
239 UNSUPPORTED_NON_FP32_INPUT_OPS = (
240 Op.MATMUL,
241 Op.CONV2D,
242 Op.FULLY_CONNECTED,
243 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000244 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000245 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000246 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100247 if (
248 errorName
249 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000250 or (
251 not gtu.dtypeIsSupportedByCompliance(inputType)
252 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
253 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100254 ):
255 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100256 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100257
Jeremy Johnson1271c442023-09-05 11:39:26 +0100258 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100259 compliance_tens = {
260 "mode": None,
261 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
262 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
263 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100264 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
265 mode = gtu.ComplianceMode.DOT_PRODUCT
266 compliance_tens["dot_product_info"] = {
267 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100268 "ks": int(argsDict["ksb"])
269 if "ksb" in argsDict
270 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100271 }
evacha019c96eef2024-02-07 11:21:55 +0000272 elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100273 mode = gtu.ComplianceMode.FP_SPECIAL
274 elif "compliance" in op and "ulp" in op["compliance"]:
275 mode = gtu.ComplianceMode.ULP
276 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000277 elif "compliance" in op and "relative" in op["compliance"]:
278 mode = gtu.ComplianceMode.RELATIVE
279 compliance_tens["relative_info"] = {
280 "max": argsDict["max_abs_value"],
281 "scale": op["compliance"]["relative"],
282 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100283 elif op["op"] == Op.REDUCE_PRODUCT:
284 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000285 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000286 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000287 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000288 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
289 compliance_tens["abs_error_info"] = {
290 "lower_bound": op["compliance"]["abs_error_lower_bound"]
291 }
Jerry Ge51bd4f52024-02-20 11:21:19 -0800292 elif op["op"] in (Op.SIN, Op.COS):
293 mode = gtu.ComplianceMode.ABS_ERROR
294 if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
295 compliance_tens["abs_error_info"] = {
296 "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
297 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100298 else:
299 mode = gtu.ComplianceMode.EXACT
300 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
301
302 return compliance_tens
303
304 # Build Op functions
305 # Create the output tensor (calling OutputShaper as needed)
306 # Do final tweaks to attributes (if necessary for errorIf)
307 # Add Op into graph
308 # Return resulting tensor information or BuildInfo
309
310 class BuildInfo:
311 """Enhanced build information containing result tensor and associated compliance dict."""
312
313 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000314 if isinstance(resultTensor, list):
315 assert complianceDict is None or isinstance(complianceDict, list)
316 self.resultTensorList = resultTensor
317 self.complianceDictList = complianceDict
318 else:
319 self.resultTensorList = [resultTensor]
320 if complianceDict is None:
321 self.complianceDictList = None
322 else:
323 self.complianceDictList = [complianceDict]
324
325 def getComplianceInfo(self):
326 if self.complianceDictList is None:
327 return None
328 else:
329 tens_dict = {}
330 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
331 if comp is not None:
332 tens_dict[tens.name] = comp
333
334 if tens_dict:
335 # Have some compliance data, so return the info
336 compliance = {
337 "version": "0.1",
338 "tensors": tens_dict,
339 }
340 else:
341 compliance = None
342 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700343
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000344 def build_unary(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100345 self,
346 rng,
347 op,
348 inputs,
349 args_dict,
350 validator_fcns=None,
351 error_name=None,
352 qinfo=None,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000353 ):
354 assert len(inputs) == 1
355 a = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100356 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100357
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000358 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100359
360 # Ensure new output type has correct qinfo
361 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000362 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000363 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100364 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, a.dtype),
365 TosaQuantGen.getZeroPoint(
366 rng, self.args.zeropoint, result_tensor.dtype
367 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000368 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100369
370 # Invalidate Input/Output list for error if checks.
371 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000372 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100373 pCount, cCount = op["operands"]
374 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000375 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100376 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000377 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100378
Les Bell729b0352021-11-24 10:28:21 +0000379 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100380 self.ser,
381 validator_fcns,
382 error_name,
383 op=op,
384 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000385 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000386 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000387 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100388 input_list=input_list,
389 output_list=output_list,
390 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000391 ):
392 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100393
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000394 attr = None
395 if op["op"] == Op.NEGATE:
396 attr = ts.TosaSerializerAttribute()
397 attr.NegateAttribute(qinfo[0], qinfo[1])
398
399 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000400
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000401 compliance = self.tensorComplianceMetaData(
402 op, a.dtype, args_dict, result_tensor, error_name
403 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000404 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700405
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000406 def build_binary_broadcast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100407 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000408 ):
409 assert len(inputs) == 2
410 a, b = inputs
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100411 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100412
413 # Invalidate Input/Output list for error if checks.
414 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000415 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100416 pCount, cCount = op["operands"]
417 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000418 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100419 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000420 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100421
Les Bell729b0352021-11-24 10:28:21 +0000422 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100423 self.ser,
424 validator_fcns,
425 error_name,
426 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000427 input1=a,
428 input2=b,
429 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000430 output_dtype=result_tensor.dtype,
431 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100432 input_list=input_list,
433 output_list=output_list,
434 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000435 ):
436 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100437
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000438 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000439
Jeremy Johnson9a758382023-11-07 16:27:35 +0000440 compliance = self.tensorComplianceMetaData(
441 op, a.dtype, args_dict, result_tensor, error_name
442 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000443
444 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700445
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000446 def build_arithmetic_right_shift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100447 self,
448 rng,
449 op,
450 inputs,
451 args_dict,
452 validator_fcns=None,
453 error_name=None,
454 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000455 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000456 assert len(inputs) == 2
457 a, b = inputs
458 round = args_dict["round"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100459 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100460
461 # Invalidate Input/Output list for error if checks.
462 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000463 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100464 pCount, cCount = op["operands"]
465 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000466 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100467 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000468 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100469
Les Bell729b0352021-11-24 10:28:21 +0000470 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100471 self.ser,
472 validator_fcns,
473 error_name,
474 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000475 input1=a,
476 input2=b,
477 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000478 output_dtype=result_tensor.dtype,
479 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100480 input_list=input_list,
481 output_list=output_list,
482 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000483 ):
484 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800485
486 attr = ts.TosaSerializerAttribute()
487 attr.ArithmeticRightShiftAttribute(round)
488
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000489 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000490
491 compliance = self.tensorComplianceMetaData(
492 op, a.dtype, args_dict, result_tensor, error_name
493 )
494
495 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800496
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100497 def build_mul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100498 self,
499 rng,
500 op,
501 inputs,
502 args_dict,
503 validator_fcns=None,
504 error_name=None,
505 qinfo=None,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100506 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000507 # Note that mul is binary operator but it has a shift value tensor
508 assert len(inputs) == 3
509 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100510
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100511 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700512
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100513 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100514 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100515 result_tensor.setDtype(DType.INT32)
516
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100517 if error_name == ErrorIf.WrongOutputType:
518 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100519 outputDType = rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100520 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100521
522 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000523 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100524 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100525 pCount, cCount = op["operands"]
526 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000527 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100528 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000529 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100530
Les Bell729b0352021-11-24 10:28:21 +0000531 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100532 self.ser,
533 validator_fcns,
534 error_name,
535 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000536 input1=a,
537 input2=b,
538 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100539 output_dtype=result_tensor.dtype,
540 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100541 input_list=input_list,
542 output_list=output_list,
543 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000544 ):
545 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700546
Jeremy Johnson0a042992024-02-28 13:20:05 +0000547 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100548
549 compliance = self.tensorComplianceMetaData(
550 op, a.dtype, args_dict, result_tensor, error_name
551 )
552
553 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700554
Jeremy Johnson587cc842024-02-08 11:45:44 +0000555 def build_table(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100556 self,
557 rng,
558 op,
559 inputs,
560 args_dict,
561 validator_fcns=None,
562 error_name=None,
563 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000564 ):
565 assert len(inputs) == 1
566 a = inputs[0]
567 table = args_dict["table"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100568 result_tensor = OutputShaper.tableOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700569
Kevin Chengfe392ce2021-10-18 21:51:55 +0000570 attr = ts.TosaSerializerAttribute()
571 attr.TableAttribute(table)
572
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100573 # Invalidate Input/Output list for error if checks.
574 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000575 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100576 pCount, cCount = op["operands"]
577 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000578 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100579 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000580 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100581
Les Bell729b0352021-11-24 10:28:21 +0000582 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100583 self.ser,
584 validator_fcns,
585 error_name,
586 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000587 input_shape=a.shape,
588 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000589 output_dtype=result_tensor.dtype,
590 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100591 input_list=input_list,
592 output_list=output_list,
593 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000594 ):
595 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100596
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000597 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700598
Jeremy Johnson587cc842024-02-08 11:45:44 +0000599 compliance = self.tensorComplianceMetaData(
600 op, a.dtype, args_dict, result_tensor, error_name
601 )
602
603 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700604
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000605 def build_select(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100606 self,
607 rng,
608 op,
609 inputs,
610 args_dict,
611 validator_fcns=None,
612 error_name=None,
613 qinfo=None,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000614 ):
615 assert len(inputs) == 3
616 cond, a, b = inputs
617
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100618 result_tensor = OutputShaper.selectOp(self.ser, rng, cond, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100619
620 # Invalidate Input/Output list for error if checks.
621 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000622 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100623 pCount, cCount = op["operands"]
624 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000625 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100626 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000627 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100628
Les Bell729b0352021-11-24 10:28:21 +0000629 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100630 self.ser,
631 validator_fcns,
632 error_name,
633 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000634 input1=cond,
635 input2=a,
636 input3=b,
637 input_shape=a.shape,
638 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000639 output_dtype=result_tensor.dtype,
640 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100641 input_list=input_list,
642 output_list=output_list,
643 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000644 ):
645 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100646
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000647 self.ser.addOperator(
648 op["op"],
649 input_list,
650 output_list,
651 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000652 compliance = self.tensorComplianceMetaData(
653 op, a.dtype, args_dict, result_tensor, error_name
654 )
655
656 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700657
Jeremy Johnsona0150012023-11-15 15:52:06 +0000658 def build_comparison(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100659 self,
660 rng,
661 op,
662 inputs,
663 args_dict,
664 validator_fcns=None,
665 error_name=None,
666 qinfo=None,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000667 ):
668 assert len(inputs) == 2
669 a, b = inputs
670
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100671 result_tensor = OutputShaper.binaryComparisonOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100672
673 # Invalidate Input/Output list for error if checks.
674 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000675 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100676 pCount, cCount = op["operands"]
677 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000678 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100679 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000680 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100681
Les Bell729b0352021-11-24 10:28:21 +0000682 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100683 self.ser,
684 validator_fcns,
685 error_name,
686 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000687 input1=a,
688 input2=b,
689 input_shape=a.shape,
690 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000691 output_shape=result_tensor.shape,
692 output_dtype=result_tensor.dtype,
693 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100694 input_list=input_list,
695 output_list=output_list,
696 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000697 ):
698 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100699
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000700 self.ser.addOperator(
701 op["op"],
702 input_list,
703 output_list,
704 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000705
706 compliance = self.tensorComplianceMetaData(
707 op, a.dtype, args_dict, result_tensor, error_name
708 )
709 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700710
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000711 def build_argmax(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100712 self, rng, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000713 ):
714 assert len(inputs) == 1
715 a = inputs[0]
716 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100717 result_tensor = OutputShaper.argmaxOp(self.ser, rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100718
719 # Invalidate Input/Output list for error if checks.
720 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000721 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100722 pCount, cCount = op["operands"]
723 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000724 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100725 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000726 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100727
Les Bell729b0352021-11-24 10:28:21 +0000728 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100729 self.ser,
730 validator_fcns,
731 error_name,
732 op=op,
733 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000734 input_shape=a.shape,
735 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000736 output_shape=result_tensor.shape,
737 output_dtype=result_tensor.dtype,
738 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100739 input_list=input_list,
740 output_list=output_list,
741 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000742 ):
743 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700744
745 attr = ts.TosaSerializerAttribute()
746 attr.AxisAttribute(axis)
747
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000748 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000749
750 compliance = self.tensorComplianceMetaData(
751 op, inputs[0].dtype, args_dict, result_tensor, error_name
752 )
753 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700754
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000755 def build_pool2d(
756 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100757 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000758 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100759 inputs,
760 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000761 validator_fcns=None,
762 error_name=None,
763 qinfo=None,
764 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100765 assert len(inputs) == 1
766 input = inputs[0]
767 # max_pool has no accum_dtype
768 accum_dtype = (
769 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
770 )
771 stride = args_dict["stride"]
772 pad = args_dict["pad"]
773 kernel = args_dict["kernel"]
774
Jeremy Johnson0601f802023-11-08 16:28:09 +0000775 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100776 self.ser, rng, input, kernel, stride, pad, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000777 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100778
779 # Ensure new output type has correct qinfo
780 if error_name == ErrorIf.WrongInputType:
781 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000782 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100783 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, input.dtype),
784 TosaQuantGen.getZeroPoint(
785 rng, self.args.zeropoint, result_tensor.dtype
786 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000787 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100788
789 # Invalidate Input/Output list for error if checks.
790 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000791 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100792 pCount, cCount = op["operands"]
793 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000794 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100795 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000796 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100797
Les Bell729b0352021-11-24 10:28:21 +0000798 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100799 self.ser,
800 validator_fcns,
801 error_name,
802 op=op,
803 input_shape=input.shape,
804 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000805 output_shape=result_tensor.shape,
806 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000807 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100808 kernel=kernel,
809 stride=stride,
810 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000811 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000812 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100813 input_list=input_list,
814 output_list=output_list,
815 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000816 ):
817 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700818
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000819 if qinfo is None:
820 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700821
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000822 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100823 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000824
825 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700826
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100827 compliance = self.tensorComplianceMetaData(
828 op, inputs[0].dtype, args_dict, result_tensor, error_name
829 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100830
831 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100832
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000833 def build_conv2d(
834 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100835 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000836 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100837 inputs,
838 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000839 validator_fcns=None,
840 error_name=None,
841 qinfo=None,
842 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100843 assert len(inputs) == 3
844 ifm, filter, bias = inputs
845 accum_dtype = args_dict["acc_type"]
846 strides = args_dict["stride"]
847 padding = args_dict["pad"]
848 dilations = args_dict["dilation"]
849
Kevin Cheng550ccc52021-03-03 11:21:43 -0800850 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100851 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100852 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100853 rng,
James Ward8b390432022-08-12 20:48:56 +0100854 ifm,
855 filter,
856 accum_dtype,
857 strides,
858 padding,
859 dilations,
860 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000861 )
862
863 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000864 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
865 DType.INT8,
866 DType.UINT8,
867 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000868 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100869 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
870 TosaQuantGen.getZeroPoint(
871 rng, self.args.zeropoint, result_tensor.dtype
872 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000873 ]
Les Bell0e027d42021-11-09 14:42:14 +0000874
875 # Invalidate Input/Output list for error_if checks.
876 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100877 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000878 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000879 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100880 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000881 )
Les Bell0e027d42021-11-09 14:42:14 +0000882
Les Bell729b0352021-11-24 10:28:21 +0000883 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000884 self.ser,
885 validator_fcns,
886 error_name,
887 op=op,
888 input_dtype=ifm.dtype,
889 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100890 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000891 qinfo=qinfo,
892 input_list=input_list,
893 num_operands=num_operands,
894 output_list=output_list,
895 pad=padding,
896 stride=strides,
897 dilation=dilations,
898 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100899 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100900 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000901 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000902 ):
903 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700904
Tai Lyd3797f02023-11-15 23:06:19 +0000905 # TODO - Test local_bound, for now set local bound attribute to False
906 local_bound = False
907
Eric Kunzee5e26762020-10-13 16:11:07 -0700908 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +0000909 attr.ConvAttribute(
910 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
911 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700912
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000913 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100914
915 compliance = self.tensorComplianceMetaData(
916 op, ifm.dtype, args_dict, result_tensor, error_name
917 )
918
919 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700920
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000921 def build_conv3d(
922 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100923 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000924 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100925 inputs,
926 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000927 validator_fcns=None,
928 error_name=None,
929 qinfo=None,
930 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100931 assert len(inputs) == 3
932 ifm, filter, bias = inputs
933 accum_dtype = args_dict["acc_type"]
934 strides = args_dict["stride"]
935 padding = args_dict["pad"]
936 dilations = args_dict["dilation"]
937
Kevin Cheng1533b852021-09-01 12:51:58 -0700938 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000939 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100940 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100941 rng,
James Ward8b390432022-08-12 20:48:56 +0100942 ifm,
943 filter,
944 accum_dtype,
945 strides,
946 padding,
947 dilations,
948 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000949 )
950
951 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000952 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
953 DType.INT8,
954 DType.UINT8,
955 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000956 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100957 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
958 TosaQuantGen.getZeroPoint(
959 rng, self.args.zeropoint, result_tensor.dtype
960 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000961 ]
Les Bell0e027d42021-11-09 14:42:14 +0000962
963 # Invalidate Input/Output list for error_if checks.
964 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +0000965 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000966 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000967 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100968 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000969 )
Les Bell0e027d42021-11-09 14:42:14 +0000970
Les Bell729b0352021-11-24 10:28:21 +0000971 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000972 self.ser,
973 validator_fcns,
974 error_name,
975 op=op,
976 input_dtype=ifm.dtype,
977 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +0000978 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000979 qinfo=qinfo,
980 input_list=input_list,
981 num_operands=num_operands,
982 output_list=output_list,
983 pad=padding,
984 stride=strides,
985 dilation=dilations,
986 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100987 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +0000988 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000989 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000990 ):
991 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700992
Tai Lyd3797f02023-11-15 23:06:19 +0000993 # TODO - Test local_bound, for now set local bound attribute to False
994 local_bound = False
995
Kevin Cheng1533b852021-09-01 12:51:58 -0700996 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +0000997 attr.ConvAttribute(
998 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
999 )
Kevin Cheng1533b852021-09-01 12:51:58 -07001000
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001001 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001002
1003 compliance = self.tensorComplianceMetaData(
1004 op, ifm.dtype, args_dict, result_tensor, error_name
1005 )
1006
1007 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001008
Kevin Cheng550ccc52021-03-03 11:21:43 -08001009 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001010 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001011 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001012 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001013 inputs,
1014 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001015 validator_fcns=None,
1016 error_name=None,
1017 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001018 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001019 assert len(inputs) == 3
1020 ifm, filter, bias = inputs
1021 accum_dtype = args_dict["acc_type"]
1022 strides = args_dict["stride"]
1023 out_pad = args_dict["pad"]
1024 output_shape = args_dict["out_shape"]
1025
TatWai Chong24594f52022-06-08 00:48:04 -07001026 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001027 result_tensor = OutputShaper.transposeConv2DOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001028 self.ser, rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001029 )
Les Bell0e027d42021-11-09 14:42:14 +00001030
1031 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001032 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1033 DType.INT8,
1034 DType.UINT8,
1035 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001036 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001037 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1038 TosaQuantGen.getZeroPoint(
1039 rng, self.args.zeropoint, result_tensor.dtype
1040 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001041 ]
Les Bell0e027d42021-11-09 14:42:14 +00001042
1043 # Invalidate Input/Output list for error_if checks.
1044 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001045 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001046 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001047 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001048 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001049 )
Les Bell0e027d42021-11-09 14:42:14 +00001050
Les Bell729b0352021-11-24 10:28:21 +00001051 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001052 self.ser,
1053 validator_fcns,
1054 error_name,
1055 op=op,
1056 input_dtype=ifm.dtype,
1057 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001058 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001059 qinfo=qinfo,
1060 input_list=input_list,
1061 num_operands=num_operands,
1062 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001063 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001064 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001065 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001066 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001067 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001068 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001069 ):
1070 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001071
Tai Lyd3797f02023-11-15 23:06:19 +00001072 # TODO - Test local_bound, for now set local bound attribute to False
1073 local_bound = False
1074
Eric Kunzee5e26762020-10-13 16:11:07 -07001075 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001076 attr.TransposeConvAttribute(
Tai Lyf36f2562024-03-14 16:21:29 +00001077 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype
Tai Lyd3797f02023-11-15 23:06:19 +00001078 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001079
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001080 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001081
1082 compliance = self.tensorComplianceMetaData(
1083 op, ifm.dtype, args_dict, result_tensor, error_name
1084 )
1085
1086 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001087
Kevin Cheng550ccc52021-03-03 11:21:43 -08001088 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001089 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001090 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001091 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001092 inputs,
1093 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001094 validator_fcns=None,
1095 error_name=None,
1096 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001097 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001098 assert len(inputs) == 3
1099 ifm, filter, bias = inputs
1100 accum_dtype = args_dict["acc_type"]
1101 strides = args_dict["stride"]
1102 padding = args_dict["pad"]
1103 dilations = args_dict["dilation"]
1104
Jeremy Johnson4f931302024-01-04 17:05:24 +00001105 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001106 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001107 rng,
James Ward8b390432022-08-12 20:48:56 +01001108 ifm,
1109 filter,
1110 accum_dtype,
1111 strides,
1112 padding,
1113 dilations,
1114 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001115 )
1116
1117 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001118 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1119 DType.INT8,
1120 DType.UINT8,
1121 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001122 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001123 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1124 TosaQuantGen.getZeroPoint(
1125 rng, self.args.zeropoint, result_tensor.dtype
1126 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001127 ]
Les Bell0e027d42021-11-09 14:42:14 +00001128
1129 # Invalidate Input/Output list for error_if checks.
1130 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001131 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001132 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001133 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001134 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001135 )
Les Bell0e027d42021-11-09 14:42:14 +00001136
Les Bell729b0352021-11-24 10:28:21 +00001137 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001138 self.ser,
1139 validator_fcns,
1140 error_name,
1141 op=op,
1142 input_dtype=ifm.dtype,
1143 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001144 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001145 qinfo=qinfo,
1146 input_list=input_list,
1147 num_operands=num_operands,
1148 output_list=output_list,
1149 pad=padding,
1150 stride=strides,
1151 dilation=dilations,
1152 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001153 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001154 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001155 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001156 ):
1157 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001158
Tai Lyd3797f02023-11-15 23:06:19 +00001159 # TODO - Test local_bound, for now set local bound attribute to False
1160 local_bound = False
1161
Eric Kunzee5e26762020-10-13 16:11:07 -07001162 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +00001163 attr.ConvAttribute(
1164 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
1165 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001166
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001167 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001168
1169 compliance = self.tensorComplianceMetaData(
1170 op, ifm.dtype, args_dict, result_tensor, error_name
1171 )
1172
1173 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001174
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001175 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001176 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001177 rng,
James Ward8b390432022-08-12 20:48:56 +01001178 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001179 inputs,
1180 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001181 validator_fcns=None,
1182 error_name=None,
1183 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001184 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001185 assert len(inputs) == 3
1186 ifm, filter, bias = inputs
1187 accum_dtype = args_dict["acc_type"]
1188
1189 result_tensor = OutputShaper.fullyConnectedOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001190 self.ser, rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001191 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001192
1193 # Invalidate Input/Output list for error if checks.
1194 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001195 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001196 pCount, cCount = op["operands"]
1197 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001198 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001199 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001200 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001201
Les Bell729b0352021-11-24 10:28:21 +00001202 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001203 self.ser,
1204 validator_fcns,
1205 error_name,
1206 op=op,
1207 input_shape=ifm.shape,
1208 input_dtype=ifm.dtype,
1209 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001210 output_shape=result_tensor.shape,
1211 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001212 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001213 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001214 input_list=input_list,
1215 output_list=output_list,
1216 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001217 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001218 ):
1219 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001220
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001221 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001222 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001223
1224 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001225
1226 compliance = self.tensorComplianceMetaData(
1227 op, ifm.dtype, args_dict, result_tensor, error_name
1228 )
1229
1230 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001231
James Ward8b390432022-08-12 20:48:56 +01001232 def build_matmul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001233 self,
1234 rng,
1235 op,
1236 inputs,
1237 args_dict,
1238 validator_fcns=None,
1239 error_name=None,
1240 qinfo=None,
James Ward8b390432022-08-12 20:48:56 +01001241 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001242 assert len(inputs) == 2
1243 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001244 accum_dtype = args_dict["acc_type"]
1245 result_tensor = OutputShaper.matmulOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001246 self.ser, rng, a, b, accum_dtype, error_name
James Ward8b390432022-08-12 20:48:56 +01001247 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001248
1249 # Invalidate Input/Output list for error if checks.
1250 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001251 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001252 pCount, cCount = op["operands"]
1253 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001254 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001255 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001256 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001257
Les Bell729b0352021-11-24 10:28:21 +00001258 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001259 self.ser,
1260 validator_fcns,
1261 error_name,
1262 op=op,
1263 input_shape=a.shape,
1264 input_dtype=a.dtype,
1265 input2_shape=b.shape,
1266 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001267 output_shape=result_tensor.shape,
1268 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001269 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001270 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001271 input_list=input_list,
1272 output_list=output_list,
1273 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001274 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001275 ):
1276 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001277
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001278 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001279 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001280
1281 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001282
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001283 compliance = self.tensorComplianceMetaData(
1284 op, a.dtype, args_dict, result_tensor, error_name
1285 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001286
1287 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001288
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001289 def build_reduce(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001290 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001291 ):
1292 assert len(inputs) == 1
1293 a = inputs[0]
1294 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001295 result_tensor = OutputShaper.reduceOp(self.ser, rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001296
1297 # Invalidate Input/Output list for error if checks.
1298 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001299 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001300 pCount, cCount = op["operands"]
1301 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001302 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001303 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001304 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001305
Les Bell729b0352021-11-24 10:28:21 +00001306 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001307 self.ser,
1308 validator_fcns,
1309 error_name,
1310 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001311 axis=axis,
1312 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001313 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001314 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001315 output_dtype=result_tensor.dtype,
1316 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001317 input_list=input_list,
1318 output_list=output_list,
1319 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001320 ):
1321 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001322
1323 attr = ts.TosaSerializerAttribute()
1324 attr.AxisAttribute(axis)
1325
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001326 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001327
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001328 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1329 # Number of products - needed for compliance
1330 args_dict["n"] = a.shape[axis]
1331
1332 compliance = self.tensorComplianceMetaData(
1333 op, a.dtype, args_dict, result_tensor, error_name
1334 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001335
1336 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001337
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001338 def build_clamp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001339 self,
1340 rng,
1341 op,
1342 inputs,
1343 args_dict,
1344 validator_fcns=None,
1345 error_name=None,
1346 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001347 ):
1348 assert len(inputs) == 1
1349 a = inputs[0]
1350
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001351 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001352
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001353 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001354
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001355 if error_name == ErrorIf.MaxSmallerMin:
1356 # Make sure the numbers are different to invoke this error
1357 while v[0] == v[1]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001358 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001359 max_val = min(v)
1360 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001361 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001362 max_val = max(v)
1363 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001364
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001365 # Invalidate Input/Output list for error if checks.
1366 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001367 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001368 pCount, cCount = op["operands"]
1369 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001370 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001371 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001372 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001373
Les Bell729b0352021-11-24 10:28:21 +00001374 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001375 self.ser,
1376 validator_fcns,
1377 error_name,
1378 op=op,
1379 max_val=max_val,
1380 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001381 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001382 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001383 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001384 output_dtype=result_tensor.dtype,
1385 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001386 input_list=input_list,
1387 output_list=output_list,
1388 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001389 ):
1390 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001391
1392 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001393 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1394 if a.dtype == DType.FP16:
1395 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1396 min_val = min_val.astype(np.float32)
1397 max_val = max_val.astype(np.float32)
Tai Ly60dc48c2024-03-08 22:19:41 +00001398 min_val_as_bytes = struct.pack("<f", min_val)
1399 max_val_as_bytes = struct.pack("<f", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001400 elif a.dtype in (DType.INT8, DType.INT16):
Tai Ly60dc48c2024-03-08 22:19:41 +00001401 min_val_as_bytes = struct.pack("<i", min_val)
1402 max_val_as_bytes = struct.pack("<i", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001403 else:
1404 # to avoid internal error for incorrect input types
Tai Ly60dc48c2024-03-08 22:19:41 +00001405 min_val_as_bytes = struct.pack("<i", 0)
1406 max_val_as_bytes = struct.pack("<i", 0)
1407
1408 attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001409
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001410 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001411
1412 compliance = self.tensorComplianceMetaData(
1413 op, a.dtype, args_dict, result_tensor, error_name
1414 )
1415
1416 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001417
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001418 def build_activation(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001419 self,
1420 rng,
1421 op,
1422 inputs,
1423 args_dict,
1424 validator_fcns=None,
1425 error_name=None,
1426 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001427 ):
1428 assert len(inputs) == 1
1429 a = inputs[0]
1430
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001431 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432
1433 # Invalidate Input/Output list for error if checks.
1434 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001435 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001436 pCount, cCount = op["operands"]
1437 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001438 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001439 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001440 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001441
Les Bell729b0352021-11-24 10:28:21 +00001442 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001443 self.ser,
1444 validator_fcns,
1445 error_name,
1446 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001447 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001448 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001449 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001450 output_dtype=result_tensor.dtype,
1451 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001452 input_list=input_list,
1453 output_list=output_list,
1454 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001455 ):
1456 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001457
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001458 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001459
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001460 compliance = self.tensorComplianceMetaData(
1461 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001462 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001463
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001464 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001465
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001466 def build_concat(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001467 self,
1468 rng,
1469 op,
1470 inputs,
1471 args_dict,
1472 validator_fcns=None,
1473 error_name=None,
1474 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001475 ):
Won Jeon74342e52024-01-09 00:34:40 +00001476 if op["op"] == Op.CONCAT_SHAPE:
1477 axis = 0
1478 else:
1479 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001480 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001481 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001482
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001483 result_tensor = OutputShaper.concatOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001484 self.ser, rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001485 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001486
Matthew Haddon818ab902021-07-27 09:12:49 +01001487 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001488 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001489 input_tensor_names.append(tensor.name)
1490
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001491 # Invalidate Input/Output list for error if checks.
1492 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001493 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001494 pCount, cCount = op["operands"]
1495 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001496 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001497 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001498 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001499
Les Bell729b0352021-11-24 10:28:21 +00001500 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001501 self.ser,
1502 validator_fcns,
1503 error_name,
1504 op=op,
1505 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001506 input_shape=inputs[0].shape,
1507 output_shape=result_tensor.shape,
1508 input_dtype=inputs[0].dtype,
1509 output_dtype=result_tensor.dtype,
1510 inputs=inputs,
1511 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001512 input_list=input_list,
1513 output_list=output_list,
1514 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001515 ):
1516 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001517
Won Jeon74342e52024-01-09 00:34:40 +00001518 if op["op"] == Op.CONCAT:
1519 attr = ts.TosaSerializerAttribute()
1520 attr.AxisAttribute(axis)
1521 else:
1522 assert op["op"] == Op.CONCAT_SHAPE
1523 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001524 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001525
1526 compliance = self.tensorComplianceMetaData(
1527 op, inputs[0].dtype, args_dict, result_tensor, error_name
1528 )
1529
1530 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001531
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001532 def build_pad(
1533 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001534 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001535 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001536 inputs,
1537 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001538 validator_fcns=None,
1539 error_name=None,
1540 qinfo=None,
1541 ):
Tai Lye095da72024-01-25 22:00:18 +00001542 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001543 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001544 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001545 padding = args_dict["pad"]
1546 pad_const_int = args_dict["pad_const_int"]
1547 pad_const_float = args_dict["pad_const_fp"]
1548
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001549 result_tensor = OutputShaper.padOp(self.ser, rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001550
Tai Ly60dc48c2024-03-08 22:19:41 +00001551 # get pad_const_val_as_bytes from either pad_const_float or pad_const_int
1552 if gtu.dtypeIsFloat(a.dtype):
1553 pad_const_val_as_bytes = struct.pack("<f", pad_const_float)
1554 else:
1555 pad_const_val_as_bytes = struct.pack("<i", pad_const_int)
1556
Kevin Chengfe392ce2021-10-18 21:51:55 +00001557 attr = ts.TosaSerializerAttribute()
Tai Ly60dc48c2024-03-08 22:19:41 +00001558 attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)
Eric Kunzee5e26762020-10-13 16:11:07 -07001559
Matthew Haddone807aae2021-10-11 18:12:58 +01001560 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001561 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001562 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001563 pCount, cCount = op["operands"]
1564 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001565 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001566 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001567 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001568
Les Bell729b0352021-11-24 10:28:21 +00001569 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001570 self.ser,
1571 validator_fcns,
1572 error_name,
1573 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001574 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001575 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001576 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001577 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001578 pad=padding,
1579 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001580 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001581 input_list=input_list,
1582 output_list=output_list,
1583 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001584 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001585 ):
1586 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001587
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001588 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001589
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001590 compliance = self.tensorComplianceMetaData(
1591 op, a.dtype, args_dict, result_tensor, error_name
1592 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001593
1594 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001595
Won Jeona21b2e82023-08-10 10:33:01 +00001596 def build_dim(
1597 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001598 rng,
Won Jeona21b2e82023-08-10 10:33:01 +00001599 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001600 inputs,
1601 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001602 validator_fcns=None,
1603 error_name=None,
1604 qinfo=None,
1605 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001606 assert len(inputs) == 1
1607 a = inputs[0]
1608 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001609 result_tensor = OutputShaper.dimOp(self.ser, rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001610
1611 # Invalidate Input/Output list for error if checks.
1612 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001613 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001614 pCount, cCount = op["operands"]
1615 num_operands = pCount + cCount
1616 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001617 rng, error_name, input_list, output_list
Won Jeona21b2e82023-08-10 10:33:01 +00001618 )
1619
1620 if not TosaErrorValidator.evValidateErrorIfs(
1621 self.ser,
1622 validator_fcns,
1623 error_name,
1624 op=op,
1625 axis=axis,
1626 input_shape=a.shape,
1627 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001628 output_shape=result_tensor.shape,
1629 output_dtype=result_tensor.dtype,
1630 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001631 input_list=input_list,
1632 output_list=output_list,
1633 num_operands=num_operands,
1634 ):
1635 return None
1636
1637 attr = ts.TosaSerializerAttribute()
1638 attr.AxisAttribute(axis)
1639
1640 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001641 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001642
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001643 def build_reshape(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001644 self,
1645 rng,
1646 op,
1647 inputs,
1648 args_dict,
1649 validator_fcns=None,
1650 error_name=None,
1651 qinfo=None,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001652 ):
Tai Ly8690a082023-12-18 20:40:24 +00001653 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001654 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001655 shape = inputs[1]
1656 shape_attr = args_dict["new_shape"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001657 result_tensor = OutputShaper.reshapeOp(self.ser, rng, a, shape_attr, error_name)
Matthew Haddone807aae2021-10-11 18:12:58 +01001658
1659 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001660 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001661 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001662 pCount, cCount = op["operands"]
1663 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001664 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001665 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001666 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001667
Les Bell729b0352021-11-24 10:28:21 +00001668 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001669 self.ser,
1670 validator_fcns,
1671 error_name,
1672 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001673 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001674 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001675 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001676 output_dtype=result_tensor.dtype,
1677 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001678 input_list=input_list,
1679 output_list=output_list,
1680 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001681 ):
1682 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001683
Tai Ly8690a082023-12-18 20:40:24 +00001684 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001685
1686 compliance = self.tensorComplianceMetaData(
1687 op, a.dtype, args_dict, result_tensor, error_name
1688 )
1689
1690 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001691
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001692 def build_reverse(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001693 self,
1694 rng,
1695 op,
1696 inputs,
1697 args_dict,
1698 validator_fcns=None,
1699 error_name=None,
1700 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001701 ):
1702 assert len(inputs) == 1
1703 a = inputs[0]
1704 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001705 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001706
1707 # Invalidate Input/Output list for error if checks.
1708 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001709 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001710 pCount, cCount = op["operands"]
1711 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001712 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001713 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001714 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001715
Les Bell729b0352021-11-24 10:28:21 +00001716 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001717 self.ser,
1718 validator_fcns,
1719 error_name,
1720 op=op,
1721 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001722 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001723 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001724 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001725 output_dtype=result_tensor.dtype,
1726 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001727 input_list=input_list,
1728 output_list=output_list,
1729 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001730 ):
1731 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001732
1733 attr = ts.TosaSerializerAttribute()
1734 attr.AxisAttribute(axis)
1735
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001736 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001737 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001738
evacha0198477222024-01-26 12:25:32 +00001739 def build_transpose(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001740 self,
1741 rng,
1742 op,
1743 inputs,
1744 args_dict,
1745 validator_fcns=None,
1746 error_name=None,
1747 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00001748 ):
1749 assert len(inputs) == 1
1750 a = inputs[0]
1751 perms = args_dict["perms"]
1752
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001753 result_tensor = OutputShaper.transposeOp(self.ser, rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001754
Kevin Chengfe392ce2021-10-18 21:51:55 +00001755 attr = ts.TosaSerializerAttribute()
1756 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001757
Matthew Haddone807aae2021-10-11 18:12:58 +01001758 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001759 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001760 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001761 pCount, cCount = op["operands"]
1762 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001763 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001764 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001765 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001766
Les Bell729b0352021-11-24 10:28:21 +00001767 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001768 self.ser,
1769 validator_fcns,
1770 error_name,
1771 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001772 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001773 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001774 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001775 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001776 output_dtype=result_tensor.dtype,
1777 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001778 input_list=input_list,
1779 output_list=output_list,
1780 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001781 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001782 ):
1783 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001784
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001785 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001786
1787 compliance = self.tensorComplianceMetaData(
1788 op, a.dtype, args_dict, result_tensor, error_name
1789 )
1790
1791 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001792
evacha017f7d4252024-01-24 12:08:09 +00001793 def build_slice(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001794 self,
1795 rng,
1796 op,
1797 inputs,
1798 args_dict,
1799 validator_fcns=None,
1800 error_name=None,
1801 qinfo=None,
evacha017f7d4252024-01-24 12:08:09 +00001802 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001803 assert len(inputs) == 3
1804 a, start_var, size_var = inputs
1805 start_const = args_dict["start"]
1806 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001807
1808 result_tensor = OutputShaper.sliceOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001809 self.ser, rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001810 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001811
1812 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001813 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001814 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001815 pCount, cCount = op["operands"]
1816 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001817 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001818 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001819 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001820
Les Bell729b0352021-11-24 10:28:21 +00001821 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001822 self.ser,
1823 validator_fcns,
1824 error_name,
1825 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001826 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001827 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001828 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001829 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001830 start=start_const,
1831 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001832 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001833 input_list=input_list,
1834 output_list=output_list,
1835 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001836 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001837 ):
1838 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001839
Tai Ly8ead6c42024-02-14 22:35:44 +00001840 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001841
1842 compliance = self.tensorComplianceMetaData(
1843 op, a.dtype, args_dict, result_tensor, error_name
1844 )
1845
1846 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001847
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001848 def build_tile(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001849 self,
1850 rng,
1851 op,
1852 inputs,
1853 args_dict,
1854 validator_fcns=None,
1855 error_name=None,
1856 qinfo=None,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001857 ):
Tai Ly8690a082023-12-18 20:40:24 +00001858 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001859 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001860 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001861 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001862 result_tensor = OutputShaper.tileOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001863 self.ser, rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001864 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001865
1866 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001867 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001868 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001869 pCount, cCount = op["operands"]
1870 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001871 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001872 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001873 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001874
Les Bell729b0352021-11-24 10:28:21 +00001875 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001876 self.ser,
1877 validator_fcns,
1878 error_name,
1879 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001880 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001881 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001882 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001883 output_dtype=result_tensor.dtype,
1884 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001885 input_list=input_list,
1886 output_list=output_list,
1887 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001888 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001889 ):
1890 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001891
Tai Ly8690a082023-12-18 20:40:24 +00001892 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001893
1894 compliance = self.tensorComplianceMetaData(
1895 op, a.dtype, args_dict, result_tensor, error_name
1896 )
1897
1898 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001899
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001900 def build_gather(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001901 self,
1902 rng,
1903 op,
1904 inputs,
1905 args_dict,
1906 validator_fcns=None,
1907 error_name=None,
1908 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001909 ):
1910 assert len(inputs) == 2
1911 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001912
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001913 result_tensor = OutputShaper.gatherOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001914 self.ser, rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001915 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001916
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001917 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001918 input_list = [values.name, indices.name]
1919 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001920 pCount, cCount = op["operands"]
1921 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001922 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001923 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001924 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001925
Les Bell729b0352021-11-24 10:28:21 +00001926 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001927 self.ser,
1928 validator_fcns,
1929 error_name,
1930 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001931 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001932 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001933 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001934 output_dtype=result_tensor.dtype,
1935 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001936 input_list=input_list,
1937 output_list=output_list,
1938 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001939 ):
1940 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001941
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001942 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001943
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001944 compliance = self.tensorComplianceMetaData(
1945 op, values.dtype, args_dict, result_tensor, error_name
1946 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001947
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001948 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001949
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001950 def build_scatter(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001951 self,
1952 rng,
1953 op,
1954 inputs,
1955 args_dict,
1956 validator_fcns=None,
1957 error_name=None,
1958 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001959 ):
1960 assert len(inputs) == 3
1961 values_in, indices, input = inputs
1962 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001963 self.ser, rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001964 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001965
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001966 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001967 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001968 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001969 pCount, cCount = op["operands"]
1970 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001971 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001972 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001973 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001974
Les Bell729b0352021-11-24 10:28:21 +00001975 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001976 self.ser,
1977 validator_fcns,
1978 error_name,
1979 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001980 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001981 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001982 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001983 output_dtype=result_tensor.dtype,
1984 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001985 input_list=input_list,
1986 output_list=output_list,
1987 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001988 ):
1989 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001990
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001991 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001992
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001993 compliance = self.tensorComplianceMetaData(
1994 op, values_in.dtype, args_dict, result_tensor, error_name
1995 )
1996
1997 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001998
Kevin Cheng550ccc52021-03-03 11:21:43 -08001999 def build_resize(
2000 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002001 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002002 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002003 inputs,
2004 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01002005 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002006 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002007 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002008 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002009 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002010 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002011 scale_input = inputs[1]
2012 offset_input = inputs[2]
2013 border_input = inputs[3]
2014
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002015 mode = args_dict["mode"]
2016 scale = args_dict["scale"]
2017 offset = args_dict["offset"]
2018 border = args_dict["border"]
2019 output_dtype = args_dict["output_dtype"]
2020
2021 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08002022 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002023 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002024 input,
2025 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002026 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002027 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002028 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002029 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002030 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002031 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002032 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002033
Matthew Haddon848efb42021-09-09 12:30:53 +01002034 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002035 input_list = [
2036 input.name,
2037 scale_input.name,
2038 offset_input.name,
2039 border_input.name,
2040 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002041 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002042 pCount, cCount = op["operands"]
2043 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002044 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002045 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002046 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002047
Les Bell729b0352021-11-24 10:28:21 +00002048 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002049 self.ser,
2050 validator_fcns,
2051 error_name,
2052 op=op,
2053 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002054 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002055 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002056 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002057 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002058 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002059 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002060 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002061 input_list=input_list,
2062 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002063 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002064 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002065 ):
2066 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002067
Eric Kunzee5e26762020-10-13 16:11:07 -07002068 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002069 # write empty scale/offset/border into ResizeAttribute
2070 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002071 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002072
2073 compliance = self.tensorComplianceMetaData(
2074 op, input.dtype, args_dict, result_tensor, error_name
2075 )
2076
2077 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002078
evacha0198477222024-01-26 12:25:32 +00002079 def build_const(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002080 self,
2081 rng,
2082 op,
2083 inputs,
2084 args_dict,
2085 validator_fcns=None,
2086 error_name=None,
2087 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00002088 ):
2089 assert len(inputs) == 1
2090 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002091 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002092
2093 compliance = self.tensorComplianceMetaData(
2094 op, val.dtype, args_dict, val, error_name
2095 )
2096
2097 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002098
2099 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002100 def build_cast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002101 self,
2102 rng,
2103 op,
2104 inputs,
2105 args_dict,
2106 validator_fcns=None,
2107 error_name=None,
2108 qinfo=None,
Jeremy Johnson708da822023-11-15 16:25:45 +00002109 ):
2110 assert len(inputs) == 1
2111 val = inputs[0]
2112 out_dtype = args_dict["out_type"]
2113
2114 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002115 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002116 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002117
2118 # Invalidate Input/Output list for error if checks.
2119 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002120 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002121 pCount, cCount = op["operands"]
2122 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002123 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002124 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002125 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002126
Les Bell729b0352021-11-24 10:28:21 +00002127 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002128 self.ser,
2129 validator_fcns,
2130 error_name,
2131 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002132 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002133 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002134 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002135 output_dtype=result_tensor.dtype,
2136 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002137 input_list=input_list,
2138 output_list=output_list,
2139 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002140 ):
2141 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002142
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002143 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002144
2145 compliance = self.tensorComplianceMetaData(
2146 op, val.dtype, args_dict, result_tensor, error_name
2147 )
2148
2149 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002150
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002151 def build_rescale(
2152 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002153 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002154 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002155 inputs,
2156 args_dict,
2157 validator_fcns=None,
2158 error_name=None,
2159 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002160 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002161 assert len(inputs) == 3
Jeremy Johnson587cc842024-02-08 11:45:44 +00002162 val = inputs[0]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002163 multiplier_val = inputs[1]
2164 shift_val = inputs[2]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002165 out_dtype = args_dict["output_dtype"]
2166 scale32 = args_dict["scale"]
2167 double_round = args_dict["double_round"]
2168 per_channel = args_dict["per_channel"]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002169 shift_arr = args_dict["shift"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002170 multiplier_arr = args_dict["multiplier"]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002171
2172 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002173 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002174 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002175
2176 if per_channel:
2177 nc = val.shape[-1]
2178 else:
2179 nc = 1
2180
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002181 in_type_width = gtu.dtypeWidth(val.dtype)
2182 out_type_width = gtu.dtypeWidth(out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002183
Tai Ly8690a082023-12-18 20:40:24 +00002184 input_unsigned = False
2185 output_unsigned = False
2186
Kevin Cheng3a478572021-01-22 17:21:02 -08002187 if val.dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002188 input_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002189 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002190 elif val.dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002191 input_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002192 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002193 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002194 elif error_name in [
2195 ErrorIf.InputZeroPointNotZero,
2196 ErrorIf.U16InputZeroPointNotValid,
2197 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002198 input_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002199 if input_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002200 input_zp = input_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002201 in_type_width += 1
2202 elif val.dtype == DType.UINT16:
2203 # Must come after ErrorIf.U16InputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002204 input_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002205 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002206 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002207 else:
2208 input_zp = 0
2209
Kevin Cheng3a478572021-01-22 17:21:02 -08002210 if out_dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002211 output_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002212 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002213 elif out_dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002214 output_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002215 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002216 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002217 elif error_name in [
2218 ErrorIf.OutputZeroPointNotZero,
2219 ErrorIf.U16OutputZeroPointNotValid,
2220 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002221 output_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002222 if output_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002223 output_zp = output_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002224 out_type_width += 1
2225 elif out_dtype == DType.UINT16:
2226 # Must come after ErrorIf.U16OutputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002227 output_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002228 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002229 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002230 else:
2231 output_zp = 0
2232
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002233 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2234 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002235
2236 for i in range(nc):
Eric Kunze750d27d2022-06-30 21:37:09 +00002237 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2238 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002239
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002240 logger.debug(
2241 f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
2242 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002243 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002244 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002245 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002246 assert val.placeholderFilename
2247 values = np.load(
2248 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2249 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002250 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2251 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2252 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002253 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2254 # Check we can safely convert to the expected dtype
2255 assert (
2256 val_adj.all() >= np.iinfo(values.dtype).min
2257 and val_adj.all() <= np.iinfo(values.dtype).max
2258 )
2259
2260 # Force casting to output datatype
2261 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2262
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002263 if not np.all(np.array_equal(values, val_adj)):
2264 # Values changed so overwrite file with new values
2265 np.save(
2266 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2267 val_adj,
2268 False,
2269 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002270
Matthew Haddonc2025212021-10-08 21:21:05 +01002271 # Invalidate Input/Output list for error if checks.
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002272 input_list = [val.name, multiplier_val.name, shift_val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002273 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002274 pCount, cCount = op["operands"]
2275 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002276 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002277 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002278 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002279
2280 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002281 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002282 self.ser,
2283 validator_fcns,
2284 error_name,
2285 op=op,
2286 input_dtype=val.dtype,
2287 output_dtype=out_dtype,
2288 input_shape=val.shape,
2289 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002290 scale32=scale32,
2291 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002292 input_list=input_list,
2293 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002294 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002295 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002296 ):
2297 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002298
Eric Kunzee5e26762020-10-13 16:11:07 -07002299 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002300 attr.RescaleAttribute(
2301 input_zp,
2302 output_zp,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002303 scale32,
2304 double_round,
2305 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002306 input_unsigned,
2307 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002308 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002309
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002310 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002311
2312 compliance = self.tensorComplianceMetaData(
2313 op, val.dtype, args_dict, result_tensor, error_name
2314 )
2315
2316 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002317
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002318 def _get_condition_tensor(self, rng, op, cond, error_name):
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002319 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002320 cond_type = gtu.get_wrong_output_type(op, rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002321 else:
2322 cond_type = DType.BOOL
2323 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002324 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002325 if choice == 1:
2326 cond_shape = [2]
2327 else:
2328 cond_shape = [1, 2]
2329 else:
2330 # Must be of size 1 (rank 0)
2331 cond_shape = []
2332 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2333 return cond_tens
2334
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002335 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002336 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002337 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002338 op,
2339 inputs,
2340 args_dict,
2341 validator_fcns=None,
2342 error_name=None,
2343 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002344 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002345 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002346 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002347 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002348 assert len(inputs) == 2
2349 then_tens, else_tens = inputs
2350
2351 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002352
2353 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002354 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002355
2356 # Make then/else tensors
2357 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002358
Jeremy Johnson587cc842024-02-08 11:45:44 +00002359 dtype = DType.INT32
2360
Matthew Haddon630c17c2021-10-14 15:05:41 +01002361 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002362 if error_name in [
2363 ErrorIf.CondIfOutputListThenGraphMismatch,
2364 ErrorIf.CondIfOutputListElseGraphMismatch,
2365 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002366 incorrect_shape = deepcopy(then_tens.shape)
2367 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002368 incorrect_shape[i] += (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002369 rng.choice([-3, -2, 2, 3])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002370 if incorrect_shape[i] > 3
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002371 else rng.choice([1, 2, 4])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002372 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002373 incorrect_arr = np.int32(rng.integers(0, 256, size=incorrect_shape))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002374
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002375 then_arr = np.int32(rng.integers(0, 256, size=out_shape))
2376 else_arr = np.int32(rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002377
2378 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002379 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002380
2381 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002382 then_block = "THEN_BLOCK"
2383 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002384 attr = ts.TosaSerializerAttribute()
2385 attr.CondIfAttribute(then_block, else_block)
2386
2387 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002388 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002389
Jerry Ge9e94af82022-10-27 09:57:00 -07002390 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002391 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002392 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002393 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002394 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002395 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002396 self.ser.addOutputTensor(then_tens)
2397
Jerry Ge9e94af82022-10-27 09:57:00 -07002398 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002399 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002400 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002401 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002402 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002403 self.ser.addOutputTensor(else_tens)
2404
Les Bell729b0352021-11-24 10:28:21 +00002405 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002406 self.ser,
2407 validator_fcns,
2408 error_name,
2409 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002410 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002411 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002412 ):
2413 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002414
Jeremy Johnson587cc842024-02-08 11:45:44 +00002415 compliance = self.tensorComplianceMetaData(
2416 op, dtype, args_dict, result_tensor, error_name
2417 )
2418
2419 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002420
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002421 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002422 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002423 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002424 op,
2425 inputs,
2426 args_dict,
2427 validator_fcns=None,
2428 error_name=None,
2429 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002430 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002431 # For cond_if with a binary op in the then/else blocks, take a and b and
2432 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002433 assert len(inputs) == 2
2434 a, b = inputs
2435
2436 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002437
2438 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002439 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002440
Jeremy Johnson587cc842024-02-08 11:45:44 +00002441 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002442
2443 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002444 then_block = "THEN_BLOCK"
2445 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002446 attr = ts.TosaSerializerAttribute()
2447 attr.CondIfAttribute(then_block, else_block)
2448
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002449 if error_name in [
2450 ErrorIf.CondIfInputListThenGraphMismatch,
2451 ErrorIf.CondIfInputListElseGraphMismatch,
2452 ErrorIf.CondIfOutputListElseGraphMismatch,
2453 ErrorIf.CondIfOutputListThenGraphMismatch,
2454 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002455 incorrect_shape = a.shape.copy()
2456 for i in range(len(incorrect_shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002457 incorrect_shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002458 incorrect_block_input = deepcopy(a)
2459 incorrect_block_input.shape = incorrect_shape
2460
Eric Kunzee5e26762020-10-13 16:11:07 -07002461 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002462 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002463 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002464 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002465
James Ward24dbc422022-10-19 12:20:31 +01002466 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002467 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002468 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002469 then_op, else_op = (
2470 self.TOSA_OP_LIST["logical_right_shift"],
2471 self.TOSA_OP_LIST["logical_left_shift"],
2472 )
Les Bell6040b4d2021-10-11 12:50:31 +01002473 else:
2474 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002475
Jeremy Johnson587cc842024-02-08 11:45:44 +00002476 # Determine the element-wise binary operation that compliance will need to
2477 # check the results of
2478 compliance_op = then_op if cond else else_op
2479
2480 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002481 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002482 if (
2483 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2484 and block == then_block
2485 ) or (
2486 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2487 and block == else_block
2488 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002489 self.ser.addInputTensor(incorrect_block_input)
2490 self.ser.addInputTensor(b)
2491 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002492 elif (
2493 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2494 and block == then_block
2495 ) or (
2496 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2497 and block == else_block
2498 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002499 self.ser.addInputTensor(a)
2500 self.ser.addInputTensor(b)
2501 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2502 else:
2503 self.ser.addInputTensor(a)
2504 self.ser.addInputTensor(b)
2505 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002506 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002507
Les Bell729b0352021-11-24 10:28:21 +00002508 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002509 self.ser,
2510 validator_fcns,
2511 error_name,
2512 op=op,
2513 a=a,
2514 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002515 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002516 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002517 ):
2518 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002519
Jeremy Johnson587cc842024-02-08 11:45:44 +00002520 compliance = self.tensorComplianceMetaData(
2521 compliance_op, a.dtype, args_dict, result_tensor, error_name
2522 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002523
Jeremy Johnson587cc842024-02-08 11:45:44 +00002524 return TosaTestGen.BuildInfo(result_tensor, compliance)
2525
2526 def build_while_loop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002527 self,
2528 rng,
2529 op,
2530 inputs,
2531 args_dict,
2532 validator_fcns=None,
2533 error_name=None,
2534 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002535 ):
2536 assert len(inputs) == 1
2537 a = inputs[0]
2538 iter_val = args_dict["iterations"]
2539
Kevin Cheng550ccc52021-03-03 11:21:43 -08002540 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002541
Kevin Cheng550ccc52021-03-03 11:21:43 -08002542 cond_block = "COND_BLOCK"
2543 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002544
2545 attr = ts.TosaSerializerAttribute()
2546 attr.WhileLoopAttribute(cond_block, body_block)
2547
2548 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002549 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002550 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002551 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002552
2553 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002554 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2555 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002556 if error_name == ErrorIf.InputListOutputListMismatch:
2557 incorrect_acc = deepcopy(acc)
2558 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002559 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002560 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2561 else:
2562 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002563
2564 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002565 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002566 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002567 [iter.name, a.name, acc.name],
2568 [iter_out.name, a_out.name, acc_out.name],
2569 attr,
2570 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002571 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002572
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002573 if error_name in [
2574 ErrorIf.InputListCondGraphMismatch,
2575 ErrorIf.InputListBodyGraphInputMismatch,
2576 ErrorIf.InputListBodyGraphOutputMismatch,
2577 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002578 incorrect_iter = deepcopy(iter)
2579 for i in range(len(incorrect_iter.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002580 incorrect_iter.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002581 if len(incorrect_iter.shape) == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002582 incorrect_iter.shape.append(rng.choice([-3, -2, 2, 3]))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002583
2584 incorrect_acc = deepcopy(acc)
2585 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002586 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002587
Eric Kunzee5e26762020-10-13 16:11:07 -07002588 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002589 self.ser.addBasicBlock(cond_block)
2590
Matthew Haddon630c17c2021-10-14 15:05:41 +01002591 if error_name == ErrorIf.InputListCondGraphMismatch:
2592 self.ser.addInputTensor(incorrect_iter)
2593 self.ser.addInputTensor(a)
2594 self.ser.addInputTensor(incorrect_acc)
2595 else:
2596 self.ser.addInputTensor(iter)
2597 self.ser.addInputTensor(a)
2598 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002599 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002600
2601 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002602 cond_type = rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002603 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002604 cond_type = DType.BOOL
2605 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002606 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002607 if choice == 1:
2608 cond_shape = [3]
2609 else:
2610 cond_shape = [1, 2]
2611 else:
2612 cond_shape = []
2613 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002614
Kevin Cheng550ccc52021-03-03 11:21:43 -08002615 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002616
2617 # BODY block (input: a, acc, iter, output: a, acc, iter)
2618 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002619 self.ser.addBasicBlock(body_block)
2620
Matthew Haddon630c17c2021-10-14 15:05:41 +01002621 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2622 self.ser.addInputTensor(incorrect_iter)
2623 self.ser.addInputTensor(a)
2624 self.ser.addInputTensor(incorrect_acc)
2625 else:
2626 self.ser.addInputTensor(iter)
2627 self.ser.addInputTensor(a)
2628 self.ser.addInputTensor(acc)
2629
Kevin Cheng550ccc52021-03-03 11:21:43 -08002630 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002631
2632 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002633 iter_body_out = self.ser.addIntermediate(
2634 incorrect_iter.shape, incorrect_iter.dtype
2635 )
2636 acc_body_out = self.ser.addIntermediate(
2637 incorrect_acc.shape, incorrect_acc.dtype
2638 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002639 else:
2640 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2641 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2642
Eric Kunzee5e26762020-10-13 16:11:07 -07002643 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2644 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2645 self.ser.addOutputTensor(iter_body_out)
2646 self.ser.addOutputTensor(a)
2647 self.ser.addOutputTensor(acc_body_out)
2648
Les Bell729b0352021-11-24 10:28:21 +00002649 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002650 self.ser,
2651 validator_fcns,
2652 error_name,
2653 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002654 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002655 ):
2656 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002657
Jeremy Johnson587cc842024-02-08 11:45:44 +00002658 compliance = self.tensorComplianceMetaData(
2659 op, a.dtype, args_dict, acc_out, error_name
2660 )
2661
2662 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002663
Luke Hutton57287132023-02-06 14:54:18 +00002664 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002665 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002666 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002667 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002668 inputs,
2669 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002670 validator_fcns=None,
2671 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002672 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002673 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002674 assert len(inputs) == 2
2675 val1, val2 = inputs
2676 inverse = args_dict["inverse"]
2677
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002678 results = OutputShaper.fft2dOp(self.ser, rng, val1, val2, error_name)
Luke Hutton57287132023-02-06 14:54:18 +00002679
2680 input_names = [val1.name, val2.name]
2681 pCount, cCount = op["operands"]
2682 num_operands = pCount + cCount
2683
2684 output_names = [res.name for res in results]
2685 output_shapes = [res.shape for res in results]
2686 output_dtypes = [res.dtype for res in results]
2687
2688 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002689 rng, error_name, input_names, output_names
Luke Hutton57287132023-02-06 14:54:18 +00002690 )
2691
2692 if not TosaErrorValidator.evValidateErrorIfs(
2693 self.ser,
2694 validator_fcns,
2695 error_name,
2696 op=op,
2697 inverse=inverse,
2698 input1=val1,
2699 input2=val2,
2700 input_shape=val1.shape,
2701 input_dtype=val1.dtype,
2702 output_shape=output_shapes,
2703 output_dtype=output_dtypes,
2704 result_tensors=results,
2705 input_list=input_names,
2706 output_list=output_names,
2707 num_operands=num_operands,
2708 ):
2709 return None
2710
Tai Lyd3797f02023-11-15 23:06:19 +00002711 # TODO - Test local_bound, for now set local bound attribute to False
2712 local_bound = False
2713
Luke Hutton57287132023-02-06 14:54:18 +00002714 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002715 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002716
2717 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002718
2719 compliance = []
2720 for res in results:
2721 compliance.append(
2722 self.tensorComplianceMetaData(
2723 op, val1.dtype, args_dict, res, error_name
2724 )
2725 )
2726
2727 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002728
Tai Lyd3797f02023-11-15 23:06:19 +00002729 def build_rfft2d(
2730 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002731 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002732 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002733 inputs,
2734 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002735 validator_fcns=None,
2736 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002737 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002738 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002739 assert len(inputs) == 1
2740 val = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002741 results = OutputShaper.rfft2dOp(self.ser, rng, val, error_name)
Luke Hutton261b7b62023-01-10 14:50:31 +00002742
2743 input_names = [val.name]
2744 pCount, cCount = op["operands"]
2745 num_operands = pCount + cCount
2746
2747 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002748 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002749 output_dtypes = [res.dtype for res in results]
2750
2751 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002752 rng, error_name, input_names, output_names
Luke Hutton261b7b62023-01-10 14:50:31 +00002753 )
2754
2755 if not TosaErrorValidator.evValidateErrorIfs(
2756 self.ser,
2757 validator_fcns,
2758 error_name,
2759 op=op,
2760 input_shape=val.shape,
2761 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002762 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002763 output_dtype=output_dtypes,
2764 result_tensors=results,
2765 input_list=input_names,
2766 output_list=output_names,
2767 num_operands=num_operands,
2768 ):
2769 return None
2770
Tai Lyd3797f02023-11-15 23:06:19 +00002771 # TODO - Test local_bound, for now set local bound attribute to False
2772 local_bound = False
2773
2774 attr = ts.TosaSerializerAttribute()
2775 attr.RFFTAttribute(local_bound)
2776
2777 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002778
2779 compliance = []
2780 for res in results:
2781 compliance.append(
2782 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2783 )
2784
2785 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002786
Won Jeon74342e52024-01-09 00:34:40 +00002787 def build_shape_op(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002788 self,
2789 rng,
2790 op,
2791 inputs,
2792 args_dict,
2793 validator_fcns=None,
2794 error_name=None,
2795 qinfo=None,
Won Jeon74342e52024-01-09 00:34:40 +00002796 ):
2797 assert len(inputs) == 2
2798 a, b = inputs
2799
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002800 result_tensor = OutputShaper.addShapeOp(self.ser, rng, a, b, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00002801
2802 # Invalidate Input/Output list for error if checks.
2803 input_list = [a.name, b.name]
2804 output_list = [result_tensor.name]
2805 pCount, cCount = op["operands"]
2806 num_operands = pCount + cCount
2807 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2808 self, error_name, input_list, output_list
2809 )
2810
2811 if not TosaErrorValidator.evValidateErrorIfs(
2812 self.ser,
2813 validator_fcns,
2814 error_name,
2815 op=op,
2816 input1=a,
2817 input2=b,
2818 input_shape=a.shape,
2819 input_dtype=a.dtype,
2820 output_shape=result_tensor.shape,
2821 output_dtype=result_tensor.dtype,
2822 result_tensors=[result_tensor],
2823 input_list=input_list,
2824 output_list=output_list,
2825 num_operands=num_operands,
2826 ):
2827 return None
2828
2829 self.ser.addOperator(
2830 op["op"],
2831 input_list,
2832 output_list,
2833 )
2834 compliance = self.tensorComplianceMetaData(
2835 op, a.dtype, args_dict, result_tensor, error_name
2836 )
2837
2838 return TosaTestGen.BuildInfo(result_tensor, compliance)
2839
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002840 def create_filter_lists(
2841 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2842 ):
Jeremy Johnson18a379d2024-03-28 15:53:21 +00002843 # Create a default testing rank range
2844 if testType == "positive":
2845 # 0-3 inclusive to keep test sizes reasonably small.
2846 default_test_rank_range = range(0, 4)
2847 else:
2848 # Some errors do not work with rank 0, use 1-3
2849 default_test_rank_range = range(1, 4)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002850
2851 # Calculate the filters based on what is requested and what the operator allows
2852 rmin, rmax = op["rank"]
Jeremy Johnson18a379d2024-03-28 15:53:21 +00002853
2854 if shapeFilter:
2855 # Specified shapes - ignore rank filter and default to op ranks below
2856 rankFilter = None
2857 ranksToCheck = []
2858 elif rankFilter is None:
2859 # No set rank filter so ensure default behaviour is bounded
2860 ranksToCheck = default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01002861 else:
Jeremy Johnson18a379d2024-03-28 15:53:21 +00002862 ranksToCheck = rankFilter
2863
2864 cleanRankFilter = []
2865 # Ensure rank values are allowed by operator
2866 for rank in ranksToCheck:
2867 if rank >= rmin and rank <= rmax:
2868 cleanRankFilter.append(rank)
2869
2870 if shapeFilter or (len(cleanRankFilter) == 0 and rankFilter is None):
2871 # Shapes specified or default test ranks didn't meet
2872 # op requirements - so just use op ranks
Matthew Haddon1c00b712021-10-01 15:51:03 +01002873 cleanRankFilter = range(rmin, rmax + 1)
2874
2875 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002876
Matthew Haddon1c00b712021-10-01 15:51:03 +01002877 if dtypeFilter is not None:
2878 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002879 # Create list of operator dtypes filtered by requested dtypes
2880 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002881 if dtype in dtypeFilter or (
2882 isinstance(dtype, list) and dtype[0] in dtypeFilter
2883 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002884 cleanDtypeFilter.append(dtype)
2885 else:
2886 cleanDtypeFilter = dtypes
2887
Jeremy Johnson18a379d2024-03-28 15:53:21 +00002888 if not shapeFilter:
2889 shapeFilter = [None]
2890
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002891 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002892 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002893 "shapeFilter": shapeFilter,
2894 "rankFilter": cleanRankFilter,
2895 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002896 }
2897 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002898 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002899 if validator is not None:
2900 validator_info = validator(check=False, op=op)
2901 else:
2902 return None
2903
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002904 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002905
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002906 # Set parameters as required
2907 if error_arguments["rank"] is not None:
2908 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002909 else:
2910 rankFilter = cleanRankFilter
2911
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002912 if error_arguments["dtype"] is not None:
2913 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002914 else:
2915 dtypeFilter = cleanDtypeFilter
2916
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002917 if error_arguments["shape"] is not None:
2918 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002919 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002920 shapeFilter = shapeFilter[
2921 :2
2922 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002923
2924 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002925 "shapeFilter": shapeFilter,
2926 "rankFilter": rankFilter,
2927 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002928 }
2929 return filterDict
2930
Kevin Cheng550ccc52021-03-03 11:21:43 -08002931 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002932 self,
2933 opName,
2934 shapeFilter=[None],
2935 rankFilter=None,
2936 dtypeFilter=None,
2937 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002938 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002939
2940 try:
2941 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002942 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002943 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002944
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002945 if not self.args.stable_rng:
2946 # Initialize a new random number generator per op
2947 self.resetGlobalRNG()
Eric Kunzee5e26762020-10-13 16:11:07 -07002948
Jeremy Johnson1271c442023-09-05 11:39:26 +01002949 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002950
Eric Kunzee5e26762020-10-13 16:11:07 -07002951 # Test list consists of a tuple of:
2952 # (opName, testNameStr, dtype, shapeList, argumentsList)
2953 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002954 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002955 error_if_validators = op["error_if_validators"]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002956 num_error_types_created = 0
Matthew Haddon1c00b712021-10-01 15:51:03 +01002957 else:
2958 error_if_validators = [None]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002959 num_error_types_created = None
Eric Kunzee5e26762020-10-13 16:11:07 -07002960
Matthew Haddon1c00b712021-10-01 15:51:03 +01002961 for validator in error_if_validators:
2962 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002963 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002964 else:
2965 error_name = None
2966
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002967 filterDict = self.create_filter_lists(
2968 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2969 )
2970 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002971 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002972 cleanRankFilter = filterDict["rankFilter"]
2973 cleanDtypeFilter = filterDict["dtypeFilter"]
2974 cleanShapeFilter = filterDict["shapeFilter"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002975 logger.debug(
2976 f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
2977 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002978
2979 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002980 for t in cleanDtypeFilter:
2981 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002982 # Filter out by rank
2983 if shape is not None and len(shape) != r:
2984 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002985 self.setTargetShape(shape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002986 typeStr = self.typeStr(t)
2987 if self.args.stable_rng:
2988 shape_rng = TosaHashRandomGenerator(
2989 self.random_seed,
2990 [opName, r, typeStr],
2991 self.random_dtype_range,
2992 )
2993 else:
2994 shape_rng = self.global_rng
2995 shapeList = tgen_fcn(self, shape_rng, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002996
Matthew Haddon74567092021-07-16 15:38:20 +01002997 shapeStr = self.shapeStr(shapeList[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07002998
Matthew Haddon74567092021-07-16 15:38:20 +01002999 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
3000 argList = []
3001 if agen_fcn:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003002 if self.args.stable_rng:
3003 arg_rng = TosaHashRandomGenerator(
3004 self.random_seed,
3005 [opName, shapeStr, typeStr],
3006 self.random_dtype_range,
3007 )
3008 else:
3009 arg_rng = self.global_rng
3010
3011 argList = agen_fcn(
3012 self, arg_rng, opName, shapeList, t, error_name
3013 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003014 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003015 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003016
Matthew Haddon74567092021-07-16 15:38:20 +01003017 for argStr, args in argList:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003018 # Create the test name string - for example: add_1x2x3_i32
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003019 if testType == "positive":
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003020 name_parts = [opName, shapeStr, typeStr]
3021 else:
3022 assert testType == "negative"
3023 name_parts = [
3024 opName,
3025 "ERRORIF",
3026 error_name,
3027 shapeStr,
3028 typeStr,
3029 ]
3030 if argStr:
3031 name_parts.append(argStr)
3032 testStr = "_".join(name_parts)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003033
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003034 testList.append(
3035 (opName, testStr, t, error_name, shapeList, args)
3036 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003037 if error_name is not None:
3038 # Check the last test is of the error we wanted
3039 if len(testList) == 0 or testList[-1][3] != error_name:
3040 if self.args.level8k:
3041 logger.info(f"Missing {error_name} tests due to level8k mode")
3042 else:
3043 logger.error(f"ERROR: Failed to create any {error_name} tests")
3044 logger.debug(
3045 "Last test created: {}".format(
3046 testList[-1] if testList else None
3047 )
3048 )
3049 else:
3050 # Successfully created at least one ERRROR_IF test
3051 num_error_types_created += 1
Matthew Haddon1c00b712021-10-01 15:51:03 +01003052
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003053 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01003054 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3055 if "invalid_test_validators" in op:
3056 invalid_test_validators = op["invalid_test_validators"]
3057 clean_testList = []
3058 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01003059 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003060 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003061 if validator_fcn(
3062 opName=test[0],
3063 input_dtype=test[2],
3064 shapeList=test[4],
3065 args=test[5],
3066 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003067 remove_test = True
3068 if not remove_test:
3069 clean_testList.append(test)
3070 testList = clean_testList
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003071 else:
3072 if num_error_types_created is not None and not self.args.level8k:
3073 remaining_error_types = (
3074 len(error_if_validators) - num_error_types_created
3075 )
3076 if remaining_error_types:
3077 raise Exception(
3078 f"Failed to create {remaining_error_types} error types for {opName}"
3079 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003080
3081 return testList
3082
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003083 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00003084 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003085 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003086 try:
3087 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003088 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003089 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003090
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003091 logger.info(f"Creating {testStr}")
Jeremy Johnson0c716862023-04-13 17:18:19 +01003092
Eric Kunzee5e26762020-10-13 16:11:07 -07003093 # Create a serializer
3094 self.createSerializer(opName, testStr)
3095
Jeremy Johnson1271c442023-09-05 11:39:26 +01003096 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003097 if "error_if_validators" in op:
3098 error_if_validators = op["error_if_validators"]
3099 else:
3100 error_if_validators = None
3101
Kevin Cheng550ccc52021-03-03 11:21:43 -08003102 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003103 num_operands = pCount + cCount
3104
3105 if isinstance(dtype_or_dtypeList, list):
3106 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003107 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003108 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003109 else:
3110 dtypeList = [dtype_or_dtypeList] * (num_operands)
3111
Won Jeon74342e52024-01-09 00:34:40 +00003112 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003113 assert (
3114 len(shapeList) == num_operands
3115 ), "shapeList length {} must match number of operands {}".format(
3116 len(shapeList), num_operands
3117 )
3118 assert (
3119 len(dtypeList) == num_operands
3120 ), "dtypeList length {} must match number of operands {}".format(
3121 len(dtypeList), num_operands
3122 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003123
3124 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003125 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003126 except KeyError:
3127 qgen = None
3128
3129 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003130
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003131 # Set the random number generator
3132 if self.args.stable_rng:
3133 build_rng = TosaHashRandomGenerator(
3134 self.random_seed, [testStr], self.random_dtype_range
3135 )
3136 else:
3137 build_rng = self.global_rng
3138
Matthew Haddon1c00b712021-10-01 15:51:03 +01003139 if qgen is not None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003140 qinfo = qgen(
3141 build_rng, self.args.zeropoint, op, dtype_or_dtypeList, error_name
3142 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003143 else:
3144 qinfo = None
3145
Jeremy Johnson1271c442023-09-05 11:39:26 +01003146 # Extra meta data for the desc.json
3147 tensMeta = {}
3148
Jeremy Johnson587cc842024-02-08 11:45:44 +00003149 # Check we are using the new interface with an argsDict dictionary
3150 assert isinstance(
3151 argsDict, dict
3152 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003153
Jeremy Johnson587cc842024-02-08 11:45:44 +00003154 # New interface with args info in dictionary
3155 assert "dg_type" in argsDict
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003156 tvgInfo = tvgen_fcn(
3157 self, build_rng, opName, dtypeList, shapeList, argsDict, error_name
3158 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003159 if tvgInfo.dataGenDict:
3160 tensMeta["data_gen"] = tvgInfo.dataGenDict
3161 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003162
evacha01ad8e1e22024-03-19 12:42:17 +00003163 tags = argsDict.get("tags", None)
3164
Jeremy Johnson587cc842024-02-08 11:45:44 +00003165 result = build_fcn(
3166 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003167 build_rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003168 op,
3169 tens,
3170 argsDict,
3171 validator_fcns=error_if_validators,
3172 error_name=error_name,
3173 qinfo=qinfo,
3174 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003175
Jeremy Johnson1271c442023-09-05 11:39:26 +01003176 if result:
Les Bell729b0352021-11-24 10:28:21 +00003177 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003178 if isinstance(result, TosaTestGen.BuildInfo):
3179 # Add the compliance meta data (if any)
3180 compliance = result.getComplianceInfo()
3181 if compliance:
3182 tensMeta["compliance"] = compliance
evacha01ad8e1e22024-03-19 12:42:17 +00003183 self.serialize("test", tensMeta, tags)
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003184 return True
Les Bell729b0352021-11-24 10:28:21 +00003185 else:
3186 # The test is not valid
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003187 logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003188 return False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003189
Eric Kunzee5e26762020-10-13 16:11:07 -07003190 def createDynamicOpLists(self):
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003191 # Find all the ops marked as templates
3192 templateKeys = []
3193 for opName in self.TOSA_OP_LIST:
Eric Kunzee5e26762020-10-13 16:11:07 -07003194 try:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003195 if self.TOSA_OP_LIST[opName]["template"]:
3196 templateKeys.append(opName)
Eric Kunzee5e26762020-10-13 16:11:07 -07003197 except KeyError:
3198 pass
3199
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003200 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3201
3202 # Add dynamic ops based on kernel sizes
3203 for opName in templateKeys:
3204 assert opName.endswith("_TEMPLATE"), "Found incorrect template"
3205 realName = opName[: len(opName) - len("_TEMPLATE")]
3206 template = self.TOSA_OP_LIST[opName]
3207 k_rank = 3 if realName == "conv3d" else 2
3208
3209 # Choose kernels to build tests for from the template or args
3210 if self.args.level8k:
3211 if k_rank == 3:
3212 kernels = [[1, bigK, 1], [2, 2, bigK]]
3213 else:
3214 kernels = [[1, bigK], [bigK, 2]]
3215 else:
3216 kernels = []
3217 if len(self.args.conv_kernels) > 0:
3218 kernels = [k for k in self.args.conv_kernels if len(k) == k_rank]
3219 if len(kernels) == 0:
3220 logger.debug(
3221 f"{realName} op using defaults as no rank {k_rank} kernels found in {self.args.conv_kernels}"
3222 )
3223 if len(kernels) == 0:
3224 # Fallback to use the defined template kernels
3225 kernels = self.TOSA_OP_LIST[opName]["filter"]
3226
3227 # Dynamically create ops for listed kernel sizes
3228 for k in kernels:
3229 kernelStr = "x".join([str(d) for d in k])
3230 testName = f"{realName}_{kernelStr}"
3231 kernelOp = template.copy()
3232 kernelOp["filter"] = k
3233 kernelOp["template"] = False
3234 kernelOp["real_name"] = realName
3235 self.TOSA_OP_LIST[testName] = kernelOp
3236
3237 # Delete the template after having created the dynamic ops
3238 del self.TOSA_OP_LIST[opName]
Eric Kunzee5e26762020-10-13 16:11:07 -07003239
3240 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003241 """Fill in default fields for ops if they aren't already specified.
3242 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003243 for op in self.TOSA_OP_LIST:
3244
3245 # Required fields
3246 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003247 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003248 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003249 raise Exception(
3250 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3251 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003252
3253 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003254 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003255 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003256 raise Exception(
3257 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3258 op
3259 )
3260 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003261
3262 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003263 _ = self.TOSA_OP_LIST[op]["types"]
3264 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003265 raise Exception(
3266 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3267 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003268
3269 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003270 _ = self.TOSA_OP_LIST[op]["op"]
3271 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003272 raise Exception(
3273 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3274 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003275
3276 # Put in default rank range, if missing
3277 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003278 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003279 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003280 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003281
3282 # Tensor operator list
3283 # 'op': op name
3284 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003285 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3286 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003287 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3288 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003289 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003290
Kevin Cheng550ccc52021-03-03 11:21:43 -08003291 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003292 TYPE_INT_FP = [
3293 DType.INT8,
3294 DType.INT16,
3295 DType.INT32,
3296 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003297 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003298 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003299 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003300
Kevin Cheng550ccc52021-03-03 11:21:43 -08003301 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003302 TYPE_FI32 = [
3303 DType.FP32,
3304 DType.FP16,
3305 DType.BF16,
3306 DType.INT32,
3307 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003308 TYPE_FIB = [
3309 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003310 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003311 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003312 DType.INT8,
3313 DType.INT16,
3314 DType.INT32,
3315 DType.BOOL,
3316 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003317 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003318
Won Jeon2c34b462024-02-06 18:37:00 +00003319 TYPE_NARROW_INT_FP = [
3320 DType.INT8,
3321 DType.INT16,
3322 DType.FP16,
3323 DType.BF16,
3324 DType.FP32,
3325 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003326
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003327 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003328 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003329 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003330 [DType.INT8, DType.INT8, DType.INT32],
3331 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003332 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003333 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003334 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003335 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003336 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3337 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003338 ]
3339
Jeremy Johnson18a379d2024-03-28 15:53:21 +00003340 DEFAULT_RANK_RANGE = (0, gtu.MAX_TENSOR_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003341
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003342 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3343 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3344
evacha01ad8e1e22024-03-19 12:42:17 +00003345 PSEUDO_RANDOM_DATAGEN = {
3346 DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM,),
3347 DType.FP32: (gtu.DataGenType.PSEUDO_RANDOM,),
3348 }
3349 DOT_PRODUCT_DATAGEN = {
3350 DType.FP16: (gtu.DataGenType.DOT_PRODUCT,),
3351 DType.FP32: (gtu.DataGenType.DOT_PRODUCT,),
3352 }
3353 EW_UNARY_DATAGEN = {
3354 DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FULL_RANGE)
3355 }
3356
Eric Kunzee5e26762020-10-13 16:11:07 -07003357 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003358 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003359 "argmax": {
3360 "op": Op.ARGMAX,
3361 "operands": (1, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00003362 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003363 "build_fcn": (
3364 build_argmax,
3365 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003366 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003367 TosaArgGen.agAxis,
3368 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003369 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003370 "error_if_validators": (
3371 TosaErrorValidator.evAxisSmallerZero,
3372 TosaErrorValidator.evAxisLargerRank,
3373 TosaErrorValidator.evArgmaxOutputRankMismatch,
3374 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3375 TosaErrorValidator.evWrongRank,
3376 TosaErrorValidator.evWrongInputType,
3377 TosaErrorValidator.evWrongOutputType,
3378 TosaErrorValidator.evWrongInputList,
3379 TosaErrorValidator.evWrongOutputList,
3380 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003381 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003382 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003383 "avg_pool2d": {
3384 "op": Op.AVG_POOL2D,
3385 "operands": (1, 0),
3386 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003387 "build_fcn": (
3388 build_pool2d,
3389 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003390 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003391 TosaArgGen.agPooling,
3392 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003393 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003394 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003395 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003396 "error_if_validators": (
3397 TosaErrorValidator.evKernelSmallerOne,
3398 TosaErrorValidator.evStrideSmallerOne,
3399 TosaErrorValidator.evPadSmallerZero,
3400 TosaErrorValidator.evWrongRank,
3401 TosaErrorValidator.evWrongInputType,
3402 TosaErrorValidator.evWrongOutputType,
3403 TosaErrorValidator.evWrongInputList,
3404 TosaErrorValidator.evWrongOutputList,
3405 TosaErrorValidator.evInputZeroPointNotZero,
3406 TosaErrorValidator.evOutputZeroPointNotZero,
3407 TosaErrorValidator.evPadLargerEqualKernel,
3408 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003409 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003410 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003411 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003412 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003413 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003414 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003415 "conv2d_TEMPLATE": {
3416 "op": Op.CONV2D,
3417 "operands": (1, 2),
3418 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003419 "build_fcn": (
3420 build_conv2d,
3421 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003422 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003423 TosaArgGen.agConv,
3424 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003425 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003426 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003427 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3428 "error_if_validators": (
3429 TosaErrorValidator.evWrongInputType,
3430 TosaErrorValidator.evWrongOutputType,
3431 TosaErrorValidator.evWrongInputList,
3432 TosaErrorValidator.evWrongOutputList,
3433 TosaErrorValidator.evInputZeroPointNotZero,
3434 TosaErrorValidator.evWeightZeroPointNotZero,
3435 TosaErrorValidator.evPadSmallerZero,
3436 TosaErrorValidator.evStrideSmallerOne,
3437 TosaErrorValidator.evDilationSmallerOne,
3438 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003439 TosaErrorValidator.evConvOutputShapeMismatch,
3440 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003441 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003442 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003443 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003444 "broadcastable_bias": True,
3445 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003446 "template": True,
3447 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003448 # Templated operator. Filled in by createDynamicOpLists
3449 "conv3d_TEMPLATE": {
3450 "op": Op.CONV3D,
3451 "operands": (1, 2),
3452 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003453 "build_fcn": (
3454 build_conv3d,
3455 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003456 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003457 TosaArgGen.agConv,
3458 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003459 "qgen": TosaQuantGen.qgConv,
3460 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003461 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3462 "error_if_validators": (
3463 TosaErrorValidator.evWrongInputType,
3464 TosaErrorValidator.evWrongOutputType,
3465 TosaErrorValidator.evWrongInputList,
3466 TosaErrorValidator.evWrongOutputList,
3467 TosaErrorValidator.evInputZeroPointNotZero,
3468 TosaErrorValidator.evWeightZeroPointNotZero,
3469 TosaErrorValidator.evPadSmallerZero,
3470 TosaErrorValidator.evStrideSmallerOne,
3471 TosaErrorValidator.evDilationSmallerOne,
3472 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003473 TosaErrorValidator.evConvOutputShapeMismatch,
3474 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003475 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003476 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003477 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003478 "filter": KERNELS_3D,
Kevin Cheng1533b852021-09-01 12:51:58 -07003479 "template": True,
3480 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003481 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003482 "depthwise_conv2d_TEMPLATE": {
3483 "op": Op.DEPTHWISE_CONV2D,
3484 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003485 "rank": (4, 4),
3486 "build_fcn": (
3487 build_depthwise_conv2d,
3488 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003489 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003490 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003491 ),
3492 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003493 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003494 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3495 "error_if_validators": (
3496 TosaErrorValidator.evWrongInputType,
3497 TosaErrorValidator.evWrongOutputType,
3498 TosaErrorValidator.evWrongInputList,
3499 TosaErrorValidator.evWrongOutputList,
3500 TosaErrorValidator.evInputZeroPointNotZero,
3501 TosaErrorValidator.evWeightZeroPointNotZero,
3502 TosaErrorValidator.evPadSmallerZero,
3503 TosaErrorValidator.evStrideSmallerOne,
3504 TosaErrorValidator.evDilationSmallerOne,
3505 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003506 TosaErrorValidator.evConvOutputShapeMismatch,
3507 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003508 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003509 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003510 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003511 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003512 "template": True,
3513 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003514 "fully_connected": {
3515 "op": Op.FULLY_CONNECTED,
3516 "operands": (1, 2),
3517 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003518 "build_fcn": (
3519 build_fully_connected,
3520 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003521 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003522 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003523 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003524 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003525 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003526 "error_if_validators": (
3527 TosaErrorValidator.evInputZeroPointNotZero,
3528 TosaErrorValidator.evWeightZeroPointNotZero,
3529 TosaErrorValidator.evWrongRank,
3530 TosaErrorValidator.evWrongInputType,
3531 TosaErrorValidator.evWrongOutputType,
3532 TosaErrorValidator.evWrongInputList,
3533 TosaErrorValidator.evWrongOutputList,
3534 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003535 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003536 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003537 "matmul": {
3538 "op": Op.MATMUL,
3539 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003540 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003541 "build_fcn": (
3542 build_matmul,
3543 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003544 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003545 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003546 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003547 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003548 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003549 "error_if_validators": (
3550 TosaErrorValidator.evInputZeroPointNotZero,
3551 TosaErrorValidator.evWrongRank,
3552 TosaErrorValidator.evWrongInputType,
3553 TosaErrorValidator.evWrongOutputType,
3554 TosaErrorValidator.evWrongInputList,
3555 TosaErrorValidator.evWrongOutputList,
3556 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003557 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003558 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003559 "max_pool2d": {
3560 "op": Op.MAX_POOL2D,
3561 "operands": (1, 0),
3562 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003563 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003564 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003565 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003566 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003567 TosaArgGen.agPooling,
3568 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003569 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003570 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003571 "error_if_validators": (
3572 TosaErrorValidator.evKernelSmallerOne,
3573 TosaErrorValidator.evStrideSmallerOne,
3574 TosaErrorValidator.evPadSmallerZero,
3575 TosaErrorValidator.evWrongRank,
3576 TosaErrorValidator.evWrongInputType,
3577 TosaErrorValidator.evWrongOutputType,
3578 TosaErrorValidator.evWrongInputList,
3579 TosaErrorValidator.evWrongOutputList,
3580 TosaErrorValidator.evPadLargerEqualKernel,
3581 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003582 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003583 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003584 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003585 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003586 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003587 "transpose_conv2d_TEMPLATE": {
3588 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003589 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003590 "rank": (4, 4),
3591 "build_fcn": (
3592 build_transpose_conv2d,
3593 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003594 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003595 TosaArgGen.agTransposeConv2D,
3596 ),
3597 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003598 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003599 "invalid_test_validators": (
3600 TosaInvalidValidator.ivHeightWidthInvalid,
3601 TosaInvalidValidator.ivNonPositiveOutputShape,
3602 ),
3603 "error_if_validators": (
3604 TosaErrorValidator.evWrongInputType,
3605 TosaErrorValidator.evWrongOutputType,
3606 TosaErrorValidator.evWrongInputList,
3607 TosaErrorValidator.evWrongOutputList,
3608 TosaErrorValidator.evInputZeroPointNotZero,
3609 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003610 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003611 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003612 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003613 TosaErrorValidator.evConvOutputShapeMismatch,
Tai Lyf36f2562024-03-14 16:21:29 +00003614 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003615 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003616 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003617 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003618 "template": True,
3619 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003620 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003621 "clamp": {
3622 "op": Op.CLAMP,
3623 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003624 "build_fcn": (
3625 build_clamp,
3626 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003627 TosaTensorValuesGen.tvgLazyGenDefault,
3628 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003629 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003630 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003631 "error_if_validators": (
3632 TosaErrorValidator.evMaxSmallerMin,
3633 TosaErrorValidator.evWrongInputType,
3634 TosaErrorValidator.evWrongOutputType,
3635 TosaErrorValidator.evWrongInputList,
3636 TosaErrorValidator.evWrongOutputList,
3637 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003638 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003639 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003640 "sigmoid": {
3641 "op": Op.SIGMOID,
3642 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003643 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003644 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003645 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003646 TosaTensorValuesGen.tvgLazyGenDefault,
3647 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003648 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003649 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003650 "error_if_validators": (
3651 TosaErrorValidator.evWrongInputType,
3652 TosaErrorValidator.evWrongOutputType,
3653 TosaErrorValidator.evWrongInputList,
3654 TosaErrorValidator.evWrongOutputList,
3655 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003656 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003657 },
3658 "tanh": {
3659 "op": Op.TANH,
3660 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003661 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003662 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003663 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003664 TosaTensorValuesGen.tvgLazyGenDefault,
3665 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003666 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003667 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003668 "error_if_validators": (
3669 TosaErrorValidator.evWrongInputType,
3670 TosaErrorValidator.evWrongOutputType,
3671 TosaErrorValidator.evWrongInputList,
3672 TosaErrorValidator.evWrongOutputList,
3673 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003674 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003675 "compliance": {
3676 "abs_error_lower_bound": 0.5,
3677 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003678 },
Won Jeon78155c62023-06-10 00:20:04 +00003679 "erf": {
3680 "op": Op.ERF,
3681 "operands": (1, 0),
3682 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003683 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003684 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003685 TosaTensorValuesGen.tvgLazyGenDefault,
3686 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003687 ),
3688 "types": TYPE_FP,
3689 "error_if_validators": (
3690 TosaErrorValidator.evWrongInputType,
3691 TosaErrorValidator.evWrongOutputType,
3692 TosaErrorValidator.evWrongInputList,
3693 TosaErrorValidator.evWrongOutputList,
3694 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003695 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003696 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003697 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003698 # Elementwise Binary Operators
3699 "add": {
3700 "op": Op.ADD,
3701 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003702 "build_fcn": (
3703 build_binary_broadcast,
3704 TosaTensorGen.tgBroadcastFuzz,
3705 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003706 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003707 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003708 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003709 "error_if_validators": (
3710 TosaErrorValidator.evRankMismatch,
3711 TosaErrorValidator.evWrongInputType,
3712 TosaErrorValidator.evWrongOutputType,
3713 TosaErrorValidator.evWrongInputList,
3714 TosaErrorValidator.evWrongOutputList,
3715 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003716 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003717 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003718 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003719 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003720 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003721 "arithmetic_right_shift": {
3722 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3723 "operands": (2, 0),
3724 "build_fcn": (
3725 build_arithmetic_right_shift,
3726 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003727 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003728 TosaArgGen.agArithmeticRightShift,
3729 ),
3730 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003731 "error_if_validators": (
3732 TosaErrorValidator.evRankMismatch,
3733 TosaErrorValidator.evWrongInputType,
3734 TosaErrorValidator.evWrongOutputType,
3735 TosaErrorValidator.evWrongInputList,
3736 TosaErrorValidator.evWrongOutputList,
3737 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003738 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003739 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003740 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003741 "bitwise_and": {
3742 "op": Op.BITWISE_AND,
3743 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003744 "build_fcn": (
3745 build_binary_broadcast,
3746 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003747 TosaTensorValuesGen.tvgLazyGenDefault,
3748 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003749 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003750 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003751 "error_if_validators": (
3752 TosaErrorValidator.evRankMismatch,
3753 TosaErrorValidator.evWrongInputType,
3754 TosaErrorValidator.evWrongOutputType,
3755 TosaErrorValidator.evWrongInputList,
3756 TosaErrorValidator.evWrongOutputList,
3757 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003758 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003759 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003760 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003761 "bitwise_or": {
3762 "op": Op.BITWISE_OR,
3763 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003764 "build_fcn": (
3765 build_binary_broadcast,
3766 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003767 TosaTensorValuesGen.tvgLazyGenDefault,
3768 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003769 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003770 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003771 "error_if_validators": (
3772 TosaErrorValidator.evRankMismatch,
3773 TosaErrorValidator.evWrongInputType,
3774 TosaErrorValidator.evWrongOutputType,
3775 TosaErrorValidator.evWrongInputList,
3776 TosaErrorValidator.evWrongOutputList,
3777 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003778 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003779 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003780 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 "bitwise_xor": {
3782 "op": Op.BITWISE_XOR,
3783 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003784 "build_fcn": (
3785 build_binary_broadcast,
3786 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003787 TosaTensorValuesGen.tvgLazyGenDefault,
3788 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003789 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003790 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003791 "error_if_validators": (
3792 TosaErrorValidator.evRankMismatch,
3793 TosaErrorValidator.evWrongInputType,
3794 TosaErrorValidator.evWrongOutputType,
3795 TosaErrorValidator.evWrongInputList,
3796 TosaErrorValidator.evWrongOutputList,
3797 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003798 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003799 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003800 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003801 "intdiv": {
3802 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003803 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003804 "build_fcn": (
3805 build_binary_broadcast,
3806 TosaTensorGen.tgBroadcastFuzz,
3807 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003808 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003809 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003810 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003811 "error_if_validators": (
3812 TosaErrorValidator.evRankMismatch,
3813 TosaErrorValidator.evWrongInputType,
3814 TosaErrorValidator.evWrongOutputType,
3815 TosaErrorValidator.evWrongInputList,
3816 TosaErrorValidator.evWrongOutputList,
3817 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003818 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003819 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003820 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003821 "logical_and": {
3822 "op": Op.LOGICAL_AND,
3823 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003824 "build_fcn": (
3825 build_binary_broadcast,
3826 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003827 TosaTensorValuesGen.tvgLazyGenDefault,
3828 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003829 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003830 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003831 "error_if_validators": (
3832 TosaErrorValidator.evRankMismatch,
3833 TosaErrorValidator.evWrongInputType,
3834 TosaErrorValidator.evWrongOutputType,
3835 TosaErrorValidator.evWrongInputList,
3836 TosaErrorValidator.evWrongOutputList,
3837 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003838 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003839 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003840 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003841 "logical_left_shift": {
3842 "op": Op.LOGICAL_LEFT_SHIFT,
3843 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003844 "build_fcn": (
3845 build_binary_broadcast,
3846 TosaTensorGen.tgBroadcastFuzz,
3847 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003848 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003849 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003850 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003851 "error_if_validators": (
3852 TosaErrorValidator.evRankMismatch,
3853 TosaErrorValidator.evWrongInputType,
3854 TosaErrorValidator.evWrongOutputType,
3855 TosaErrorValidator.evWrongInputList,
3856 TosaErrorValidator.evWrongOutputList,
3857 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003858 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003859 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003860 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003861 "logical_right_shift": {
3862 "op": Op.LOGICAL_RIGHT_SHIFT,
3863 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003864 "build_fcn": (
3865 build_binary_broadcast,
3866 TosaTensorGen.tgBroadcastFuzz,
3867 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003868 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003869 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003870 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003871 "error_if_validators": (
3872 TosaErrorValidator.evRankMismatch,
3873 TosaErrorValidator.evWrongInputType,
3874 TosaErrorValidator.evWrongOutputType,
3875 TosaErrorValidator.evWrongInputList,
3876 TosaErrorValidator.evWrongOutputList,
3877 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003878 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003879 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003880 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003881 "logical_or": {
3882 "op": Op.LOGICAL_OR,
3883 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003884 "build_fcn": (
3885 build_binary_broadcast,
3886 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003887 TosaTensorValuesGen.tvgLazyGenDefault,
3888 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003889 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003890 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003891 "error_if_validators": (
3892 TosaErrorValidator.evRankMismatch,
3893 TosaErrorValidator.evWrongInputType,
3894 TosaErrorValidator.evWrongOutputType,
3895 TosaErrorValidator.evWrongInputList,
3896 TosaErrorValidator.evWrongOutputList,
3897 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003898 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003899 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003900 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003901 "logical_xor": {
3902 "op": Op.LOGICAL_XOR,
3903 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003904 "build_fcn": (
3905 build_binary_broadcast,
3906 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003907 TosaTensorValuesGen.tvgLazyGenDefault,
3908 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003909 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003910 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003911 "error_if_validators": (
3912 TosaErrorValidator.evRankMismatch,
3913 TosaErrorValidator.evWrongInputType,
3914 TosaErrorValidator.evWrongOutputType,
3915 TosaErrorValidator.evWrongInputList,
3916 TosaErrorValidator.evWrongOutputList,
3917 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003918 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003919 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003920 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003921 "maximum": {
3922 "op": Op.MAXIMUM,
3923 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003924 "build_fcn": (
3925 build_binary_broadcast,
3926 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003927 TosaTensorValuesGen.tvgLazyGenDefault,
3928 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003929 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003930 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003931 "error_if_validators": (
3932 TosaErrorValidator.evRankMismatch,
3933 TosaErrorValidator.evWrongInputType,
3934 TosaErrorValidator.evWrongOutputType,
3935 TosaErrorValidator.evWrongInputList,
3936 TosaErrorValidator.evWrongOutputList,
3937 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003938 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003939 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003940 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003941 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003942 "minimum": {
3943 "op": Op.MINIMUM,
3944 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003945 "build_fcn": (
3946 build_binary_broadcast,
3947 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003948 TosaTensorValuesGen.tvgLazyGenDefault,
3949 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003950 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003951 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003952 "error_if_validators": (
3953 TosaErrorValidator.evRankMismatch,
3954 TosaErrorValidator.evWrongInputType,
3955 TosaErrorValidator.evWrongOutputType,
3956 TosaErrorValidator.evWrongInputList,
3957 TosaErrorValidator.evWrongOutputList,
3958 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003959 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003960 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003961 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003962 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003963 "mul": {
3964 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003965 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003966 "build_fcn": (
3967 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003968 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003969 TosaTensorValuesGen.tvgMul,
3970 TosaArgGen.agMul,
3971 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003972 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003973 "error_if_validators": (
3974 TosaErrorValidator.evWrongInputType,
3975 TosaErrorValidator.evWrongOutputType,
3976 TosaErrorValidator.evWrongInputList,
3977 TosaErrorValidator.evWrongOutputList,
3978 TosaErrorValidator.evRankMismatch,
3979 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003980 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003981 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003982 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003983 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003984 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003985 "pow": {
3986 "op": Op.POW,
3987 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003988 "build_fcn": (
3989 build_binary_broadcast,
3990 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003991 TosaTensorValuesGen.tvgPow,
3992 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003993 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003994 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003995 "error_if_validators": (
3996 TosaErrorValidator.evRankMismatch,
3997 TosaErrorValidator.evWrongInputType,
3998 TosaErrorValidator.evWrongOutputType,
3999 TosaErrorValidator.evWrongInputList,
4000 TosaErrorValidator.evWrongOutputList,
4001 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004002 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004003 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004004 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004005 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004006 "sub": {
4007 "op": Op.SUB,
4008 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004009 "build_fcn": (
4010 build_binary_broadcast,
4011 TosaTensorGen.tgBroadcastFuzz,
4012 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004013 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004014 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004015 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004016 "error_if_validators": (
4017 TosaErrorValidator.evRankMismatch,
4018 TosaErrorValidator.evWrongInputType,
4019 TosaErrorValidator.evWrongOutputType,
4020 TosaErrorValidator.evWrongInputList,
4021 TosaErrorValidator.evWrongOutputList,
4022 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004023 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004024 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004025 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004026 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004027 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004028 "table": {
4029 "op": Op.TABLE,
4030 # Use the automatic generation functions to create the input array
4031 # but create the table tensor in the build function, as it may be
4032 # a different type from the input
4033 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004034 "build_fcn": (
4035 build_table,
4036 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00004037 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004038 TosaArgGen.agTable,
4039 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004040 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004041 "error_if_validators": (
4042 TosaErrorValidator.evWrongInputType,
4043 TosaErrorValidator.evWrongOutputType,
4044 TosaErrorValidator.evWrongInputList,
4045 TosaErrorValidator.evWrongOutputList,
4046 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004047 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004048 # Elementwise Unary operators
4049 "abs": {
4050 "op": Op.ABS,
4051 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004052 "build_fcn": (
4053 build_unary,
4054 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004055 TosaTensorValuesGen.tvgLazyGenDefault,
4056 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004057 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004058 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004059 "error_if_validators": (
4060 TosaErrorValidator.evWrongInputType,
4061 TosaErrorValidator.evWrongOutputType,
4062 TosaErrorValidator.evWrongInputList,
4063 TosaErrorValidator.evWrongOutputList,
4064 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004065 "data_gen": EW_UNARY_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004066 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004067 "bitwise_not": {
4068 "op": Op.BITWISE_NOT,
4069 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004070 "build_fcn": (
4071 build_unary,
4072 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004073 TosaTensorValuesGen.tvgLazyGenDefault,
4074 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004075 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004076 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004077 "error_if_validators": (
4078 TosaErrorValidator.evWrongInputType,
4079 TosaErrorValidator.evWrongOutputType,
4080 TosaErrorValidator.evWrongInputList,
4081 TosaErrorValidator.evWrongOutputList,
4082 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004083 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004084 "ceil": {
4085 "op": Op.CEIL,
4086 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004087 "build_fcn": (
4088 build_unary,
4089 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004090 TosaTensorValuesGen.tvgLazyGenDefault,
4091 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004092 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004093 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004094 "error_if_validators": (
4095 TosaErrorValidator.evWrongInputType,
4096 TosaErrorValidator.evWrongOutputType,
4097 TosaErrorValidator.evWrongInputList,
4098 TosaErrorValidator.evWrongOutputList,
4099 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004100 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004101 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004102 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004103 "clz": {
4104 "op": Op.CLZ,
4105 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004106 "build_fcn": (
4107 build_unary,
4108 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004109 TosaTensorValuesGen.tvgLazyGenDefault,
4110 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004111 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004112 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004113 "error_if_validators": (
4114 TosaErrorValidator.evWrongInputType,
4115 TosaErrorValidator.evWrongOutputType,
4116 TosaErrorValidator.evWrongInputList,
4117 TosaErrorValidator.evWrongOutputList,
4118 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004119 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004120 "cos": {
4121 "op": Op.COS,
4122 "operands": (1, 0),
4123 "build_fcn": (
4124 build_unary,
4125 TosaTensorGen.tgBasic,
4126 TosaTensorValuesGen.tvgLazyGenDefault,
4127 TosaArgGen.agNone,
4128 ),
4129 "types": TYPE_FP,
4130 "error_if_validators": (
4131 TosaErrorValidator.evWrongInputType,
4132 TosaErrorValidator.evWrongOutputType,
4133 TosaErrorValidator.evWrongInputList,
4134 TosaErrorValidator.evWrongOutputList,
4135 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004136 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jerry Ge51bd4f52024-02-20 11:21:19 -08004137 "compliance": {"abs_error_normal_divisor": 2},
4138 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004139 "exp": {
4140 "op": Op.EXP,
4141 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004142 "build_fcn": (
4143 build_unary,
4144 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004145 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004146 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004147 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004148 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004149 "error_if_validators": (
4150 TosaErrorValidator.evWrongInputType,
4151 TosaErrorValidator.evWrongOutputType,
4152 TosaErrorValidator.evWrongInputList,
4153 TosaErrorValidator.evWrongOutputList,
4154 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004155 "data_gen": EW_UNARY_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004156 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004157 "floor": {
4158 "op": Op.FLOOR,
4159 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004160 "build_fcn": (
4161 build_unary,
4162 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004163 TosaTensorValuesGen.tvgLazyGenDefault,
4164 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004165 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004166 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004167 "error_if_validators": (
4168 TosaErrorValidator.evWrongInputType,
4169 TosaErrorValidator.evWrongOutputType,
4170 TosaErrorValidator.evWrongInputList,
4171 TosaErrorValidator.evWrongOutputList,
4172 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004173 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004174 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004175 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004176 "log": {
4177 "op": Op.LOG,
4178 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004179 "build_fcn": (
4180 build_unary,
4181 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004182 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004183 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004184 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004185 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004186 "error_if_validators": (
4187 TosaErrorValidator.evWrongInputType,
4188 TosaErrorValidator.evWrongOutputType,
4189 TosaErrorValidator.evWrongInputList,
4190 TosaErrorValidator.evWrongOutputList,
4191 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004192 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004193 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004194 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004195 "logical_not": {
4196 "op": Op.LOGICAL_NOT,
4197 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004198 "build_fcn": (
4199 build_unary,
4200 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004201 TosaTensorValuesGen.tvgLazyGenDefault,
4202 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004203 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004204 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004205 "error_if_validators": (
4206 TosaErrorValidator.evWrongInputType,
4207 TosaErrorValidator.evWrongOutputType,
4208 TosaErrorValidator.evWrongInputList,
4209 TosaErrorValidator.evWrongOutputList,
4210 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004211 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004212 "negate": {
4213 "op": Op.NEGATE,
4214 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004215 "build_fcn": (
4216 build_unary,
4217 TosaTensorGen.tgBasic,
4218 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004219 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004220 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004221 "qgen": TosaQuantGen.qgUnary,
4222 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004223 "error_if_validators": (
4224 TosaErrorValidator.evInputZeroPointNotZero,
4225 TosaErrorValidator.evOutputZeroPointNotZero,
4226 TosaErrorValidator.evWrongInputType,
4227 TosaErrorValidator.evWrongOutputType,
4228 TosaErrorValidator.evWrongInputList,
4229 TosaErrorValidator.evWrongOutputList,
4230 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004231 "data_gen": EW_UNARY_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004232 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004233 "reciprocal": {
4234 "op": Op.RECIPROCAL,
4235 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004236 "build_fcn": (
4237 build_unary,
4238 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004239 TosaTensorValuesGen.tvgLazyGenDefault,
4240 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004241 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004242 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004243 "error_if_validators": (
4244 TosaErrorValidator.evWrongInputType,
4245 TosaErrorValidator.evWrongOutputType,
4246 TosaErrorValidator.evWrongInputList,
4247 TosaErrorValidator.evWrongOutputList,
4248 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004249 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004250 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004251 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004252 "rsqrt": {
4253 "op": Op.RSQRT,
4254 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004255 "build_fcn": (
4256 build_unary,
4257 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004258 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004259 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004260 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004261 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004262 "error_if_validators": (
4263 TosaErrorValidator.evWrongInputType,
4264 TosaErrorValidator.evWrongOutputType,
4265 TosaErrorValidator.evWrongInputList,
4266 TosaErrorValidator.evWrongOutputList,
4267 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004268 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004269 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004270 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004271 "sin": {
4272 "op": Op.SIN,
4273 "operands": (1, 0),
4274 "build_fcn": (
4275 build_unary,
4276 TosaTensorGen.tgBasic,
4277 TosaTensorValuesGen.tvgLazyGenDefault,
4278 TosaArgGen.agNone,
4279 ),
4280 "types": TYPE_FP,
4281 "error_if_validators": (
4282 TosaErrorValidator.evWrongInputType,
4283 TosaErrorValidator.evWrongOutputType,
4284 TosaErrorValidator.evWrongInputList,
4285 TosaErrorValidator.evWrongOutputList,
4286 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004287 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jerry Ge51bd4f52024-02-20 11:21:19 -08004288 "compliance": {"abs_error_normal_divisor": 2},
4289 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004290 # Elementwise Ternary operators
4291 "select": {
4292 "op": Op.SELECT,
4293 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004294 "build_fcn": (
4295 build_select,
4296 TosaTensorGen.tgBroadcastFuzz,
4297 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004298 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004299 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004300 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004301 "error_if_validators": (
4302 TosaErrorValidator.evRankMismatch,
4303 TosaErrorValidator.evWrongInputType,
4304 TosaErrorValidator.evWrongOutputType,
4305 TosaErrorValidator.evWrongInputList,
4306 TosaErrorValidator.evWrongOutputList,
4307 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004308 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004309 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004310 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004311 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004312 # Comparison operators
4313 "equal": {
4314 "op": Op.EQUAL,
4315 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004316 "build_fcn": (
4317 build_comparison,
4318 TosaTensorGen.tgBroadcastFuzz,
4319 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004320 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004321 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004322 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004323 "error_if_validators": (
4324 TosaErrorValidator.evRankMismatch,
4325 TosaErrorValidator.evWrongInputType,
4326 TosaErrorValidator.evWrongOutputType,
4327 TosaErrorValidator.evWrongInputList,
4328 TosaErrorValidator.evWrongOutputList,
4329 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004330 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004331 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004332 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004333 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004334 "greater_equal": {
4335 "op": Op.GREATER_EQUAL,
4336 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004337 "build_fcn": (
4338 build_comparison,
4339 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004340 TosaTensorValuesGen.tvgLazyGenDefault,
4341 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004342 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004343 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004344 "error_if_validators": (
4345 TosaErrorValidator.evRankMismatch,
4346 TosaErrorValidator.evWrongInputType,
4347 TosaErrorValidator.evWrongOutputType,
4348 TosaErrorValidator.evWrongInputList,
4349 TosaErrorValidator.evWrongOutputList,
4350 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004351 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004352 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004353 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004354 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004355 "greater": {
4356 "op": Op.GREATER,
4357 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004358 "build_fcn": (
4359 build_comparison,
4360 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004361 TosaTensorValuesGen.tvgLazyGenDefault,
4362 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004363 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004364 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004365 "error_if_validators": (
4366 TosaErrorValidator.evRankMismatch,
4367 TosaErrorValidator.evWrongInputType,
4368 TosaErrorValidator.evWrongOutputType,
4369 TosaErrorValidator.evWrongInputList,
4370 TosaErrorValidator.evWrongOutputList,
4371 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004372 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004373 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004374 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004375 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004376 # Reduction operators
4377 "reduce_all": {
4378 "op": Op.REDUCE_ALL,
4379 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004380 "build_fcn": (
4381 build_reduce,
4382 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004383 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004384 TosaArgGen.agAxis,
4385 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004386 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004387 "error_if_validators": (
4388 TosaErrorValidator.evAxisLargerRank,
4389 TosaErrorValidator.evAxisSmallerZero,
4390 TosaErrorValidator.evShapeOfAxisNotOne,
4391 TosaErrorValidator.evWrongInputType,
4392 TosaErrorValidator.evWrongOutputType,
4393 TosaErrorValidator.evWrongRank,
4394 TosaErrorValidator.evWrongInputList,
4395 TosaErrorValidator.evWrongOutputList,
4396 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004397 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004398 "reduce_any": {
4399 "op": Op.REDUCE_ANY,
4400 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004401 "build_fcn": (
4402 build_reduce,
4403 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004404 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004405 TosaArgGen.agAxis,
4406 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004407 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004408 "error_if_validators": (
4409 TosaErrorValidator.evAxisLargerRank,
4410 TosaErrorValidator.evAxisSmallerZero,
4411 TosaErrorValidator.evShapeOfAxisNotOne,
4412 TosaErrorValidator.evWrongInputType,
4413 TosaErrorValidator.evWrongOutputType,
4414 TosaErrorValidator.evWrongRank,
4415 TosaErrorValidator.evWrongInputList,
4416 TosaErrorValidator.evWrongOutputList,
4417 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004418 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004419 "reduce_max": {
4420 "op": Op.REDUCE_MAX,
4421 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004422 "build_fcn": (
4423 build_reduce,
4424 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004425 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004426 TosaArgGen.agAxis,
4427 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004428 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004429 "error_if_validators": (
4430 TosaErrorValidator.evAxisLargerRank,
4431 TosaErrorValidator.evAxisSmallerZero,
4432 TosaErrorValidator.evShapeOfAxisNotOne,
4433 TosaErrorValidator.evWrongInputType,
4434 TosaErrorValidator.evWrongOutputType,
4435 TosaErrorValidator.evWrongRank,
4436 TosaErrorValidator.evWrongInputList,
4437 TosaErrorValidator.evWrongOutputList,
4438 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004439 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004440 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004441 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004442 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004443 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004444 "build_fcn": (
4445 build_reduce,
4446 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004447 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004448 TosaArgGen.agAxis,
4449 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004450 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004451 "error_if_validators": (
4452 TosaErrorValidator.evAxisLargerRank,
4453 TosaErrorValidator.evAxisSmallerZero,
4454 TosaErrorValidator.evShapeOfAxisNotOne,
4455 TosaErrorValidator.evWrongInputType,
4456 TosaErrorValidator.evWrongOutputType,
4457 TosaErrorValidator.evWrongRank,
4458 TosaErrorValidator.evWrongInputList,
4459 TosaErrorValidator.evWrongOutputList,
4460 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004461 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004462 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004463 "reduce_product": {
4464 "op": Op.REDUCE_PRODUCT,
4465 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004466 "build_fcn": (
4467 build_reduce,
4468 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004469 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004470 TosaArgGen.agAxis,
4471 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004472 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004473 "error_if_validators": (
4474 TosaErrorValidator.evAxisLargerRank,
4475 TosaErrorValidator.evAxisSmallerZero,
4476 TosaErrorValidator.evShapeOfAxisNotOne,
4477 TosaErrorValidator.evWrongInputType,
4478 TosaErrorValidator.evWrongOutputType,
4479 TosaErrorValidator.evWrongRank,
4480 TosaErrorValidator.evWrongInputList,
4481 TosaErrorValidator.evWrongOutputList,
4482 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004483 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004484 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004485 "reduce_sum": {
4486 "op": Op.REDUCE_SUM,
4487 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004488 "build_fcn": (
4489 build_reduce,
4490 TosaTensorGen.tgBasic,
4491 TosaTensorValuesGen.tvgReduceSum,
4492 TosaArgGen.agAxis,
4493 ),
James Ward24dbc422022-10-19 12:20:31 +01004494 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004495 "error_if_validators": (
4496 TosaErrorValidator.evAxisLargerRank,
4497 TosaErrorValidator.evAxisSmallerZero,
4498 TosaErrorValidator.evShapeOfAxisNotOne,
4499 TosaErrorValidator.evWrongInputType,
4500 TosaErrorValidator.evWrongOutputType,
4501 TosaErrorValidator.evWrongRank,
4502 TosaErrorValidator.evWrongInputList,
4503 TosaErrorValidator.evWrongOutputList,
4504 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004505 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004506 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004507 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004508 "concat": {
4509 "op": Op.CONCAT,
4510 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004511 "build_fcn": (
4512 build_concat,
4513 TosaTensorGen.tgConcat,
4514 TosaTensorValuesGen.tvgConcat,
4515 TosaArgGen.agAxis,
4516 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004517 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004518 "error_if_validators": (
4519 TosaErrorValidator.evAxisLargerRank,
4520 TosaErrorValidator.evAxisSmallerZero,
4521 TosaErrorValidator.evConcatInputRankMismatch,
4522 TosaErrorValidator.evConcatShapeSumMismatch,
4523 TosaErrorValidator.evConcatInputDimMismatch,
4524 TosaErrorValidator.evWrongInputType,
4525 TosaErrorValidator.evWrongOutputType,
4526 TosaErrorValidator.evWrongOutputList,
4527 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004528 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004529 },
4530 "pad": {
4531 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004532 "operands": (2, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004533 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004534 "build_fcn": (
4535 build_pad,
4536 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004537 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004538 TosaArgGen.agPad,
4539 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004540 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004541 "error_if_validators": (
4542 TosaErrorValidator.evWrongInputType,
4543 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004544 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004545 TosaErrorValidator.evWrongOutputType,
4546 TosaErrorValidator.evWrongInputList,
4547 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004548 TosaErrorValidator.evRankMismatch,
4549 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004550 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004551 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004552 },
Won Jeona21b2e82023-08-10 10:33:01 +00004553 "dim": {
4554 "op": Op.DIM,
4555 "operands": (1, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004556 "rank": (1, gtu.MAX_TENSOR_RANK),
Won Jeona21b2e82023-08-10 10:33:01 +00004557 "build_fcn": (
4558 build_dim,
4559 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004560 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004561 TosaArgGen.agAxis,
4562 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004563 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004564 "error_if_validators": (
4565 TosaErrorValidator.evAxisLargerRank,
4566 TosaErrorValidator.evAxisSmallerZero,
4567 TosaErrorValidator.evWrongInputType,
4568 TosaErrorValidator.evWrongInputList,
4569 TosaErrorValidator.evWrongOutputList,
4570 TosaErrorValidator.evWrongRank,
4571 ),
4572 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004573 "reshape": {
4574 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004575 "operands": (2, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004576 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004577 "build_fcn": (
4578 build_reshape,
4579 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004580 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004581 TosaArgGen.agReshape,
4582 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004583 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004584 "error_if_validators": (
4585 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4586 TosaErrorValidator.evWrongInputType,
4587 TosaErrorValidator.evWrongOutputType,
4588 TosaErrorValidator.evWrongInputList,
4589 TosaErrorValidator.evWrongOutputList,
4590 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004591 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004592 },
4593 "reverse": {
4594 "op": Op.REVERSE,
4595 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004596 "build_fcn": (
4597 build_reverse,
4598 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004599 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004600 TosaArgGen.agAxis,
4601 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004602 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004603 "error_if_validators": (
4604 TosaErrorValidator.evAxisSmallerZero,
4605 TosaErrorValidator.evAxisLargerRank,
4606 TosaErrorValidator.evWrongInputType,
4607 TosaErrorValidator.evWrongOutputType,
4608 TosaErrorValidator.evWrongInputList,
4609 TosaErrorValidator.evWrongOutputList,
4610 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004611 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004612 },
4613 "slice": {
4614 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004615 "operands": (3, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004616 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004617 "build_fcn": (
4618 build_slice,
4619 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004620 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004621 TosaArgGen.agSlice,
4622 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004623 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004624 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004625 # TODO Turn off these error categories for now as the reference
4626 # model cannot allocate memory space for empty tensor. We probably
4627 # can report an accurate error messege at the right place during
4628 # exeuction.
4629 # TosaErrorValidator.evStartSmallerZero,
4630 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004631 TosaErrorValidator.evStartSizeOutsideBounds,
4632 TosaErrorValidator.evSizeOutputShapeMismatch,
4633 TosaErrorValidator.evInputSizeStartLengthMismatch,
4634 TosaErrorValidator.evWrongRank,
4635 TosaErrorValidator.evWrongInputType,
4636 TosaErrorValidator.evWrongOutputType,
4637 TosaErrorValidator.evWrongInputList,
4638 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004639 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004640 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004641 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004642 },
4643 "tile": {
4644 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004645 "operands": (2, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004646 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004647 "build_fcn": (
4648 build_tile,
4649 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004650 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004651 TosaArgGen.agTile,
4652 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004653 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004654 "error_if_validators": (
4655 TosaErrorValidator.evWrongInputType,
4656 TosaErrorValidator.evWrongOutputType,
4657 TosaErrorValidator.evWrongInputList,
4658 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004659 TosaErrorValidator.evRankMismatch,
4660 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004661 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004662 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004663 },
4664 "transpose": {
4665 "op": Op.TRANSPOSE,
4666 "operands": (1, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004667 "rank": (1, gtu.MAX_TENSOR_RANK),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004668 "build_fcn": (
4669 build_transpose,
4670 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004671 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004672 TosaArgGen.agTranspose,
4673 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004674 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004675 "error_if_validators": (
4676 TosaErrorValidator.evIndexOutsideBounds,
4677 TosaErrorValidator.evIndexUsedTwice,
4678 TosaErrorValidator.evWrongInputType,
4679 TosaErrorValidator.evWrongOutputType,
4680 TosaErrorValidator.evWrongInputList,
4681 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004682 TosaErrorValidator.evWrongRank,
4683 TosaErrorValidator.evRankMismatch,
4684 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004685 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004686 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004687 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004688 # Data nodes
4689 "const": {
4690 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004691 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004692 "build_fcn": (
4693 build_const,
4694 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004695 TosaTensorValuesGen.tvgLazyGenDefault,
4696 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004697 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004698 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha01ad8e1e22024-03-19 12:42:17 +00004699 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004700 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004701 "identity": {
4702 "op": Op.IDENTITY,
4703 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004704 "build_fcn": (
4705 build_unary,
4706 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004707 TosaTensorValuesGen.tvgLazyGenDefault,
4708 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004709 ),
evacha011adff832024-03-06 17:33:44 +00004710 "types": TYPE_FIB + [DType.INT4, DType.INT48],
evacha01ad8e1e22024-03-19 12:42:17 +00004711 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004712 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004713 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004714 "gather": {
4715 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004716 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004717 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004718 "build_fcn": (
4719 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004720 TosaTensorGen.tgGather,
4721 TosaTensorValuesGen.tvgGather,
4722 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004723 ),
James Ward24dbc422022-10-19 12:20:31 +01004724 "types": (
4725 DType.INT8,
4726 DType.INT16,
4727 DType.INT32,
4728 DType.FP16,
4729 DType.BF16,
4730 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004731 DType.FP8E4M3,
4732 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004733 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004734 "error_if_validators": (
4735 TosaErrorValidator.evWrongInputType,
4736 TosaErrorValidator.evWrongOutputType,
4737 TosaErrorValidator.evWrongInputList,
4738 TosaErrorValidator.evWrongOutputList,
4739 TosaErrorValidator.evWrongRank,
4740 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004741 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004742 },
4743 "scatter": {
4744 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004745 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004746 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004747 "build_fcn": (
4748 build_scatter,
4749 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004750 TosaTensorValuesGen.tvgScatter,
4751 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004752 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004753 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004754 "error_if_validators": (
4755 TosaErrorValidator.evWrongInputType,
4756 TosaErrorValidator.evWrongOutputType,
4757 TosaErrorValidator.evWrongInputList,
4758 TosaErrorValidator.evWrongOutputList,
4759 TosaErrorValidator.evWrongRank,
4760 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004761 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004762 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004763 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004764 "resize": {
4765 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004766 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004767 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004768 "build_fcn": (
4769 build_resize,
4770 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004771 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004772 TosaArgGen.agResize,
4773 ),
James Ward24dbc422022-10-19 12:20:31 +01004774 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004775 "invalid_test_validators": (
4776 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004777 ),
4778 "error_if_validators": (
4779 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004780 TosaErrorValidator.evScaleSmallerEqualZero,
4781 TosaErrorValidator.evScaleNLargerMax,
4782 TosaErrorValidator.evScaleDLargerMax,
4783 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004784 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004785 TosaErrorValidator.evBorderSmallerMin,
4786 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004787 TosaErrorValidator.evWrongInputType,
4788 TosaErrorValidator.evWrongOutputType,
4789 TosaErrorValidator.evWrongRank,
4790 TosaErrorValidator.evWrongInputList,
4791 TosaErrorValidator.evWrongOutputList,
4792 TosaErrorValidator.evBatchMismatch,
4793 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004794 TosaErrorValidator.evResizeOutputShapeMismatch,
4795 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004796 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004797 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004798 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004799 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004800 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004801 "cast": {
4802 "op": Op.CAST,
4803 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004804 "build_fcn": (
4805 build_cast,
4806 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004807 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004808 TosaArgGen.agCast,
4809 ),
James Ward8b390432022-08-12 20:48:56 +01004810 "types": (
4811 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004812 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004813 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004814 DType.INT8,
4815 DType.INT16,
4816 DType.INT32,
4817 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004818 DType.FP8E4M3,
4819 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004820 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004821 "error_if_validators": (
4822 TosaErrorValidator.evWrongInputType,
4823 TosaErrorValidator.evWrongOutputType,
4824 TosaErrorValidator.evWrongInputList,
4825 TosaErrorValidator.evWrongOutputList,
4826 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004827 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson708da822023-11-15 16:25:45 +00004828 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004829 },
4830 "rescale": {
4831 "op": Op.RESCALE,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004832 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004833 "build_fcn": (
4834 build_rescale,
4835 TosaTensorGen.tgBasic,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004836 TosaTensorValuesGen.tvgRescale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004837 TosaArgGen.agRescale,
4838 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004839 "types": [
4840 DType.UINT8,
4841 DType.INT8,
4842 DType.INT16,
4843 DType.INT32,
4844 DType.INT48,
4845 DType.UINT16,
4846 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004847 "error_if_validators": (
4848 TosaErrorValidator.evInputZeroPointNotZero,
4849 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004850 TosaErrorValidator.evU16InputZeroPointNotValid,
4851 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004852 TosaErrorValidator.evScaleTrue,
4853 TosaErrorValidator.evScaleNotTrue,
4854 TosaErrorValidator.evWrongInputType,
4855 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004856 TosaErrorValidator.evWrongInputList,
4857 TosaErrorValidator.evWrongOutputList,
4858 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004859 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004860 # Custom
4861 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004862 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004863 # Two varients of cond_if, one that generates one of two constant tensors (no
4864 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4865 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004866 "cond_if_const": {
4867 "op": Op.COND_IF,
4868 "operands": (0, 2),
4869 "build_fcn": (
4870 build_cond_if_const,
4871 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004872 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004873 TosaArgGen.agCondIf,
4874 ),
4875 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004876 "error_if_validators": (
4877 TosaErrorValidator.evOutputListThenGraphMismatch,
4878 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004879 TosaErrorValidator.evCondIfCondNotMatchingBool,
4880 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004881 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004882 },
4883 "cond_if_binary": {
4884 "op": Op.COND_IF,
4885 "operands": (2, 0),
4886 "build_fcn": (
4887 build_cond_if_binary,
4888 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004889 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004890 TosaArgGen.agCondIf,
4891 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004892 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004893 "error_if_validators": (
4894 TosaErrorValidator.evInputListThenGraphMismatch,
4895 TosaErrorValidator.evInputListElseGraphMismatch,
4896 TosaErrorValidator.evOutputListThenGraphMismatch,
4897 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004898 TosaErrorValidator.evCondIfCondNotMatchingBool,
4899 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004900 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004901 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004902 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004903 "while_loop": {
4904 "op": Op.WHILE_LOOP,
4905 "operands": (0, 1),
4906 "build_fcn": (
4907 build_while_loop,
4908 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004909 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004910 TosaArgGen.agWhileLoop,
4911 ),
4912 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004913 "error_if_validators": (
4914 TosaErrorValidator.evInputListOutputListMismatch,
4915 TosaErrorValidator.evInputListCondGraphMismatch,
4916 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4917 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4918 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004919 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004920 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004921 },
Luke Hutton57287132023-02-06 14:54:18 +00004922 "fft2d": {
4923 "op": Op.FFT2D,
4924 "operands": (2, 0),
4925 "rank": (3, 3),
4926 "build_fcn": (
4927 build_fft2d,
4928 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004929 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004930 TosaArgGen.agFFT2d,
4931 ),
4932 "types": [DType.FP32],
4933 "error_if_validators": (
4934 TosaErrorValidator.evWrongInputType,
4935 TosaErrorValidator.evWrongOutputType,
4936 TosaErrorValidator.evWrongInputList,
4937 TosaErrorValidator.evWrongOutputList,
4938 TosaErrorValidator.evWrongRank,
4939 TosaErrorValidator.evBatchMismatch,
4940 TosaErrorValidator.evKernelNotPowerOfTwo,
4941 TosaErrorValidator.evFFTInputShapeMismatch,
4942 TosaErrorValidator.evFFTOutputShapeMismatch,
4943 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004944 "data_gen": DOT_PRODUCT_DATAGEN,
Luke Hutton57287132023-02-06 14:54:18 +00004945 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004946 "rfft2d": {
4947 "op": Op.RFFT2D,
4948 "operands": (1, 0),
4949 "rank": (3, 3),
4950 "build_fcn": (
4951 build_rfft2d,
4952 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004953 TosaTensorValuesGen.tvgLazyGenDefault,
4954 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004955 ),
4956 "types": [DType.FP32],
4957 "error_if_validators": (
4958 TosaErrorValidator.evWrongInputType,
4959 TosaErrorValidator.evWrongOutputType,
4960 TosaErrorValidator.evWrongInputList,
4961 TosaErrorValidator.evWrongOutputList,
4962 TosaErrorValidator.evWrongRank,
4963 TosaErrorValidator.evBatchMismatch,
4964 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004965 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004966 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004967 "data_gen": DOT_PRODUCT_DATAGEN,
Luke Hutton261b7b62023-01-10 14:50:31 +00004968 },
Won Jeon74342e52024-01-09 00:34:40 +00004969 # Shape
4970 "add_shape": {
4971 "op": Op.ADD_SHAPE,
4972 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004973 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004974 "build_fcn": (
4975 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004976 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004977 TosaTensorValuesGen.tvgAddSub,
4978 TosaArgGen.agNone,
4979 ),
4980 "types": [DType.SHAPE],
4981 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4982 },
4983 "sub_shape": {
4984 "op": Op.SUB_SHAPE,
4985 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004986 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004987 "build_fcn": (
4988 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004989 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004990 TosaTensorValuesGen.tvgAddSub,
4991 TosaArgGen.agNone,
4992 ),
4993 "types": [DType.SHAPE],
4994 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4995 },
4996 "mul_shape": {
4997 "op": Op.MUL_SHAPE,
4998 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004999 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005000 "build_fcn": (
5001 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005002 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005003 TosaTensorValuesGen.tvgMul,
5004 TosaArgGen.agNone,
5005 ),
5006 "types": [DType.SHAPE],
5007 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5008 },
5009 "div_shape": {
5010 "op": Op.DIV_SHAPE,
5011 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005012 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005013 "build_fcn": (
5014 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005015 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005016 TosaTensorValuesGen.tvgIntDiv,
5017 TosaArgGen.agNone,
5018 ),
5019 "types": [DType.SHAPE],
5020 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5021 },
5022 "concat_shape": {
5023 "op": Op.CONCAT_SHAPE,
5024 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005025 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005026 "build_fcn": (
5027 build_concat,
5028 TosaTensorGen.tgConcat,
5029 TosaTensorValuesGen.tvgConcat,
5030 TosaArgGen.agNone,
5031 ),
5032 "types": [DType.SHAPE],
5033 "error_if_validators": (),
5034 },
5035 "const_shape": {
5036 "op": Op.CONST_SHAPE,
5037 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005038 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005039 "build_fcn": (
5040 build_const,
5041 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00005042 TosaTensorValuesGen.tvgLazyGenDefault,
5043 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00005044 ),
5045 "types": [DType.SHAPE],
5046 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005047 }
5048
Kevin Cheng550ccc52021-03-03 11:21:43 -08005049
Eric Kunzee5e26762020-10-13 16:11:07 -07005050class OutputShaper:
5051 # Methods in this class compute the expected output shape and datatype
5052 # for common classes of operations
5053 def __init__(self):
5054 pass
5055
5056 # These methods return arguments that can be used for
5057 # creating a new output tensor
5058 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005059 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5060 if error_name != ErrorIf.RankMismatch:
5061 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005062 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005063
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005064 # Work out broadcasted output shape (when not ERRORIF test)
Eric Kunzee5e26762020-10-13 16:11:07 -07005065 shape = []
5066 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005067 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005068 shape.append(b.shape[i])
5069 else:
5070 shape.append(a.shape[i])
5071
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005072 if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
5073 # Can only create this error for rank > 0
5074 fuzz_idx = rng.integers(0, len(shape))
Jerry Ge135c9552023-05-23 20:59:32 +00005075 shape[fuzz_idx] += 1
5076
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005077 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005078 all_dtypes = [
5079 DType.INT8,
5080 DType.INT16,
5081 DType.INT32,
5082 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005083 DType.FP16,
5084 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005085 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005086 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005087 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5088 outputDType = rng.choice(wrong_dtypes)
5089 else:
5090 outputDType = a.dtype
5091
5092 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005093
5094 @staticmethod
5095 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005096 assert len(a.shape) == len(b.shape)
5097 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005098
5099 shape = []
5100 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005101 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005102 shape.append(a.shape[i])
5103
Kevin Cheng550ccc52021-03-03 11:21:43 -08005104 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005105
5106 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005107 def unaryOp(ser, rng, a, error_name=None):
5108 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005109 all_dtypes = [
5110 DType.INT8,
5111 DType.INT16,
5112 DType.INT32,
5113 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005114 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005115 DType.FP16,
5116 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005117 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005118 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5119 outputDType = rng.choice(wrong_dtypes)
5120 else:
5121 outputDType = a.dtype
5122
5123 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005124
5125 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005126 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005127 if error_name != ErrorIf.RankMismatch:
5128 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005129 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005130
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005131 # Work out broadcasted output shape (when not ERRORIF test)
Eric Kunzee5e26762020-10-13 16:11:07 -07005132 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005133 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005134 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005135 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5136 else:
5137 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005138
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005139 if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
5140 # Can only create this error for rank > 0
5141 fuzz_idx = rng.integers(0, len(shape))
Jerry Ge135c9552023-05-23 20:59:32 +00005142 shape[fuzz_idx] += 1
5143
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005144 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005145 all_dtypes = [
5146 DType.INT8,
5147 DType.INT16,
5148 DType.INT32,
5149 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005150 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005151 DType.FP16,
5152 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005153 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005154 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5155 outputDType = rng.choice(wrong_dtypes)
5156 else:
5157 outputDType = a.dtype
5158
5159 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005160
5161 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005162 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005163 if error_name != ErrorIf.RankMismatch:
5164 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005165 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005166
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005167 # Work out broadcasted output shape
Eric Kunzee5e26762020-10-13 16:11:07 -07005168 shape = []
5169 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005170 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005171 shape.append(b.shape[i])
5172 else:
5173 shape.append(a.shape[i])
5174
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005175 if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
5176 # Can only create this error for rank > 0
5177 fuzz_idx = rng.integers(0, len(shape))
Jerry Ge135c9552023-05-23 20:59:32 +00005178 shape[fuzz_idx] += 1
5179
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005180 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005181 wrong_dtypes = [
5182 DType.INT8,
5183 DType.INT16,
5184 DType.INT32,
5185 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005186 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005187 DType.FP16,
5188 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005189 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005190 outputDType = rng.choice(wrong_dtypes)
5191 else:
5192 outputDType = DType.BOOL
5193
5194 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005195
5196 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005197 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005198 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005199 if error_name not in [
5200 ErrorIf.AxisSmallerZero,
5201 ErrorIf.AxisLargerRank,
5202 ErrorIf.ShapeOfAxisNotOne,
5203 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005204 shape[axis] = 1
5205 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5206 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005207
Matthew Haddond6ce7252021-09-29 15:35:44 +01005208 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005209 all_dtypes = [
5210 DType.INT8,
5211 DType.INT16,
5212 DType.INT32,
5213 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005214 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005215 DType.FP16,
5216 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005217 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005218 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5219 outputDType = rng.choice(wrong_dtypes)
5220 else:
5221 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005222
Matthew Haddond6ce7252021-09-29 15:35:44 +01005223 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005224
5225 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005226 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005227 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005228
5229 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5230 del shape[axis]
5231
5232 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5233 remove = rng.choice([True, False])
5234 if remove and len(shape) > 1:
5235 del shape[0]
5236 else:
5237 shape.append(1)
5238 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5239 for i in range(len(shape)):
5240 shape[i] = shape[i] + rng.integers(1, 10)
5241
5242 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005243 all_dtypes = [
5244 DType.INT8,
5245 DType.INT16,
5246 DType.INT32,
5247 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005248 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005249 DType.FP16,
5250 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005251 DType.FP8E4M3,
5252 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005253 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005254 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5255 outputDType = rng.choice(wrong_dtypes)
5256 else:
5257 outputDType = DType.INT32
5258
5259 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005260
5261 @staticmethod
Tai Lyf36f2562024-03-14 16:21:29 +00005262 def _get_conv_output_type(input_dtype):
5263 if input_dtype in (DType.FP16, DType.BF16, DType.FP32):
5264 return input_dtype
5265 elif input_dtype in (DType.FP8E4M3, DType.FP8E5M2):
5266 return DType.FP16
5267 elif input_dtype in (DType.INT8, DType.INT4):
5268 return DType.INT32
5269 elif input_dtype in (DType.INT16,):
5270 return DType.INT48
5271 assert True, f"Unsupported convolution data type {input_dtype}"
5272
5273 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005274 def conv2dOp(
5275 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5276 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005277
5278 # IFM: NHWC
5279 # Filter: OHWI
5280 # OFM: NHWC
5281
Kevin Cheng550ccc52021-03-03 11:21:43 -08005282 h = (
5283 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005284 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005285 + padding[0]
5286 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005287 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005288 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005289
Kevin Cheng550ccc52021-03-03 11:21:43 -08005290 w = (
5291 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005292 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005293 + padding[2]
5294 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005295 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005296 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005297
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005298 if error_name == ErrorIf.ConvOutputShapeMismatch:
5299 choices = [1, 2, 3]
5300 change = rng.choice(choices)
5301 # increment in multiples of stride to not hit non-integer error case
5302 if change in [1, 3]:
5303 h = h + (rng.choice(choices) * strides[0])
5304 if change in [2, 3]:
5305 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005306
Eric Kunzee5e26762020-10-13 16:11:07 -07005307 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5308
James Ward8b390432022-08-12 20:48:56 +01005309 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005310 # Pick some potentially correct output dtype if input type is incorrect
5311 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005312 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005313 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005314
5315 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005316 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005317 excludes = [DType.FP16, DType.FP32]
Jeremy Johnson80fd9b82024-03-12 11:46:50 +00005318 elif ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
Won Jeon2c34b462024-02-06 18:37:00 +00005319 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005320 else:
5321 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005322 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005323 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005324
Kevin Cheng550ccc52021-03-03 11:21:43 -08005325 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005326
5327 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005328 def conv3dOp(
5329 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5330 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005331
5332 # IFM: NDHWC
5333 # Filter: ODHWI
5334 # OFM: NDHWC
5335
5336 d = (
5337 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005338 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005339 + padding[0]
5340 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005341 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005342 ) // strides[0] + 1
5343
5344 h = (
5345 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005346 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005347 + padding[2]
5348 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005349 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005350 ) // strides[1] + 1
5351
5352 w = (
5353 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005354 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005355 + padding[4]
5356 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005357 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005358 ) // strides[2] + 1
5359
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005360 if error_name == ErrorIf.ConvOutputShapeMismatch:
5361 choices = [1, 2, 3, 4]
5362 change = rng.choice(choices)
5363 # increment in multiples of stride to not hit non-integer error case
5364 if change in [1, 4]:
5365 d = d + (rng.choice(choices) * strides[0])
5366 if change in [2, 4]:
5367 h = h + (rng.choice(choices) * strides[1])
5368 if change in [3, 4]:
5369 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005370
Kevin Cheng1533b852021-09-01 12:51:58 -07005371 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5372
James Ward8b390432022-08-12 20:48:56 +01005373 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005374 # Pick some potentially correct output dtype if input type is incorrect
5375 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005376 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005377 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005378
5379 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005380 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005381 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005382 else:
5383 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005384 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005385 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005386
5387 return ser.addOutput(ofm_shape, out_dtype)
5388
5389 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005390 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005391 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005392 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005393 # IFM: NHWC
5394 # Filter: HWCM
5395 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005396
Kevin Cheng550ccc52021-03-03 11:21:43 -08005397 h = (
5398 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005399 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005400 + padding[0]
5401 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005402 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005403 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005404
Kevin Cheng550ccc52021-03-03 11:21:43 -08005405 w = (
5406 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005407 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005408 + padding[2]
5409 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005410 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005411 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005412
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005413 if error_name == ErrorIf.ConvOutputShapeMismatch:
5414 choices = [1, 2, 3]
5415 change = rng.choice(choices)
5416 # increment in multiples of stride to not hit non-integer error case
5417 if change in [1, 3]:
5418 h = h + (rng.choice(choices) * strides[0])
5419 if change in [2, 3]:
5420 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005421
Eric Kunzee5e26762020-10-13 16:11:07 -07005422 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5423
James Ward8b390432022-08-12 20:48:56 +01005424 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005425 # Pick some potentially correct output dtype if input type is incorrect
5426 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005427 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005428 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005429
5430 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005431 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005432 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005433 else:
5434 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005435 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005436 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005437
Kevin Cheng550ccc52021-03-03 11:21:43 -08005438 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005439
5440 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005441 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005442 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005443 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005444 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005445 h = 1
5446 w = 1
5447 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005448 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5449 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005450
5451 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005452 choices = [1, 2, 3]
5453 change = rng.choice(choices)
5454 # increment in multiples of stride to not hit non-integer error case
5455 if change in [1, 3]:
5456 h = h + (rng.choice(choices) * stride[0])
5457 if change in [2, 3]:
5458 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005459 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005460
5461 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005462 all_dtypes = [
5463 DType.INT8,
5464 DType.INT16,
5465 DType.INT32,
5466 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005467 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005468 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005469 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005470 DType.FP8E4M3,
5471 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005472 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005473 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5474 outputDType = rng.choice(wrong_dtypes)
5475 else:
5476 outputDType = ifm.dtype
5477
5478 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005479
5480 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005481 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005482 # input: N, IC
5483 # filter: OC, IC
5484 # output: N, OC
5485
5486 output_shape = [input.shape[0], filter.shape[0]]
5487
James Ward8b390432022-08-12 20:48:56 +01005488 # Validated in arg_gen (also invalidated for ErrorIf)
5489 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005490
Kevin Cheng550ccc52021-03-03 11:21:43 -08005491 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005492
5493 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005494 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005495 # a: N, H, C
5496 # b: N, C, W
5497 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005498
Kevin Cheng2d60f002021-06-09 14:18:32 -07005499 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005500
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005501 if error_name == ErrorIf.WrongOutputType:
5502 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005503 incorrect_types = (
5504 DType.INT4,
5505 DType.INT8,
5506 DType.INT16,
5507 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005508 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005509 DType.FP16,
5510 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005511 DType.FP8E4M3,
5512 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005513 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005514 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005515 incorrect_types = (
5516 DType.INT4,
5517 DType.INT8,
5518 DType.INT16,
5519 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005520 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005521 DType.FP16,
5522 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005523 DType.FP8E4M3,
5524 DType.FP8E5M2,
5525 )
5526 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5527 incorrect_types = (
5528 DType.INT4,
5529 DType.INT8,
5530 DType.INT16,
5531 DType.INT32,
5532 DType.INT48,
5533 DType.FP32,
5534 DType.BF16,
5535 DType.FP8E4M3,
5536 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005537 )
James Ward24dbc422022-10-19 12:20:31 +01005538 elif (
5539 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5540 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005541 incorrect_types = (
5542 DType.INT4,
5543 DType.INT8,
5544 DType.INT16,
5545 DType.INT32,
5546 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005547 DType.FP8E4M3,
5548 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005549 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005550 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005551 elif error_name == ErrorIf.WrongInputType:
5552 # Pick some potentially correct output dtype if input type is incorrect
5553 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005554 else:
James Ward8b390432022-08-12 20:48:56 +01005555 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005556
Kevin Cheng550ccc52021-03-03 11:21:43 -08005557 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005558
5559 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005560 def concatOp(ser, rng, axis, inputs, error_name=None):
5561 input1 = inputs[0]
5562 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005563
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005564 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005565 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005566 if not (
5567 # unable to concat tensors of different ranks
5568 error_name == ErrorIf.ConcatInputRankMismatch
5569 # unable to concat tensors along an invalid axis
5570 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005571 ):
5572 for tensor in remaining_inputs:
5573 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005574
Matthew Haddon01c359d2021-10-15 16:30:48 +01005575 if error_name == ErrorIf.ConcatShapeSumMismatch:
5576 output_shape[axis] += rng.integers(5, 10)
5577
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005578 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005579 all_dtypes = {
5580 DType.INT8,
5581 DType.INT16,
5582 DType.INT32,
5583 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005584 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005585 DType.FP16,
5586 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005587 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005588 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5589 outputDType = rng.choice(wrong_dtypes)
5590 else:
5591 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005592
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005593 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005594
5595 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005596 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005597
5598 output_shape = a.shape.copy()
5599
5600 for i in range(len(output_shape)):
5601 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5602
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005603 if error_name == ErrorIf.PadOutputShapeMismatch:
5604 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005605 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005606 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005607 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005608
Matthew Haddone807aae2021-10-11 18:12:58 +01005609 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005610 all_dtypes = [
5611 DType.INT8,
5612 DType.INT16,
5613 DType.INT32,
5614 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005615 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005616 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005617 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005618 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005619 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5620 outputDType = rng.choice(wrong_dtypes)
5621 else:
5622 outputDType = a.dtype
5623
5624 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005625
5626 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005627 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005628 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005629
5630 if error_name == ErrorIf.WrongOutputType:
5631 all_dtypes = [
5632 DType.INT8,
5633 DType.INT16,
5634 DType.INT32,
5635 DType.INT48,
5636 DType.FP32,
5637 DType.FP16,
5638 DType.BF16,
5639 ]
5640 wrong_dtypes = list(set(all_dtypes))
5641 outputDType = rng.choice(wrong_dtypes)
5642 else:
5643 outputDType = DType.SHAPE
5644
5645 return ser.addOutput(output_shape, outputDType)
5646
5647 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005648 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005649 output_shape = shape.copy()
5650
Matthew Haddone807aae2021-10-11 18:12:58 +01005651 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5652 for i in range(len(output_shape)):
5653 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5654
5655 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005656 all_dtypes = [
5657 DType.INT8,
5658 DType.INT16,
5659 DType.INT32,
5660 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005661 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005662 DType.FP16,
5663 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005664 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005665 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5666 outputDType = rng.choice(wrong_dtypes)
5667 else:
5668 outputDType = a.dtype
5669
5670 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005671
5672 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005673 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005674
Matthew Haddone807aae2021-10-11 18:12:58 +01005675 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005676 all_dtypes = [
5677 DType.INT8,
5678 DType.INT16,
5679 DType.INT32,
5680 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005681 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005682 DType.FP16,
5683 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005684 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005685 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005686 outputDType = rng.choice(wrong_dtypes)
5687 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005688 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005689
Luke Huttona4e48ca2023-02-22 11:53:48 +00005690 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005691 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005692 for index in range(len(output_shape)):
5693 if output_shape[index] <= 2:
5694 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5695 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005696 output_shape[index] = output_shape[index] + rng.choice(
5697 [-2, -1, 1, 2]
5698 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005699 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5700 output_shape = input.shape.copy()
5701 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005702 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005703
5704 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005705
5706 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005707 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005708
5709 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005710 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005711
5712 for i in range(len(output_shape)):
5713 output_shape[i] = a.shape[i] * multiples[i]
5714
Luke Huttona4e48ca2023-02-22 11:53:48 +00005715 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005716 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005717
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005718 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005719 all_dtypes = [
5720 DType.INT8,
5721 DType.INT16,
5722 DType.INT32,
5723 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005724 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005725 DType.FP16,
5726 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005727 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005728 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5729 outputDType = rng.choice(wrong_dtypes)
5730 else:
5731 outputDType = a.dtype
5732
5733 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005734
5735 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005736 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005737 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005738
Kevin Cheng550ccc52021-03-03 11:21:43 -08005739 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005740
Luke Huttona4e48ca2023-02-22 11:53:48 +00005741 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005742 for i in range(len(output_shape)):
5743 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005744
Luke Huttona4e48ca2023-02-22 11:53:48 +00005745 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5746 for i in range(len(output_shape)):
5747 output_shape[i] += rng.integers(1, 10)
5748 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005749 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005750
Matthew Haddone807aae2021-10-11 18:12:58 +01005751 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005752 all_dtypes = [
5753 DType.INT8,
5754 DType.INT16,
5755 DType.INT32,
5756 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005757 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005758 DType.FP16,
5759 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005760 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005761 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5762 outputDType = rng.choice(wrong_dtypes)
5763 else:
5764 outputDType = a.dtype
5765
5766 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005767
5768 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005769 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005770 if error_name != ErrorIf.WrongRank:
5771 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005772 assert len(indices.shape) == 2
5773 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005774
Kevin Cheng77d0f762020-11-24 10:26:32 -08005775 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5776
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005777 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005778 all_dtypes = [
5779 DType.INT8,
5780 DType.INT16,
5781 DType.INT32,
5782 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005783 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005784 DType.FP16,
5785 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005786 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005787 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5788 outputDType = rng.choice(wrong_dtypes)
5789 else:
5790 outputDType = values.dtype
5791
5792 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005793
5794 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005795 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005796 if error_name != ErrorIf.WrongRank:
5797 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005798 assert len(indices.shape) == 2
5799 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005800 assert values_in.shape[0] == indices.shape[0] # N
5801 assert input.shape[1] == indices.shape[1] # W
5802 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005803
5804 output_shape = values_in.shape
5805
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005806 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005807 all_dtypes = [
5808 DType.INT8,
5809 DType.INT16,
5810 DType.INT32,
5811 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005812 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005813 DType.FP16,
5814 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005815 DType.FP8E4M3,
5816 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005817 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005818 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5819 outputDType = rng.choice(wrong_dtypes)
5820 else:
5821 outputDType = values_in.dtype
5822
5823 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005824
5825 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005826 def tableOp(ser, rng, input, error_name=None):
5827 # Same shape as the input, dtype dependent on input dtype
5828 if error_name != ErrorIf.WrongInputType:
5829 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005830 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005831 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005832 wrong_dtypes = [
5833 DType.INT8,
5834 DType.INT16,
5835 DType.INT32,
5836 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005837 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005838 DType.FP16,
5839 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005840 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005841 wrong_dtypes.remove(output_dtype)
5842 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005843 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005844
5845 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005846 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005847 serializer,
5848 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005849 input,
5850 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005851 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005852 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005853 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005854 input_dtype,
5855 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005856 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005857 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005858 # Calculate OH, OW
5859 scale_y_n = scale[0]
5860 scale_y_d = scale[1]
5861 scale_x_n = scale[2]
5862 scale_x_d = scale[3]
5863 if error_name == ErrorIf.ScaleSmallerEqualZero:
5864 scale_y_n = max(scale_y_n, 1)
5865 scale_y_d = max(scale_y_d, 1)
5866 scale_x_n = max(scale_x_n, 1)
5867 scale_x_d = max(scale_x_d, 1)
5868
5869 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5870 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5871
5872 if error_name is not None:
5873 # Make sure the output tensor is valid, which can occur when
5874 # scale, offset or border have been changed for ERROR_IFs
5875 oh = max(oh, 1)
5876 ow = max(ow, 1)
5877 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005878 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5879 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005880
5881 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5882 choices = [1, 2, 3]
5883 change = rng.choice(choices)
5884 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5885 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005886 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005887 oh -= scale_y_d
5888 assert oh > 0 # Should have been caught in agResize
5889 else:
5890 oh += scale_y_d
5891 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005892 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005893 ow -= scale_x_d
5894 assert ow > 0 # Should have been caught in agResize
5895 else:
5896 ow += scale_x_d
5897
Matthew Haddon848efb42021-09-09 12:30:53 +01005898 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005899 output_dims = [
5900 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005901 oh,
5902 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005903 input.shape[0],
5904 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005905 elif error_name == ErrorIf.BatchMismatch:
5906 output_dims = [
5907 input.shape[0] + rng.integers(1, 10),
5908 oh,
5909 ow,
5910 input.shape[3],
5911 ]
5912 elif error_name == ErrorIf.ChannelMismatch:
5913 output_dims = [
5914 input.shape[0],
5915 oh,
5916 ow,
5917 input.shape[3] + rng.integers(1, 10),
5918 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005919 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005920 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005921
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005922 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005923
5924 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005925 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005926 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005927
5928 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005929 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005930 if error_name == ErrorIf.ConvOutputShapeMismatch:
5931 choices = [1, 2, 3]
5932 change = rng.choice(choices)
5933 if change in [1, 3]:
5934 output_shape[1] = output_shape[1] + rng.choice(choices)
5935 if change in [2, 3]:
5936 output_shape[2] = output_shape[2] + rng.choice(choices)
5937
James Ward8b390432022-08-12 20:48:56 +01005938 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005939 # Pick some potentially correct output dtype if input type is incorrect
5940 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005941 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005942 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005943
5944 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005945 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005946 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005947 else:
5948 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005949 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005950 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005951
Kevin Cheng550ccc52021-03-03 11:21:43 -08005952 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005953
5954 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005955 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5956 outputs = []
5957
5958 assert ifm1.dtype == ifm2.dtype
5959 input_dtype = ifm1.dtype
5960
5961 if error_name != ErrorIf.FFTInputShapeMismatch:
5962 assert ifm1.shape == ifm2.shape
5963
5964 input_shape = ifm1.shape
5965 if error_name != ErrorIf.WrongRank:
5966 assert len(input_shape) == 3
5967
5968 output_shape = input_shape.copy()
5969 output_dtype = input_dtype
5970
5971 if error_name == ErrorIf.WrongOutputType:
5972 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005973 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005974 output_dtype = rng.choice(wrong_dtypes)
5975 elif error_name == ErrorIf.BatchMismatch:
5976 output_shape[0] += rng.integers(1, 10)
5977 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5978 modify_dim = rng.choice([1, 2])
5979 output_shape[modify_dim] += rng.integers(1, 10)
5980
5981 outputs.append(serializer.addOutput(output_shape, output_dtype))
5982 outputs.append(serializer.addOutput(output_shape, output_dtype))
5983 return outputs
5984
5985 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005986 def rfft2dOp(serializer, rng, value, error_name=None):
5987 outputs = []
5988
5989 input_shape = value.shape
5990 if error_name != ErrorIf.WrongRank:
5991 assert len(input_shape) == 3
5992
5993 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5994
5995 output_dtype = value.dtype
5996 if error_name == ErrorIf.WrongOutputType:
5997 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005998 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005999 output_dtype = rng.choice(wrong_dtypes)
6000 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00006001 output_shape[0] += rng.integers(1, 10)
6002 elif error_name == ErrorIf.FFTOutputShapeMismatch:
6003 modify_dim = rng.choice([1, 2])
6004 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00006005
6006 outputs.append(serializer.addOutput(output_shape, output_dtype))
6007 outputs.append(serializer.addOutput(output_shape, output_dtype))
6008 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00006009
6010 @staticmethod
6011 def addShapeOp(ser, rng, a, b, error_name=None):
6012 if error_name != ErrorIf.RankMismatch:
6013 assert len(a.shape) == len(b.shape)
6014 assert a.dtype == b.dtype
6015
Jeremy Johnson18a379d2024-03-28 15:53:21 +00006016 shape = a.shape.copy()
Won Jeon74342e52024-01-09 00:34:40 +00006017
Jeremy Johnson18a379d2024-03-28 15:53:21 +00006018 # Do not expect rank 0 tests!
6019 assert len(shape) > 0
Won Jeon74342e52024-01-09 00:34:40 +00006020 if error_name == ErrorIf.DimensionMismatch:
Jeremy Johnson18a379d2024-03-28 15:53:21 +00006021 # Can only create this error for rank > 0
6022 fuzz_idx = rng.integers(0, len(shape))
Won Jeon74342e52024-01-09 00:34:40 +00006023 shape[fuzz_idx] += 1
6024
6025 if error_name == ErrorIf.WrongOutputType:
6026 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
6027 outputDType = rng.choice(wrong_dtypes)
6028 else:
6029 outputDType = DType.SHAPE
6030 return ser.addOutput(shape, outputDType)