blob: 68a4e9412c574e8e686143c86e798d2fe474fcfb [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
79 for dtype in (DType.FP32, DType.FP16, DType.BF16):
80 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000198 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 return np.int64(self.rng.integers(low=low, high=high, size=shape))
200 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
201 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
202
203 if dtype == DType.FP16:
204 return np.float16(f_tensor)
205 else:
206 f32_tensor = np.float32(f_tensor)
207 if dtype == DType.BF16:
208 # Floor the last 16 bits of each f32 value
209 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
210 else:
211 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100213 # All other integer types
214 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
Kevin Cheng989cb052021-04-28 16:29:44 -0700216 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700217 placeholders = []
218
Kevin Cheng989cb052021-04-28 16:29:44 -0700219 assert len(shape_list) == len(dtype_list)
220
Jeremy Johnson1271c442023-09-05 11:39:26 +0100221 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700222 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 if not self.args.lazy_data_gen:
224 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700225 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700226
227 return placeholders
228
Kevin Cheng989cb052021-04-28 16:29:44 -0700229 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700230 consts = []
231
Kevin Cheng989cb052021-04-28 16:29:44 -0700232 assert len(shape_list) == len(dtype_list)
233
Jeremy Johnson1271c442023-09-05 11:39:26 +0100234 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700235 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100236 if not self.args.lazy_data_gen:
237 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700238 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700239
240 return consts
241
242 def makeShape(self, rank):
243 if self.targetted_shape:
244 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 return np.int32(
246 self.rng.integers(
247 low=self.args.tensor_shape_range[0],
248 high=self.args.tensor_shape_range[1],
249 size=rank,
250 )
251 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700252
253 def setTargetShape(self, shape):
254 self.targetted_shape = shape
255
256 def randInt(self, low=0, high=256):
257 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
258
259 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 low, high = self.getDTypeRange(dtype)
261
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100262 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100263 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100264 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100265 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100266 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100267 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
268 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700269 elif dtype == DType.BOOL:
270 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000271 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700272 # Special size
273 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 return np.int32(self.rng.integers(low, high, size=1))[0]
276
277 def shapeStr(self, shape):
278
279 sStr = []
280 # Convert to strings
281 for i in shape:
282 sStr.append(str(i))
283
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100286 def typeStr(self, dtype):
287 if isinstance(dtype, list) or isinstance(dtype, tuple):
288 assert len(dtype) >= 2
289 strs = [self.typeStr(t) for t in dtype]
290 # Limit types to the first 2 as the 3rd is the accumulator
291 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100293 if dtype in gtu.DTYPE_ATTRIBUTES:
294 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700295 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100296 raise Exception(
297 "Unknown dtype, cannot convert to string: {}".format(dtype)
298 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100300 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100301 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100302 if dtype in gtu.DTYPE_ATTRIBUTES:
303 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100305 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700306
Luke Hutton57287132023-02-06 14:54:18 +0000307 def constrictBatchSize(self, shape):
308 # Limit the batch size unless an explicit target shape set
309 if self.args.max_batch_size and not self.args.target_shapes:
310 shape[0] = min(shape[0], self.args.max_batch_size)
311 return shape
312
James Ward30124a82023-02-02 14:56:33 +0000313 def makeDimension(self):
314 return self.randInt(
315 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
316 )
317
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100318 def tensorComplianceMetaData(
319 self, op, inputType, argsDict, outputTensor, errorName
320 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000321 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
322 UNSUPPORTED_NON_FP32_INPUT_OPS = (
323 Op.MATMUL,
324 Op.CONV2D,
325 Op.FULLY_CONNECTED,
326 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000327 Op.TRANSPOSE_CONV2D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000328 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100329 if (
330 errorName
331 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000332 or (
333 not gtu.dtypeIsSupportedByCompliance(inputType)
334 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
335 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100336 ):
337 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100338 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100339
Jeremy Johnson1271c442023-09-05 11:39:26 +0100340 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100341 compliance_tens = {
342 "mode": None,
343 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
344 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
345 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100346 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
347 mode = gtu.ComplianceMode.DOT_PRODUCT
348 compliance_tens["dot_product_info"] = {
349 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100350 "ks": int(argsDict["ksb"])
351 if "ksb" in argsDict
352 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100353 }
354 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
355 mode = gtu.ComplianceMode.FP_SPECIAL
356 elif "compliance" in op and "ulp" in op["compliance"]:
357 mode = gtu.ComplianceMode.ULP
358 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
359 elif op["op"] == Op.REDUCE_PRODUCT:
360 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000361 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000362 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000363 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000364 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
365 compliance_tens["abs_error_info"] = {
366 "lower_bound": op["compliance"]["abs_error_lower_bound"]
367 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100368 else:
369 mode = gtu.ComplianceMode.EXACT
370 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
371
372 return compliance_tens
373
374 # Build Op functions
375 # Create the output tensor (calling OutputShaper as needed)
376 # Do final tweaks to attributes (if necessary for errorIf)
377 # Add Op into graph
378 # Return resulting tensor information or BuildInfo
379
380 class BuildInfo:
381 """Enhanced build information containing result tensor and associated compliance dict."""
382
383 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000384 if isinstance(resultTensor, list):
385 assert complianceDict is None or isinstance(complianceDict, list)
386 self.resultTensorList = resultTensor
387 self.complianceDictList = complianceDict
388 else:
389 self.resultTensorList = [resultTensor]
390 if complianceDict is None:
391 self.complianceDictList = None
392 else:
393 self.complianceDictList = [complianceDict]
394
395 def getComplianceInfo(self):
396 if self.complianceDictList is None:
397 return None
398 else:
399 tens_dict = {}
400 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
401 if comp is not None:
402 tens_dict[tens.name] = comp
403
404 if tens_dict:
405 # Have some compliance data, so return the info
406 compliance = {
407 "version": "0.1",
408 "tensors": tens_dict,
409 }
410 else:
411 compliance = None
412 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700413
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000414 def build_unary(
415 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
416 ):
417 assert len(inputs) == 1
418 a = inputs[0]
419 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100420
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000421 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100422
423 # Ensure new output type has correct qinfo
424 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000425 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000426 qinfo = [
427 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000428 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000429 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100430
431 # Invalidate Input/Output list for error if checks.
432 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000433 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100434 pCount, cCount = op["operands"]
435 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000436 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
437 self, error_name, input_list, output_list
438 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100439
Les Bell729b0352021-11-24 10:28:21 +0000440 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100441 self.ser,
442 validator_fcns,
443 error_name,
444 op=op,
445 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000446 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000447 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000448 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100449 input_list=input_list,
450 output_list=output_list,
451 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000452 ):
453 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100454
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000455 attr = None
456 if op["op"] == Op.NEGATE:
457 attr = ts.TosaSerializerAttribute()
458 attr.NegateAttribute(qinfo[0], qinfo[1])
459
460 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000461
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000462 compliance = self.tensorComplianceMetaData(
463 op, a.dtype, args_dict, result_tensor, error_name
464 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000465 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700466
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000467 def build_binary_broadcast(
468 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
469 ):
470 assert len(inputs) == 2
471 a, b = inputs
472 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000473 self.ser, self.rng, a, b, error_name
474 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100475
476 # Invalidate Input/Output list for error if checks.
477 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000478 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100479 pCount, cCount = op["operands"]
480 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000481 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
482 self, error_name, input_list, output_list
483 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100484
Les Bell729b0352021-11-24 10:28:21 +0000485 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100486 self.ser,
487 validator_fcns,
488 error_name,
489 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000490 input1=a,
491 input2=b,
492 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000493 output_dtype=result_tensor.dtype,
494 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100495 input_list=input_list,
496 output_list=output_list,
497 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000498 ):
499 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100500
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000501 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000502
Jeremy Johnson9a758382023-11-07 16:27:35 +0000503 compliance = self.tensorComplianceMetaData(
504 op, a.dtype, args_dict, result_tensor, error_name
505 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000506
507 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700508
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100509 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700510 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000511 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700512 return result_tens
513
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000514 def build_arithmetic_right_shift(
515 self, op, a, b, round, validator_fcns=None, error_name=None
516 ):
517 result_tens = OutputShaper.binaryBroadcastOp(
518 self.ser, self.rng, a, b, error_name
519 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100520
521 # Invalidate Input/Output list for error if checks.
522 input_list = [a.name, b.name]
523 output_list = [result_tens.name]
524 pCount, cCount = op["operands"]
525 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000526 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
527 self, error_name, input_list, output_list
528 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100529
Les Bell729b0352021-11-24 10:28:21 +0000530 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100531 self.ser,
532 validator_fcns,
533 error_name,
534 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000535 input1=a,
536 input2=b,
537 input_dtype=a.dtype,
538 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000539 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100540 input_list=input_list,
541 output_list=output_list,
542 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000543 ):
544 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800545
546 attr = ts.TosaSerializerAttribute()
547 attr.ArithmeticRightShiftAttribute(round)
548
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000549 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800550 return result_tens
551
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100552 def build_mul(
553 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
554 ):
555 assert len(inputs) == 2
556 a, b = inputs
557 shift = args_dict["shift"]
558
559 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000560 self.ser, self.rng, a, b, error_name
561 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700562
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100563 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100564 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100565 result_tensor.setDtype(DType.INT32)
566
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100567 if error_name == ErrorIf.WrongOutputType:
568 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
569 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100570 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100571
572 # Invalidate Input/Output list for error if checks.
573 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100574 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100575 pCount, cCount = op["operands"]
576 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000577 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
578 self, error_name, input_list, output_list
579 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100580
Les Bell729b0352021-11-24 10:28:21 +0000581 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100582 self.ser,
583 validator_fcns,
584 error_name,
585 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000586 input1=a,
587 input2=b,
588 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100589 output_dtype=result_tensor.dtype,
590 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100591 input_list=input_list,
592 output_list=output_list,
593 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000594 ):
595 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700596
Kevin Chengaee1fac2020-11-11 13:54:06 -0800597 attr = ts.TosaSerializerAttribute()
598 attr.MulAttribute(shift)
599
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000600 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100601
602 compliance = self.tensorComplianceMetaData(
603 op, a.dtype, args_dict, result_tensor, error_name
604 )
605
606 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700607
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100608 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
609 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700610
Kevin Chengfe392ce2021-10-18 21:51:55 +0000611 attr = ts.TosaSerializerAttribute()
612 attr.TableAttribute(table)
613
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100614 # Invalidate Input/Output list for error if checks.
615 input_list = [a.name]
616 output_list = [result_tens.name]
617 pCount, cCount = op["operands"]
618 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000619 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
620 self, error_name, input_list, output_list
621 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622
Les Bell729b0352021-11-24 10:28:21 +0000623 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100624 self.ser,
625 validator_fcns,
626 error_name,
627 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000628 input_shape=a.shape,
629 input_dtype=a.dtype,
630 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000631 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100632 input_list=input_list,
633 output_list=output_list,
634 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000635 ):
636 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100637
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000638 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700639
640 return result_tens
641
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000642 def build_select(
643 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
644 ):
645 assert len(inputs) == 3
646 cond, a, b = inputs
647
648 result_tensor = OutputShaper.selectOp(
649 self.ser, self.rng, cond, a, b, error_name
650 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100651
652 # Invalidate Input/Output list for error if checks.
653 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000654 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100655 pCount, cCount = op["operands"]
656 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000657 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
658 self, error_name, input_list, output_list
659 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100660
Les Bell729b0352021-11-24 10:28:21 +0000661 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100662 self.ser,
663 validator_fcns,
664 error_name,
665 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000666 input1=cond,
667 input2=a,
668 input3=b,
669 input_shape=a.shape,
670 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000671 output_dtype=result_tensor.dtype,
672 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100673 input_list=input_list,
674 output_list=output_list,
675 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000676 ):
677 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100678
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000679 self.ser.addOperator(
680 op["op"],
681 input_list,
682 output_list,
683 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000684 compliance = self.tensorComplianceMetaData(
685 op, a.dtype, args_dict, result_tensor, error_name
686 )
687
688 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700689
Jeremy Johnsona0150012023-11-15 15:52:06 +0000690 def build_comparison(
691 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
692 ):
693 assert len(inputs) == 2
694 a, b = inputs
695
696 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000697 self.ser, self.rng, a, b, error_name
698 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100699
700 # Invalidate Input/Output list for error if checks.
701 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000702 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100703 pCount, cCount = op["operands"]
704 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000705 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
706 self, error_name, input_list, output_list
707 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100708
Les Bell729b0352021-11-24 10:28:21 +0000709 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100710 self.ser,
711 validator_fcns,
712 error_name,
713 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000714 input1=a,
715 input2=b,
716 input_shape=a.shape,
717 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000718 output_shape=result_tensor.shape,
719 output_dtype=result_tensor.dtype,
720 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100721 input_list=input_list,
722 output_list=output_list,
723 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000724 ):
725 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100726
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000727 self.ser.addOperator(
728 op["op"],
729 input_list,
730 output_list,
731 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000732
733 compliance = self.tensorComplianceMetaData(
734 op, a.dtype, args_dict, result_tensor, error_name
735 )
736 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700737
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000738 def build_argmax(
739 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
740 ):
741 assert len(inputs) == 1
742 a = inputs[0]
743 axis = args_dict["axis"]
744 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100745
746 # Invalidate Input/Output list for error if checks.
747 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000748 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100749 pCount, cCount = op["operands"]
750 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000751 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
752 self, error_name, input_list, output_list
753 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100754
Les Bell729b0352021-11-24 10:28:21 +0000755 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100756 self.ser,
757 validator_fcns,
758 error_name,
759 op=op,
760 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000761 input_shape=a.shape,
762 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000763 output_shape=result_tensor.shape,
764 output_dtype=result_tensor.dtype,
765 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100766 input_list=input_list,
767 output_list=output_list,
768 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000769 ):
770 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700771
772 attr = ts.TosaSerializerAttribute()
773 attr.AxisAttribute(axis)
774
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000775 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000776
777 compliance = self.tensorComplianceMetaData(
778 op, inputs[0].dtype, args_dict, result_tensor, error_name
779 )
780 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700781
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000782 def build_pool2d(
783 self,
784 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100785 inputs,
786 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000787 validator_fcns=None,
788 error_name=None,
789 qinfo=None,
790 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100791 assert len(inputs) == 1
792 input = inputs[0]
793 # max_pool has no accum_dtype
794 accum_dtype = (
795 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
796 )
797 stride = args_dict["stride"]
798 pad = args_dict["pad"]
799 kernel = args_dict["kernel"]
800
Jeremy Johnson0601f802023-11-08 16:28:09 +0000801 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000802 self.ser, self.rng, input, kernel, stride, pad, error_name
803 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100804
805 # Ensure new output type has correct qinfo
806 if error_name == ErrorIf.WrongInputType:
807 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000808 qinfo = [
809 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000810 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000811 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100812
813 # Invalidate Input/Output list for error if checks.
814 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000815 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100816 pCount, cCount = op["operands"]
817 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000818 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
819 self, error_name, input_list, output_list
820 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100821
Les Bell729b0352021-11-24 10:28:21 +0000822 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100823 self.ser,
824 validator_fcns,
825 error_name,
826 op=op,
827 input_shape=input.shape,
828 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000829 output_shape=result_tensor.shape,
830 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100831 kernel=kernel,
832 stride=stride,
833 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000834 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000835 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100836 input_list=input_list,
837 output_list=output_list,
838 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000839 ):
840 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700841
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000842 if qinfo is None:
843 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700844
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000845 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100846 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000847
848 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700849
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100850 compliance = self.tensorComplianceMetaData(
851 op, inputs[0].dtype, args_dict, result_tensor, error_name
852 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100853
854 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100855
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000856 def build_conv2d(
857 self,
858 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100859 inputs,
860 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000861 validator_fcns=None,
862 error_name=None,
863 qinfo=None,
864 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100865 assert len(inputs) == 3
866 ifm, filter, bias = inputs
867 accum_dtype = args_dict["acc_type"]
868 strides = args_dict["stride"]
869 padding = args_dict["pad"]
870 dilations = args_dict["dilation"]
871
Kevin Cheng550ccc52021-03-03 11:21:43 -0800872 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100873 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100874 self.ser,
875 self.rng,
876 ifm,
877 filter,
878 accum_dtype,
879 strides,
880 padding,
881 dilations,
882 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000883 )
884
885 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000886 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
887 DType.INT8,
888 DType.UINT8,
889 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000890 qinfo = [
891 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100892 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000893 ]
Les Bell0e027d42021-11-09 14:42:14 +0000894
895 # Invalidate Input/Output list for error_if checks.
896 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100897 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000898 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000899 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
900 self, error_name, input_list, output_list
901 )
Les Bell0e027d42021-11-09 14:42:14 +0000902
Les Bell729b0352021-11-24 10:28:21 +0000903 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000904 self.ser,
905 validator_fcns,
906 error_name,
907 op=op,
908 input_dtype=ifm.dtype,
909 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100910 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000911 qinfo=qinfo,
912 input_list=input_list,
913 num_operands=num_operands,
914 output_list=output_list,
915 pad=padding,
916 stride=strides,
917 dilation=dilations,
918 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100919 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100920 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000921 ):
922 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700923
Tai Lyd3797f02023-11-15 23:06:19 +0000924 # TODO - Test local_bound, for now set local bound attribute to False
925 local_bound = False
926
Eric Kunzee5e26762020-10-13 16:11:07 -0700927 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000928 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700929
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000930 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100931
932 compliance = self.tensorComplianceMetaData(
933 op, ifm.dtype, args_dict, result_tensor, error_name
934 )
935
936 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700937
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000938 def build_conv3d(
939 self,
940 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100941 inputs,
942 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000943 validator_fcns=None,
944 error_name=None,
945 qinfo=None,
946 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100947 assert len(inputs) == 3
948 ifm, filter, bias = inputs
949 accum_dtype = args_dict["acc_type"]
950 strides = args_dict["stride"]
951 padding = args_dict["pad"]
952 dilations = args_dict["dilation"]
953
Kevin Cheng1533b852021-09-01 12:51:58 -0700954 assert len(padding) == 6
955 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100956 self.ser,
957 self.rng,
958 ifm,
959 filter,
960 accum_dtype,
961 strides,
962 padding,
963 dilations,
964 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000965 )
966
967 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000968 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
969 DType.INT8,
970 DType.UINT8,
971 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000972 qinfo = [
973 TosaQuantGen.getZeroPoint(self, ifm.dtype),
974 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
975 ]
Les Bell0e027d42021-11-09 14:42:14 +0000976
977 # Invalidate Input/Output list for error_if checks.
978 input_list = [ifm.name, filter.name, bias.name]
979 output_list = [result_tens.name]
980 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000981 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
982 self, error_name, input_list, output_list
983 )
Les Bell0e027d42021-11-09 14:42:14 +0000984
Les Bell729b0352021-11-24 10:28:21 +0000985 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000986 self.ser,
987 validator_fcns,
988 error_name,
989 op=op,
990 input_dtype=ifm.dtype,
991 weight_dtype=filter.dtype,
992 output_dtype=result_tens.dtype,
993 qinfo=qinfo,
994 input_list=input_list,
995 num_operands=num_operands,
996 output_list=output_list,
997 pad=padding,
998 stride=strides,
999 dilation=dilations,
1000 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001001 weight_shape=filter.shape,
1002 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001003 ):
1004 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07001005
Tai Lyd3797f02023-11-15 23:06:19 +00001006 # TODO - Test local_bound, for now set local bound attribute to False
1007 local_bound = False
1008
Kevin Cheng1533b852021-09-01 12:51:58 -07001009 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001010 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -07001011
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001012 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -07001013 return result_tens
1014
Kevin Cheng550ccc52021-03-03 11:21:43 -08001015 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001016 self,
1017 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001018 inputs,
1019 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001020 validator_fcns=None,
1021 error_name=None,
1022 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001023 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001024 assert len(inputs) == 3
1025 ifm, filter, bias = inputs
1026 accum_dtype = args_dict["acc_type"]
1027 strides = args_dict["stride"]
1028 out_pad = args_dict["pad"]
1029 output_shape = args_dict["out_shape"]
1030
TatWai Chong24594f52022-06-08 00:48:04 -07001031 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001032 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001033 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001034 )
Les Bell0e027d42021-11-09 14:42:14 +00001035
1036 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001037 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1038 DType.INT8,
1039 DType.UINT8,
1040 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001041 qinfo = [
1042 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001043 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001044 ]
Les Bell0e027d42021-11-09 14:42:14 +00001045
1046 # Invalidate Input/Output list for error_if checks.
1047 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001048 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001049 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001050 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1051 self, error_name, input_list, output_list
1052 )
Les Bell0e027d42021-11-09 14:42:14 +00001053
Les Bell729b0352021-11-24 10:28:21 +00001054 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001055 self.ser,
1056 validator_fcns,
1057 error_name,
1058 op=op,
1059 input_dtype=ifm.dtype,
1060 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001061 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001062 qinfo=qinfo,
1063 input_list=input_list,
1064 num_operands=num_operands,
1065 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001066 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001067 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001068 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001069 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001070 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001071 ):
1072 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001073
Tai Lyd3797f02023-11-15 23:06:19 +00001074 # TODO - Test local_bound, for now set local bound attribute to False
1075 local_bound = False
1076
Eric Kunzee5e26762020-10-13 16:11:07 -07001077 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001078 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001079 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001080 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001081
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001082 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001083
1084 compliance = self.tensorComplianceMetaData(
1085 op, ifm.dtype, args_dict, result_tensor, error_name
1086 )
1087
1088 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001089
Kevin Cheng550ccc52021-03-03 11:21:43 -08001090 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001091 self,
1092 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001093 inputs,
1094 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001095 validator_fcns=None,
1096 error_name=None,
1097 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001098 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001099 assert len(inputs) == 3
1100 ifm, filter, bias = inputs
1101 accum_dtype = args_dict["acc_type"]
1102 strides = args_dict["stride"]
1103 padding = args_dict["pad"]
1104 dilations = args_dict["dilation"]
1105
Jeremy Johnson4f931302024-01-04 17:05:24 +00001106 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001107 self.ser,
1108 self.rng,
1109 ifm,
1110 filter,
1111 accum_dtype,
1112 strides,
1113 padding,
1114 dilations,
1115 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001116 )
1117
1118 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001119 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1120 DType.INT8,
1121 DType.UINT8,
1122 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001123 qinfo = [
1124 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001125 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001126 ]
Les Bell0e027d42021-11-09 14:42:14 +00001127
1128 # Invalidate Input/Output list for error_if checks.
1129 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001130 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001131 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001132 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1133 self, error_name, input_list, output_list
1134 )
Les Bell0e027d42021-11-09 14:42:14 +00001135
Les Bell729b0352021-11-24 10:28:21 +00001136 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001137 self.ser,
1138 validator_fcns,
1139 error_name,
1140 op=op,
1141 input_dtype=ifm.dtype,
1142 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001143 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001144 qinfo=qinfo,
1145 input_list=input_list,
1146 num_operands=num_operands,
1147 output_list=output_list,
1148 pad=padding,
1149 stride=strides,
1150 dilation=dilations,
1151 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001152 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001153 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001154 ):
1155 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001156
Tai Lyd3797f02023-11-15 23:06:19 +00001157 # TODO - Test local_bound, for now set local bound attribute to False
1158 local_bound = False
1159
Eric Kunzee5e26762020-10-13 16:11:07 -07001160 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001161 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001162
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001163 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001164
1165 compliance = self.tensorComplianceMetaData(
1166 op, ifm.dtype, args_dict, result_tensor, error_name
1167 )
1168
1169 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001170
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001171 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001172 self,
1173 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001174 inputs,
1175 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001176 validator_fcns=None,
1177 error_name=None,
1178 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001179 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001180 assert len(inputs) == 3
1181 ifm, filter, bias = inputs
1182 accum_dtype = args_dict["acc_type"]
1183
1184 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001185 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001186 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001187
1188 # Invalidate Input/Output list for error if checks.
1189 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001190 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001191 pCount, cCount = op["operands"]
1192 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001193 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1194 self, error_name, input_list, output_list
1195 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001196
Les Bell729b0352021-11-24 10:28:21 +00001197 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001198 self.ser,
1199 validator_fcns,
1200 error_name,
1201 op=op,
1202 input_shape=ifm.shape,
1203 input_dtype=ifm.dtype,
1204 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001205 output_shape=result_tensor.shape,
1206 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001207 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001208 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001209 input_list=input_list,
1210 output_list=output_list,
1211 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001212 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001213 ):
1214 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001215
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001216 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001217 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001218
1219 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001220
1221 compliance = self.tensorComplianceMetaData(
1222 op, ifm.dtype, args_dict, result_tensor, error_name
1223 )
1224
1225 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001226
James Ward8b390432022-08-12 20:48:56 +01001227 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001228 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001229 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001230 assert len(inputs) == 2
1231 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001232 accum_dtype = args_dict["acc_type"]
1233 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001234 self.ser, self.rng, a, b, accum_dtype, error_name
1235 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001236
1237 # Invalidate Input/Output list for error if checks.
1238 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001239 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001240 pCount, cCount = op["operands"]
1241 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001242 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1243 self, error_name, input_list, output_list
1244 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001245
Les Bell729b0352021-11-24 10:28:21 +00001246 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001247 self.ser,
1248 validator_fcns,
1249 error_name,
1250 op=op,
1251 input_shape=a.shape,
1252 input_dtype=a.dtype,
1253 input2_shape=b.shape,
1254 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001255 output_shape=result_tensor.shape,
1256 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001257 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001258 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001259 input_list=input_list,
1260 output_list=output_list,
1261 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001262 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001263 ):
1264 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001265
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001266 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001267 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001268
1269 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001270
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001271 compliance = self.tensorComplianceMetaData(
1272 op, a.dtype, args_dict, result_tensor, error_name
1273 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001274
1275 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001276
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001277 def build_reduce(
1278 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1279 ):
1280 assert len(inputs) == 1
1281 a = inputs[0]
1282 axis = args_dict["axis"]
1283 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001284
1285 # Invalidate Input/Output list for error if checks.
1286 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001287 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001288 pCount, cCount = op["operands"]
1289 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001290 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1291 self, error_name, input_list, output_list
1292 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001293
Les Bell729b0352021-11-24 10:28:21 +00001294 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001295 self.ser,
1296 validator_fcns,
1297 error_name,
1298 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001299 axis=axis,
1300 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001301 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001302 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001303 output_dtype=result_tensor.dtype,
1304 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001305 input_list=input_list,
1306 output_list=output_list,
1307 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001308 ):
1309 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001310
1311 attr = ts.TosaSerializerAttribute()
1312 attr.AxisAttribute(axis)
1313
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001314 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001315
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001316 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1317 # Number of products - needed for compliance
1318 args_dict["n"] = a.shape[axis]
1319
1320 compliance = self.tensorComplianceMetaData(
1321 op, a.dtype, args_dict, result_tensor, error_name
1322 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001323
1324 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001325
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001326 def build_clamp(
1327 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1328 ):
1329 assert len(inputs) == 1
1330 a = inputs[0]
1331
1332 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001333
Jeremy Johnson18e26662021-07-22 16:15:29 +01001334 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001335
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001336 if error_name == ErrorIf.MaxSmallerMin:
1337 # Make sure the numbers are different to invoke this error
1338 while v[0] == v[1]:
1339 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1340 max_val = min(v)
1341 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001342 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001343 max_val = max(v)
1344 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001345
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001346 # Invalidate Input/Output list for error if checks.
1347 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001348 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001349 pCount, cCount = op["operands"]
1350 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001351 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1352 self, error_name, input_list, output_list
1353 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001354
Les Bell729b0352021-11-24 10:28:21 +00001355 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001356 self.ser,
1357 validator_fcns,
1358 error_name,
1359 op=op,
1360 max_val=max_val,
1361 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001362 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001363 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001364 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001365 output_dtype=result_tensor.dtype,
1366 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001367 input_list=input_list,
1368 output_list=output_list,
1369 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001370 ):
1371 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001372
1373 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001374 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1375 if a.dtype == DType.FP16:
1376 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1377 min_val = min_val.astype(np.float32)
1378 max_val = max_val.astype(np.float32)
1379
1380 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001381 else:
James Ward34071252022-12-07 15:48:47 +00001382 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001383
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001384 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001385
1386 compliance = self.tensorComplianceMetaData(
1387 op, a.dtype, args_dict, result_tensor, error_name
1388 )
1389
1390 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001391
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001392 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1393 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001394 attr = ts.TosaSerializerAttribute()
1395
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001396 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001397
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001398 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001399 return result_tens
1400
1401 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001402 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1403 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001404
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001405 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001406 return result_tens
1407
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001408 def build_activation(
1409 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1410 ):
1411 assert len(inputs) == 1
1412 a = inputs[0]
1413
1414 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001415
1416 # Invalidate Input/Output list for error if checks.
1417 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001418 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001419 pCount, cCount = op["operands"]
1420 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001421 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1422 self, error_name, input_list, output_list
1423 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424
Les Bell729b0352021-11-24 10:28:21 +00001425 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001426 self.ser,
1427 validator_fcns,
1428 error_name,
1429 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001430 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001431 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001432 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001433 output_dtype=result_tensor.dtype,
1434 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001435 input_list=input_list,
1436 output_list=output_list,
1437 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001438 ):
1439 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001440
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001441 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001442
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001443 compliance = self.tensorComplianceMetaData(
1444 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001445 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001446
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001447 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001448
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001449 def build_concat(
1450 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1451 ):
Won Jeon74342e52024-01-09 00:34:40 +00001452 if op["op"] == Op.CONCAT_SHAPE:
1453 axis = 0
1454 else:
1455 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001456 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001457 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001458
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001459 result_tensor = OutputShaper.concatOp(
1460 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001461 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001462
Matthew Haddon818ab902021-07-27 09:12:49 +01001463 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001464 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001465 input_tensor_names.append(tensor.name)
1466
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001467 # Invalidate Input/Output list for error if checks.
1468 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001469 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001470 pCount, cCount = op["operands"]
1471 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001472 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1473 self, error_name, input_list, output_list
1474 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001475
Les Bell729b0352021-11-24 10:28:21 +00001476 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001477 self.ser,
1478 validator_fcns,
1479 error_name,
1480 op=op,
1481 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001482 input_shape=inputs[0].shape,
1483 output_shape=result_tensor.shape,
1484 input_dtype=inputs[0].dtype,
1485 output_dtype=result_tensor.dtype,
1486 inputs=inputs,
1487 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001488 input_list=input_list,
1489 output_list=output_list,
1490 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001491 ):
1492 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001493
Won Jeon74342e52024-01-09 00:34:40 +00001494 if op["op"] == Op.CONCAT:
1495 attr = ts.TosaSerializerAttribute()
1496 attr.AxisAttribute(axis)
1497 else:
1498 assert op["op"] == Op.CONCAT_SHAPE
1499 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001500 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001501
1502 compliance = self.tensorComplianceMetaData(
1503 op, inputs[0].dtype, args_dict, result_tensor, error_name
1504 )
1505
1506 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001507
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001508 def build_pad(
1509 self,
1510 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001511 inputs,
1512 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001513 validator_fcns=None,
1514 error_name=None,
1515 qinfo=None,
1516 ):
Tai Lye095da72024-01-25 22:00:18 +00001517 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001518 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001519 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001520 padding = args_dict["pad"]
1521 pad_const_int = args_dict["pad_const_int"]
1522 pad_const_float = args_dict["pad_const_fp"]
1523
1524 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001525
Tai Lye095da72024-01-25 22:00:18 +00001526 # write empty padding into PadAttribute to ensure inputs[1] is used
Kevin Chengfe392ce2021-10-18 21:51:55 +00001527 attr = ts.TosaSerializerAttribute()
Tai Lye095da72024-01-25 22:00:18 +00001528 attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001529
Matthew Haddone807aae2021-10-11 18:12:58 +01001530 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001531 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001532 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001533 pCount, cCount = op["operands"]
1534 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001535 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1536 self, error_name, input_list, output_list
1537 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001538
Les Bell729b0352021-11-24 10:28:21 +00001539 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001540 self.ser,
1541 validator_fcns,
1542 error_name,
1543 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001544 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001545 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001546 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001547 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001548 pad=padding,
1549 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001550 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001551 input_list=input_list,
1552 output_list=output_list,
1553 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001554 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001555 ):
1556 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001557
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001558 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001559
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001560 compliance = self.tensorComplianceMetaData(
1561 op, a.dtype, args_dict, result_tensor, error_name
1562 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001563
1564 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001565
Won Jeona21b2e82023-08-10 10:33:01 +00001566 def build_dim(
1567 self,
1568 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001569 inputs,
1570 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001571 validator_fcns=None,
1572 error_name=None,
1573 qinfo=None,
1574 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001575 assert len(inputs) == 1
1576 a = inputs[0]
1577 axis = args_dict["axis"]
1578 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001579
1580 # Invalidate Input/Output list for error if checks.
1581 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001582 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001583 pCount, cCount = op["operands"]
1584 num_operands = pCount + cCount
1585 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1586 self, error_name, input_list, output_list
1587 )
1588
1589 if not TosaErrorValidator.evValidateErrorIfs(
1590 self.ser,
1591 validator_fcns,
1592 error_name,
1593 op=op,
1594 axis=axis,
1595 input_shape=a.shape,
1596 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001597 output_shape=result_tensor.shape,
1598 output_dtype=result_tensor.dtype,
1599 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001600 input_list=input_list,
1601 output_list=output_list,
1602 num_operands=num_operands,
1603 ):
1604 return None
1605
1606 attr = ts.TosaSerializerAttribute()
1607 attr.AxisAttribute(axis)
1608
1609 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001610 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001611
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001612 def build_reshape(
1613 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1614 ):
Tai Ly8690a082023-12-18 20:40:24 +00001615 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001616 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001617 shape = inputs[1]
1618 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001619 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001620 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001621 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001622
1623 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001624 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001625 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001626 pCount, cCount = op["operands"]
1627 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001628 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1629 self, error_name, input_list, output_list
1630 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001631
Les Bell729b0352021-11-24 10:28:21 +00001632 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001633 self.ser,
1634 validator_fcns,
1635 error_name,
1636 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001637 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001638 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001639 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001640 output_dtype=result_tensor.dtype,
1641 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001642 input_list=input_list,
1643 output_list=output_list,
1644 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001645 ):
1646 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001647
Tai Ly8690a082023-12-18 20:40:24 +00001648 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001649
1650 compliance = self.tensorComplianceMetaData(
1651 op, a.dtype, args_dict, result_tensor, error_name
1652 )
1653
1654 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001655
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001656 def build_reverse(
1657 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1658 ):
1659 assert len(inputs) == 1
1660 a = inputs[0]
1661 axis = args_dict["axis"]
1662 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001663
1664 # Invalidate Input/Output list for error if checks.
1665 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001666 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001667 pCount, cCount = op["operands"]
1668 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001669 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1670 self, error_name, input_list, output_list
1671 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001672
Les Bell729b0352021-11-24 10:28:21 +00001673 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001674 self.ser,
1675 validator_fcns,
1676 error_name,
1677 op=op,
1678 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001679 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001680 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001681 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001682 output_dtype=result_tensor.dtype,
1683 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001684 input_list=input_list,
1685 output_list=output_list,
1686 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001687 ):
1688 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001689
1690 attr = ts.TosaSerializerAttribute()
1691 attr.AxisAttribute(axis)
1692
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001693 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001694 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001695
evacha0198477222024-01-26 12:25:32 +00001696 def build_transpose(
1697 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1698 ):
1699 assert len(inputs) == 1
1700 a = inputs[0]
1701 perms = args_dict["perms"]
1702
1703 result_tensor = OutputShaper.transposeOp(
1704 self.ser, self.rng, a, perms, error_name
1705 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001706
Kevin Chengfe392ce2021-10-18 21:51:55 +00001707 attr = ts.TosaSerializerAttribute()
1708 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001709
Matthew Haddone807aae2021-10-11 18:12:58 +01001710 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001711 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001712 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001713 pCount, cCount = op["operands"]
1714 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001715 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1716 self, error_name, input_list, output_list
1717 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001718
Les Bell729b0352021-11-24 10:28:21 +00001719 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001720 self.ser,
1721 validator_fcns,
1722 error_name,
1723 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001724 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001725 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001726 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001727 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001728 output_dtype=result_tensor.dtype,
1729 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001730 input_list=input_list,
1731 output_list=output_list,
1732 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001733 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001734 ):
1735 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001736
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001737 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001738
1739 compliance = self.tensorComplianceMetaData(
1740 op, a.dtype, args_dict, result_tensor, error_name
1741 )
1742
1743 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001744
evacha017f7d4252024-01-24 12:08:09 +00001745 def build_slice(
1746 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1747 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001748 assert len(inputs) == 3
1749 a, start_var, size_var = inputs
1750 start_const = args_dict["start"]
1751 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001752
1753 result_tensor = OutputShaper.sliceOp(
TatWai Chongf15bad82024-01-31 21:33:27 -08001754 self.ser, self.rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001755 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001756
1757 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001758 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001759 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001760 pCount, cCount = op["operands"]
1761 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001762 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1763 self, error_name, input_list, output_list
1764 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001765
Les Bell729b0352021-11-24 10:28:21 +00001766 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001767 self.ser,
1768 validator_fcns,
1769 error_name,
1770 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001771 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001772 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001773 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001774 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001775 start=start_const,
1776 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001777 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001778 input_list=input_list,
1779 output_list=output_list,
1780 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001781 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001782 ):
1783 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001784
TatWai Chongf15bad82024-01-31 21:33:27 -08001785 # TODO remove the slice attribute once shape dynamism support is mature.
Eric Kunzee5e26762020-10-13 16:11:07 -07001786 attr = ts.TosaSerializerAttribute()
TatWai Chongf15bad82024-01-31 21:33:27 -08001787 attr.SliceAttribute(start_const, size_const)
Eric Kunzee5e26762020-10-13 16:11:07 -07001788
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001789 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha017f7d4252024-01-24 12:08:09 +00001790
1791 compliance = self.tensorComplianceMetaData(
1792 op, a.dtype, args_dict, result_tensor, error_name
1793 )
1794
1795 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001796
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001797 def build_tile(
1798 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1799 ):
Tai Ly8690a082023-12-18 20:40:24 +00001800 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001801 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001802 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001803 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001804 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001805 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001806 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001807
1808 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001809 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001810 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001811 pCount, cCount = op["operands"]
1812 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001813 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1814 self, error_name, input_list, output_list
1815 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001816
Les Bell729b0352021-11-24 10:28:21 +00001817 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001818 self.ser,
1819 validator_fcns,
1820 error_name,
1821 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001822 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001823 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001824 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001825 output_dtype=result_tensor.dtype,
1826 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001827 input_list=input_list,
1828 output_list=output_list,
1829 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001830 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001831 ):
1832 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001833
Tai Ly8690a082023-12-18 20:40:24 +00001834 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001835
1836 compliance = self.tensorComplianceMetaData(
1837 op, a.dtype, args_dict, result_tensor, error_name
1838 )
1839
1840 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001841
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001842 def build_gather(
1843 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1844 ):
1845 assert len(inputs) == 2
1846 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001847
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001848 result_tensor = OutputShaper.gatherOp(
1849 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001850 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001851
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001852 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001853 input_list = [values.name, indices.name]
1854 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001855 pCount, cCount = op["operands"]
1856 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001857 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1858 self, error_name, input_list, output_list
1859 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001860
Les Bell729b0352021-11-24 10:28:21 +00001861 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001862 self.ser,
1863 validator_fcns,
1864 error_name,
1865 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001866 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001867 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001868 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001869 output_dtype=result_tensor.dtype,
1870 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001871 input_list=input_list,
1872 output_list=output_list,
1873 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001874 ):
1875 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001876
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001877 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001878
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001879 compliance = self.tensorComplianceMetaData(
1880 op, values.dtype, args_dict, result_tensor, error_name
1881 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001882
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001883 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001884
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001885 def build_scatter(
1886 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1887 ):
1888 assert len(inputs) == 3
1889 values_in, indices, input = inputs
1890 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001891 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001892 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001893
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001894 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001895 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001896 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001897 pCount, cCount = op["operands"]
1898 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001899 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1900 self, error_name, input_list, output_list
1901 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001902
Les Bell729b0352021-11-24 10:28:21 +00001903 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001904 self.ser,
1905 validator_fcns,
1906 error_name,
1907 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001908 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001909 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001910 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001911 output_dtype=result_tensor.dtype,
1912 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001913 input_list=input_list,
1914 output_list=output_list,
1915 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001916 ):
1917 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001918
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001919 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001920
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001921 compliance = self.tensorComplianceMetaData(
1922 op, values_in.dtype, args_dict, result_tensor, error_name
1923 )
1924
1925 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001926
Kevin Cheng550ccc52021-03-03 11:21:43 -08001927 def build_resize(
1928 self,
1929 op,
1930 input,
1931 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001932 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001933 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001934 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001935 input_dtype,
1936 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001937 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001938 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001939 ):
1940 result_tens = OutputShaper.resizeOp(
1941 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001942 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001943 input,
1944 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001945 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001946 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001947 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001948 input_dtype,
1949 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001950 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001951 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001952
Matthew Haddon848efb42021-09-09 12:30:53 +01001953 # Invalidate Input/Output list for error if checks.
1954 input_list = [input.name]
1955 output_list = [result_tens.name]
1956 pCount, cCount = op["operands"]
1957 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001958 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1959 self, error_name, input_list, output_list
1960 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001961
Les Bell729b0352021-11-24 10:28:21 +00001962 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001963 self.ser,
1964 validator_fcns,
1965 error_name,
1966 op=op,
1967 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001968 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001969 input_dtype=input_dtype,
1970 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001971 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001972 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001973 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001974 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001975 input_list=input_list,
1976 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001977 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001978 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001979 ):
1980 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001981
Eric Kunzee5e26762020-10-13 16:11:07 -07001982 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001983
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001984 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001985
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001986 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001987 return result_tens
1988
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001989 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1990 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1991 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001992 self.ser.addOperator(
1993 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1994 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001995 return result_tens
1996
evacha0198477222024-01-26 12:25:32 +00001997 def build_const(
1998 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1999 ):
2000 assert len(inputs) == 1
2001 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002002 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002003
2004 compliance = self.tensorComplianceMetaData(
2005 op, val.dtype, args_dict, val, error_name
2006 )
2007
2008 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002009
2010 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002011 def build_cast(
2012 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2013 ):
2014 assert len(inputs) == 1
2015 val = inputs[0]
2016 out_dtype = args_dict["out_type"]
2017
2018 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002019 self.ser, self.rng, val, out_dtype, error_name
2020 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002021
2022 # Invalidate Input/Output list for error if checks.
2023 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002024 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002025 pCount, cCount = op["operands"]
2026 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002027 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2028 self, error_name, input_list, output_list
2029 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002030
Les Bell729b0352021-11-24 10:28:21 +00002031 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002032 self.ser,
2033 validator_fcns,
2034 error_name,
2035 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002036 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002037 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002038 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002039 output_dtype=result_tensor.dtype,
2040 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002041 input_list=input_list,
2042 output_list=output_list,
2043 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002044 ):
2045 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002046
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002047 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002048
2049 compliance = self.tensorComplianceMetaData(
2050 op, val.dtype, args_dict, result_tensor, error_name
2051 )
2052
2053 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002054
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002055 def build_rescale(
2056 self,
2057 op,
2058 val,
2059 out_dtype,
2060 scale32,
2061 double_round,
2062 per_channel,
2063 validator_fcns,
2064 error_name,
2065 ):
2066 result_tens = OutputShaper.typeConversionOp(
2067 self.ser, self.rng, val, out_dtype, error_name
2068 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002069
2070 if per_channel:
2071 nc = val.shape[-1]
2072 else:
2073 nc = 1
2074
2075 in_type_width = self.typeWidth(val.dtype)
2076 out_type_width = self.typeWidth(out_dtype)
2077
Tai Ly8690a082023-12-18 20:40:24 +00002078 input_unsigned = False
2079 output_unsigned = False
2080
Kevin Cheng3a478572021-01-22 17:21:02 -08002081 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002082 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002083 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002084 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002085 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002086 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002087 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002088 elif error_name in [
2089 ErrorIf.InputZeroPointNotZero,
2090 ErrorIf.U16InputZeroPointNotValid,
2091 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002092 input_zp = self.randInt(-128, 128)
2093 if input_zp == 0:
2094 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002095 in_type_width += 1
2096 elif val.dtype == DType.UINT16:
2097 # Must come after ErrorIf.U16InputZeroPointNotValid check
2098 input_zp = self.rng.choice([0, 32768])
2099 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002100 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002101 else:
2102 input_zp = 0
2103
Kevin Cheng3a478572021-01-22 17:21:02 -08002104 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002105 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002106 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002107 elif out_dtype == DType.UINT8:
2108 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002109 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002110 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002111 elif error_name in [
2112 ErrorIf.OutputZeroPointNotZero,
2113 ErrorIf.U16OutputZeroPointNotValid,
2114 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002115 output_zp = self.randInt(-128, 128)
2116 if output_zp == 0:
2117 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002118 out_type_width += 1
2119 elif out_dtype == DType.UINT16:
2120 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2121 output_zp = self.rng.choice([0, 32768])
2122 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002123 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002124 else:
2125 output_zp = 0
2126
2127 # Calculate scale based on:
2128 # scale = a *(2^output_width)/(2^input_width))
2129
2130 a = np.float32(self.rng.random(size=[nc]))
2131 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2132
2133 if scale32:
2134 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002135 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002136 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2137 else:
2138 # Cap the scaling at 2^15 - 1 for scale16
2139 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2140
Kevin Cheng550ccc52021-03-03 11:21:43 -08002141 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002142
2143 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2144 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002145 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2146 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002147
2148 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002149 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2150 scale_arr[i], scale32
2151 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002152 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2153 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002154
Kevin Cheng550ccc52021-03-03 11:21:43 -08002155 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002156 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002157 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002158 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002159 assert val.placeholderFilename
2160 values = np.load(
2161 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2162 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002163 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2164 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2165 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002166 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2167 # Check we can safely convert to the expected dtype
2168 assert (
2169 val_adj.all() >= np.iinfo(values.dtype).min
2170 and val_adj.all() <= np.iinfo(values.dtype).max
2171 )
2172
2173 # Force casting to output datatype
2174 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2175
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002176 if not np.all(np.array_equal(values, val_adj)):
2177 # Values changed so overwrite file with new values
2178 np.save(
2179 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2180 val_adj,
2181 False,
2182 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002183
Matthew Haddonc2025212021-10-08 21:21:05 +01002184 # Invalidate Input/Output list for error if checks.
2185 input_list = [val.name]
2186 output_list = [result_tens.name]
2187 pCount, cCount = op["operands"]
2188 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002189 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2190 self, error_name, input_list, output_list
2191 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002192
2193 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002194 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002195 self.ser,
2196 validator_fcns,
2197 error_name,
2198 op=op,
2199 input_dtype=val.dtype,
2200 output_dtype=out_dtype,
2201 input_shape=val.shape,
2202 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002203 scale32=scale32,
2204 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002205 input_list=input_list,
2206 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002207 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002208 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002209 ):
2210 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002211
Eric Kunzee5e26762020-10-13 16:11:07 -07002212 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002213 attr.RescaleAttribute(
2214 input_zp,
2215 output_zp,
2216 multiplier_arr,
2217 shift_arr,
2218 scale32,
2219 double_round,
2220 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002221 input_unsigned,
2222 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002223 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002224
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002225 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002226 return result_tens
2227
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002228 def _get_condition_tensor(self, op, cond, error_name):
2229 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002230 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002231 else:
2232 cond_type = DType.BOOL
2233 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2234 choice = self.rng.choice([1, 2])
2235 if choice == 1:
2236 cond_shape = [2]
2237 else:
2238 cond_shape = [1, 2]
2239 else:
2240 # Must be of size 1 (rank 0)
2241 cond_shape = []
2242 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2243 return cond_tens
2244
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002245 def build_cond_if_const(
2246 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2247 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002248 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002249 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002250 # and fill them with const nodes for the body.
2251
2252 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002253 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002254
2255 # Make then/else tensors
2256 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002257
2258 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002259 if error_name in [
2260 ErrorIf.CondIfOutputListThenGraphMismatch,
2261 ErrorIf.CondIfOutputListElseGraphMismatch,
2262 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002263 incorrect_shape = deepcopy(then_tens.shape)
2264 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002265 incorrect_shape[i] += (
2266 self.rng.choice([-3, -2, 2, 3])
2267 if incorrect_shape[i] > 3
2268 else self.rng.choice([1, 2, 4])
2269 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002270 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2271
Jeremy Johnson18e26662021-07-22 16:15:29 +01002272 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2273 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002274
2275 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002276 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002277
2278 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002279 then_block = "THEN_BLOCK"
2280 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002281 attr = ts.TosaSerializerAttribute()
2282 attr.CondIfAttribute(then_block, else_block)
2283
2284 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002285 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002286
Jerry Ge9e94af82022-10-27 09:57:00 -07002287 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002288 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002289 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2290 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2291 else:
2292 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002293 self.ser.addOutputTensor(then_tens)
2294
Jerry Ge9e94af82022-10-27 09:57:00 -07002295 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002296 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2297 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2298 else:
2299 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002300 self.ser.addOutputTensor(else_tens)
2301
Les Bell729b0352021-11-24 10:28:21 +00002302 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002303 self.ser,
2304 validator_fcns,
2305 error_name,
2306 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002307 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002308 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002309 ):
2310 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002311
Eric Kunzee5e26762020-10-13 16:11:07 -07002312 return result_tens
2313
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002314 def build_cond_if_binary(
2315 self, op, a, b, cond, validator_fcns=None, error_name=None
2316 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002317 # For cond_if with a binary op in the then/else blocks, take a and b and
2318 # alternately add or subtract them based on the condition
2319
2320 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002321 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002322
Kevin Cheng550ccc52021-03-03 11:21:43 -08002323 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002324
2325 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002326 then_block = "THEN_BLOCK"
2327 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002328 attr = ts.TosaSerializerAttribute()
2329 attr.CondIfAttribute(then_block, else_block)
2330
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002331 if error_name in [
2332 ErrorIf.CondIfInputListThenGraphMismatch,
2333 ErrorIf.CondIfInputListElseGraphMismatch,
2334 ErrorIf.CondIfOutputListElseGraphMismatch,
2335 ErrorIf.CondIfOutputListThenGraphMismatch,
2336 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002337 incorrect_shape = a.shape.copy()
2338 for i in range(len(incorrect_shape)):
2339 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2340 incorrect_block_input = deepcopy(a)
2341 incorrect_block_input.shape = incorrect_shape
2342
Eric Kunzee5e26762020-10-13 16:11:07 -07002343 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002344 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002345 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002346 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002347
James Ward24dbc422022-10-19 12:20:31 +01002348 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002349 then_op, else_op = Op.ADD, Op.SUB
2350 elif a.dtype in (DType.INT8, DType.INT16):
2351 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2352 else:
2353 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002354
Les Bell6040b4d2021-10-11 12:50:31 +01002355 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002356 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002357 if (
2358 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2359 and block == then_block
2360 ) or (
2361 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2362 and block == else_block
2363 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002364 self.ser.addInputTensor(incorrect_block_input)
2365 self.ser.addInputTensor(b)
2366 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002367 elif (
2368 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2369 and block == then_block
2370 ) or (
2371 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2372 and block == else_block
2373 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002374 self.ser.addInputTensor(a)
2375 self.ser.addInputTensor(b)
2376 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2377 else:
2378 self.ser.addInputTensor(a)
2379 self.ser.addInputTensor(b)
2380 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002381 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002382
Les Bell729b0352021-11-24 10:28:21 +00002383 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002384 self.ser,
2385 validator_fcns,
2386 error_name,
2387 op=op,
2388 a=a,
2389 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002390 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002391 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002392 ):
2393 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002394
Eric Kunzee5e26762020-10-13 16:11:07 -07002395 return result_tens
2396
Matthew Haddon630c17c2021-10-14 15:05:41 +01002397 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002398 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002399
Kevin Cheng550ccc52021-03-03 11:21:43 -08002400 cond_block = "COND_BLOCK"
2401 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002402
2403 attr = ts.TosaSerializerAttribute()
2404 attr.WhileLoopAttribute(cond_block, body_block)
2405
2406 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002407 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002408 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002409 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002410
2411 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002412 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2413 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002414 if error_name == ErrorIf.InputListOutputListMismatch:
2415 incorrect_acc = deepcopy(acc)
2416 for i in range(len(incorrect_acc.shape)):
2417 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2418 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2419 else:
2420 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002421
2422 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002423 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002424 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002425 [iter.name, a.name, acc.name],
2426 [iter_out.name, a_out.name, acc_out.name],
2427 attr,
2428 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002429 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002430
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002431 if error_name in [
2432 ErrorIf.InputListCondGraphMismatch,
2433 ErrorIf.InputListBodyGraphInputMismatch,
2434 ErrorIf.InputListBodyGraphOutputMismatch,
2435 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002436 incorrect_iter = deepcopy(iter)
2437 for i in range(len(incorrect_iter.shape)):
2438 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2439 if len(incorrect_iter.shape) == 0:
2440 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2441
2442 incorrect_acc = deepcopy(acc)
2443 for i in range(len(incorrect_acc.shape)):
2444 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2445
Eric Kunzee5e26762020-10-13 16:11:07 -07002446 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002447 self.ser.addBasicBlock(cond_block)
2448
Matthew Haddon630c17c2021-10-14 15:05:41 +01002449 if error_name == ErrorIf.InputListCondGraphMismatch:
2450 self.ser.addInputTensor(incorrect_iter)
2451 self.ser.addInputTensor(a)
2452 self.ser.addInputTensor(incorrect_acc)
2453 else:
2454 self.ser.addInputTensor(iter)
2455 self.ser.addInputTensor(a)
2456 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002457 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002458
2459 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002460 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002461 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002462 cond_type = DType.BOOL
2463 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2464 choice = self.rng.choice([1, 2])
2465 if choice == 1:
2466 cond_shape = [3]
2467 else:
2468 cond_shape = [1, 2]
2469 else:
2470 cond_shape = []
2471 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002472
Kevin Cheng550ccc52021-03-03 11:21:43 -08002473 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002474
2475 # BODY block (input: a, acc, iter, output: a, acc, iter)
2476 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002477 self.ser.addBasicBlock(body_block)
2478
Matthew Haddon630c17c2021-10-14 15:05:41 +01002479 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2480 self.ser.addInputTensor(incorrect_iter)
2481 self.ser.addInputTensor(a)
2482 self.ser.addInputTensor(incorrect_acc)
2483 else:
2484 self.ser.addInputTensor(iter)
2485 self.ser.addInputTensor(a)
2486 self.ser.addInputTensor(acc)
2487
Kevin Cheng550ccc52021-03-03 11:21:43 -08002488 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002489
2490 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002491 iter_body_out = self.ser.addIntermediate(
2492 incorrect_iter.shape, incorrect_iter.dtype
2493 )
2494 acc_body_out = self.ser.addIntermediate(
2495 incorrect_acc.shape, incorrect_acc.dtype
2496 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002497 else:
2498 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2499 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2500
Eric Kunzee5e26762020-10-13 16:11:07 -07002501 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2502 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2503 self.ser.addOutputTensor(iter_body_out)
2504 self.ser.addOutputTensor(a)
2505 self.ser.addOutputTensor(acc_body_out)
2506
Les Bell729b0352021-11-24 10:28:21 +00002507 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002508 self.ser,
2509 validator_fcns,
2510 error_name,
2511 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002512 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002513 ):
2514 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002515
Eric Kunzee5e26762020-10-13 16:11:07 -07002516 return acc_out
2517
Luke Hutton57287132023-02-06 14:54:18 +00002518 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002519 self,
2520 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002521 inputs,
2522 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002523 validator_fcns=None,
2524 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002525 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002526 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002527 assert len(inputs) == 2
2528 val1, val2 = inputs
2529 inverse = args_dict["inverse"]
2530
Luke Hutton57287132023-02-06 14:54:18 +00002531 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2532
2533 input_names = [val1.name, val2.name]
2534 pCount, cCount = op["operands"]
2535 num_operands = pCount + cCount
2536
2537 output_names = [res.name for res in results]
2538 output_shapes = [res.shape for res in results]
2539 output_dtypes = [res.dtype for res in results]
2540
2541 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2542 self, error_name, input_names, output_names
2543 )
2544
2545 if not TosaErrorValidator.evValidateErrorIfs(
2546 self.ser,
2547 validator_fcns,
2548 error_name,
2549 op=op,
2550 inverse=inverse,
2551 input1=val1,
2552 input2=val2,
2553 input_shape=val1.shape,
2554 input_dtype=val1.dtype,
2555 output_shape=output_shapes,
2556 output_dtype=output_dtypes,
2557 result_tensors=results,
2558 input_list=input_names,
2559 output_list=output_names,
2560 num_operands=num_operands,
2561 ):
2562 return None
2563
Tai Lyd3797f02023-11-15 23:06:19 +00002564 # TODO - Test local_bound, for now set local bound attribute to False
2565 local_bound = False
2566
Luke Hutton57287132023-02-06 14:54:18 +00002567 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002568 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002569
2570 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002571
2572 compliance = []
2573 for res in results:
2574 compliance.append(
2575 self.tensorComplianceMetaData(
2576 op, val1.dtype, args_dict, res, error_name
2577 )
2578 )
2579
2580 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002581
Tai Lyd3797f02023-11-15 23:06:19 +00002582 def build_rfft2d(
2583 self,
2584 op,
2585 val,
2586 validator_fcns=None,
2587 error_name=None,
2588 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002589 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2590
2591 input_names = [val.name]
2592 pCount, cCount = op["operands"]
2593 num_operands = pCount + cCount
2594
2595 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002596 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002597 output_dtypes = [res.dtype for res in results]
2598
2599 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2600 self, error_name, input_names, output_names
2601 )
2602
2603 if not TosaErrorValidator.evValidateErrorIfs(
2604 self.ser,
2605 validator_fcns,
2606 error_name,
2607 op=op,
2608 input_shape=val.shape,
2609 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002610 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002611 output_dtype=output_dtypes,
2612 result_tensors=results,
2613 input_list=input_names,
2614 output_list=output_names,
2615 num_operands=num_operands,
2616 ):
2617 return None
2618
Tai Lyd3797f02023-11-15 23:06:19 +00002619 # TODO - Test local_bound, for now set local bound attribute to False
2620 local_bound = False
2621
2622 attr = ts.TosaSerializerAttribute()
2623 attr.RFFTAttribute(local_bound)
2624
2625 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002626 return results
2627
Won Jeon74342e52024-01-09 00:34:40 +00002628 def build_shape_op(
2629 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2630 ):
2631 assert len(inputs) == 2
2632 a, b = inputs
2633
2634 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2635
2636 # Invalidate Input/Output list for error if checks.
2637 input_list = [a.name, b.name]
2638 output_list = [result_tensor.name]
2639 pCount, cCount = op["operands"]
2640 num_operands = pCount + cCount
2641 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2642 self, error_name, input_list, output_list
2643 )
2644
2645 if not TosaErrorValidator.evValidateErrorIfs(
2646 self.ser,
2647 validator_fcns,
2648 error_name,
2649 op=op,
2650 input1=a,
2651 input2=b,
2652 input_shape=a.shape,
2653 input_dtype=a.dtype,
2654 output_shape=result_tensor.shape,
2655 output_dtype=result_tensor.dtype,
2656 result_tensors=[result_tensor],
2657 input_list=input_list,
2658 output_list=output_list,
2659 num_operands=num_operands,
2660 ):
2661 return None
2662
2663 self.ser.addOperator(
2664 op["op"],
2665 input_list,
2666 output_list,
2667 )
2668 compliance = self.tensorComplianceMetaData(
2669 op, a.dtype, args_dict, result_tensor, error_name
2670 )
2671
2672 return TosaTestGen.BuildInfo(result_tensor, compliance)
2673
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002674 def create_filter_lists(
2675 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2676 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002677 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2678 default_test_rank_range = range(1, 5)
2679 if not shapeFilter:
2680 shapeFilter = [None]
2681
2682 # Calculate the filters based on what is requested and what the operator allows
2683 rmin, rmax = op["rank"]
2684 if rankFilter is not None:
2685 cleanRankFilter = []
2686 # Ensure rankFilter values are allowed by operator
2687 for rank in rankFilter:
2688 if rank >= rmin and rank <= rmax:
2689 cleanRankFilter.append(rank)
2690 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002691 # Ensure default behaviour is bounded by default range or by operator,
2692 # whichever is the smaller range of ranks.
2693 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002694 cleanRankFilter = (
2695 opRankRange
2696 if len(opRankRange) <= len(default_test_rank_range)
2697 else default_test_rank_range
2698 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002699 else:
2700 cleanRankFilter = range(rmin, rmax + 1)
2701
2702 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002703
Matthew Haddon1c00b712021-10-01 15:51:03 +01002704 if dtypeFilter is not None:
2705 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002706 # Create list of operator dtypes filtered by requested dtypes
2707 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002708 if dtype in dtypeFilter or (
2709 isinstance(dtype, list) and dtype[0] in dtypeFilter
2710 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002711 cleanDtypeFilter.append(dtype)
2712 else:
2713 cleanDtypeFilter = dtypes
2714
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002715 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002716 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002717 "shapeFilter": shapeFilter,
2718 "rankFilter": cleanRankFilter,
2719 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002720 }
2721 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002722 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002723 if validator is not None:
2724 validator_info = validator(check=False, op=op)
2725 else:
2726 return None
2727
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002728 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002729
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002730 # Set parameters as required
2731 if error_arguments["rank"] is not None:
2732 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002733 else:
2734 rankFilter = cleanRankFilter
2735
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002736 if error_arguments["dtype"] is not None:
2737 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002738 else:
2739 dtypeFilter = cleanDtypeFilter
2740
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002741 if error_arguments["shape"] is not None:
2742 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002743 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002744 shapeFilter = shapeFilter[
2745 :2
2746 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002747
2748 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002749 "shapeFilter": shapeFilter,
2750 "rankFilter": rankFilter,
2751 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002752 }
2753 return filterDict
2754
Kevin Cheng550ccc52021-03-03 11:21:43 -08002755 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002756 self,
2757 opName,
2758 shapeFilter=[None],
2759 rankFilter=None,
2760 dtypeFilter=None,
2761 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002762 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002763
2764 try:
2765 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002766 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002767 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002768
2769 # Initialize a new random number generator
2770 self.rng = np.random.default_rng(self.random_seed)
2771
Jeremy Johnson1271c442023-09-05 11:39:26 +01002772 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002773
Eric Kunzee5e26762020-10-13 16:11:07 -07002774 # Test list consists of a tuple of:
2775 # (opName, testNameStr, dtype, shapeList, argumentsList)
2776 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002777 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002778 error_if_validators = op["error_if_validators"]
2779 else:
2780 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002781
Matthew Haddon1c00b712021-10-01 15:51:03 +01002782 for validator in error_if_validators:
2783 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002784 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002785 else:
2786 error_name = None
2787
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002788 filterDict = self.create_filter_lists(
2789 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2790 )
2791 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002792 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002793 cleanRankFilter = filterDict["rankFilter"]
2794 cleanDtypeFilter = filterDict["dtypeFilter"]
2795 cleanShapeFilter = filterDict["shapeFilter"]
2796 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002797
2798 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002799 for t in cleanDtypeFilter:
2800 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002801 # Filter out by rank
2802 if shape is not None and len(shape) != r:
2803 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002804 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002805 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002806
Matthew Haddon74567092021-07-16 15:38:20 +01002807 shapeStr = self.shapeStr(shapeList[0])
2808 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002809
Matthew Haddon74567092021-07-16 15:38:20 +01002810 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2811 argList = []
2812 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002813 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002814 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002815 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002816
Matthew Haddon74567092021-07-16 15:38:20 +01002817 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002818 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002819 if argStr:
2820 testStr = "{}_{}_{}_{}".format(
2821 opName, shapeStr, typeStr, argStr
2822 )
2823 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002824 testStr = "{}_{}_{}".format(
2825 opName, shapeStr, typeStr
2826 )
2827 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002828 if argStr:
2829 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2830 opName, error_name, shapeStr, typeStr, argStr
2831 )
2832 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002833 testStr = "{}_ERRORIF_{}_{}_{}".format(
2834 opName, error_name, shapeStr, typeStr
2835 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002836
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002837 testList.append(
2838 (opName, testStr, t, error_name, shapeList, args)
2839 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002840
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002841 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002842 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2843 if "invalid_test_validators" in op:
2844 invalid_test_validators = op["invalid_test_validators"]
2845 clean_testList = []
2846 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002847 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002848 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002849 if validator_fcn(
2850 opName=test[0],
2851 input_dtype=test[2],
2852 shapeList=test[4],
2853 args=test[5],
2854 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002855 remove_test = True
2856 if not remove_test:
2857 clean_testList.append(test)
2858 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002859
2860 return testList
2861
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002862 def serializeTest(
2863 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2864 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002865 try:
2866 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002867 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002868 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002869
Jeremy Johnson0c716862023-04-13 17:18:19 +01002870 if self.args.verbose:
2871 print(f"Creating {testStr}")
2872
Eric Kunzee5e26762020-10-13 16:11:07 -07002873 # Create a serializer
2874 self.createSerializer(opName, testStr)
2875
Jeremy Johnson1271c442023-09-05 11:39:26 +01002876 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002877 if "error_if_validators" in op:
2878 error_if_validators = op["error_if_validators"]
2879 else:
2880 error_if_validators = None
2881
Kevin Cheng550ccc52021-03-03 11:21:43 -08002882 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002883 num_operands = pCount + cCount
2884
2885 if isinstance(dtype_or_dtypeList, list):
2886 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00002887 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002888 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002889 else:
2890 dtypeList = [dtype_or_dtypeList] * (num_operands)
2891
Won Jeon74342e52024-01-09 00:34:40 +00002892 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002893 assert (
2894 len(shapeList) == num_operands
2895 ), "shapeList length {} must match number of operands {}".format(
2896 len(shapeList), num_operands
2897 )
2898 assert (
2899 len(dtypeList) == num_operands
2900 ), "dtypeList length {} must match number of operands {}".format(
2901 len(dtypeList), num_operands
2902 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002903
2904 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002905 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002906 except KeyError:
2907 qgen = None
2908
2909 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002910
Matthew Haddon1c00b712021-10-01 15:51:03 +01002911 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002912 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002913 else:
2914 qinfo = None
2915
Jeremy Johnson1271c442023-09-05 11:39:26 +01002916 # Extra meta data for the desc.json
2917 tensMeta = {}
2918
2919 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002920 if isinstance(testArgs, dict):
2921 # New interface with args info in dictionary
2922 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002923 assert "dg_type" in argsDict
2924 tvgInfo = tvgen_fcn(
2925 self, opName, dtypeList, shapeList, argsDict, error_name
2926 )
2927 if tvgInfo.dataGenDict:
2928 tensMeta["data_gen"] = tvgInfo.dataGenDict
2929 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002930
2931 result = build_fcn(
2932 self,
2933 op,
2934 tens,
2935 argsDict,
2936 validator_fcns=error_if_validators,
2937 error_name=error_name,
2938 qinfo=qinfo,
2939 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002940 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002941 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002942 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002943
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002944 try:
2945 if error_if_validators is None:
2946 if qinfo is not None:
2947 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2948 else:
2949 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002950 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002951 if qinfo is not None:
2952 result = build_fcn(
2953 self,
2954 op,
2955 *tens,
2956 *testArgs,
2957 validator_fcns=error_if_validators,
2958 error_name=error_name,
2959 qinfo=qinfo,
2960 )
2961 else:
2962 result = build_fcn(
2963 self,
2964 op,
2965 *tens,
2966 *testArgs,
2967 validator_fcns=error_if_validators,
2968 error_name=error_name,
2969 )
2970 except TypeError as e:
2971 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2972 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002973
Jeremy Johnson1271c442023-09-05 11:39:26 +01002974 if result:
Les Bell729b0352021-11-24 10:28:21 +00002975 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002976 if isinstance(result, TosaTestGen.BuildInfo):
2977 # Add the compliance meta data (if any)
2978 compliance = result.getComplianceInfo()
2979 if compliance:
2980 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01002981 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002982 else:
2983 # The test is not valid
2984 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002985
Eric Kunzee5e26762020-10-13 16:11:07 -07002986 def createDynamicOpLists(self):
2987
Jeremy Johnson00423432022-09-12 17:27:37 +01002988 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2989 # Already created these lists (can occur when class is initialized more than once)
2990 return
2991
Eric Kunzee5e26762020-10-13 16:11:07 -07002992 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002993 if not self.args.level8k:
2994 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2995 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2996 else:
2997 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2998 KERNELS_2D = [[1, bigK], [bigK, 2]]
2999 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003000
Kevin Cheng1533b852021-09-01 12:51:58 -07003001 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003002 testName = "conv2d_{}x{}".format(k[0], k[1])
3003 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3004 self.TOSA_OP_LIST[testName]["filter"] = k
3005 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003006
Kevin Cheng550ccc52021-03-03 11:21:43 -08003007 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3008 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3009 "depthwise_conv2d_TEMPLATE"
3010 ].copy()
3011 self.TOSA_OP_LIST[testName]["filter"] = k
3012 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003013
Kevin Cheng550ccc52021-03-03 11:21:43 -08003014 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3015 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3016 "transpose_conv2d_TEMPLATE"
3017 ].copy()
3018 self.TOSA_OP_LIST[testName]["filter"] = k
3019 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003020
Kevin Cheng1533b852021-09-01 12:51:58 -07003021 for k in KERNELS_3D:
3022 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3023 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3024 self.TOSA_OP_LIST[testName]["filter"] = k
3025 self.TOSA_OP_LIST[testName]["template"] = False
3026
Eric Kunzee5e26762020-10-13 16:11:07 -07003027 # Delete any templates after having created any dynamic ops
3028 # This is a two-pass operation because it's bad practice to delete
3029 # keys from dictionaries while iterating
3030 keyList = []
3031 for k in self.TOSA_OP_LIST:
3032 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003033 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003034 keyList.append(k)
3035 continue
3036 except KeyError:
3037 pass
3038
3039 for k in keyList:
3040 del self.TOSA_OP_LIST[k]
3041
3042 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003043 """Fill in default fields for ops if they aren't already specified.
3044 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003045 for op in self.TOSA_OP_LIST:
3046
3047 # Required fields
3048 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003049 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003050 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003051 raise Exception(
3052 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3053 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003054
3055 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003056 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003057 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003058 raise Exception(
3059 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3060 op
3061 )
3062 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003063
3064 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003065 _ = self.TOSA_OP_LIST[op]["types"]
3066 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003067 raise Exception(
3068 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3069 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003070
3071 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003072 _ = self.TOSA_OP_LIST[op]["op"]
3073 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003074 raise Exception(
3075 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3076 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003077
3078 # Put in default rank range, if missing
3079 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003080 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003081 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003082 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003083
3084 # Tensor operator list
3085 # 'op': op name
3086 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003087 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3088 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003089 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3090 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003091 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003092
Kevin Cheng550ccc52021-03-03 11:21:43 -08003093 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003094 TYPE_INT_FP = [
3095 DType.INT8,
3096 DType.INT16,
3097 DType.INT32,
3098 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003099 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003100 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003101 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003102
Kevin Cheng550ccc52021-03-03 11:21:43 -08003103 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003104 TYPE_FI32 = [
3105 DType.FP32,
3106 DType.FP16,
3107 DType.BF16,
3108 DType.INT32,
3109 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003110 TYPE_FIB = [
3111 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003112 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003113 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003114 DType.INT8,
3115 DType.INT16,
3116 DType.INT32,
3117 DType.BOOL,
3118 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003119 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003120
James Ward24dbc422022-10-19 12:20:31 +01003121 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07003122
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003123 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003124 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003125 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003126 [DType.INT8, DType.INT8, DType.INT32],
3127 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003128 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003129 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003130 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003131 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003132 ]
3133
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003134 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003135
3136 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003137 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003138 "argmax": {
3139 "op": Op.ARGMAX,
3140 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003141 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003142 "build_fcn": (
3143 build_argmax,
3144 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003145 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003146 TosaArgGen.agAxis,
3147 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003148 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003149 "error_if_validators": (
3150 TosaErrorValidator.evAxisSmallerZero,
3151 TosaErrorValidator.evAxisLargerRank,
3152 TosaErrorValidator.evArgmaxOutputRankMismatch,
3153 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3154 TosaErrorValidator.evWrongRank,
3155 TosaErrorValidator.evWrongInputType,
3156 TosaErrorValidator.evWrongOutputType,
3157 TosaErrorValidator.evWrongInputList,
3158 TosaErrorValidator.evWrongOutputList,
3159 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003160 "data_gen": {
3161 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3162 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003163 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003164 "avg_pool2d": {
3165 "op": Op.AVG_POOL2D,
3166 "operands": (1, 0),
3167 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003168 "build_fcn": (
3169 build_pool2d,
3170 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003171 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003172 TosaArgGen.agPooling,
3173 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003174 "qgen": TosaQuantGen.qgUnary,
3175 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003176 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003177 "error_if_validators": (
3178 TosaErrorValidator.evKernelSmallerOne,
3179 TosaErrorValidator.evStrideSmallerOne,
3180 TosaErrorValidator.evPadSmallerZero,
3181 TosaErrorValidator.evWrongRank,
3182 TosaErrorValidator.evWrongInputType,
3183 TosaErrorValidator.evWrongOutputType,
3184 TosaErrorValidator.evWrongInputList,
3185 TosaErrorValidator.evWrongOutputList,
3186 TosaErrorValidator.evInputZeroPointNotZero,
3187 TosaErrorValidator.evOutputZeroPointNotZero,
3188 TosaErrorValidator.evPadLargerEqualKernel,
3189 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003190 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003191 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003192 "data_gen": {
3193 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3194 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003195 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003196 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003197 "conv2d_TEMPLATE": {
3198 "op": Op.CONV2D,
3199 "operands": (1, 2),
3200 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003201 "build_fcn": (
3202 build_conv2d,
3203 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003204 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003205 TosaArgGen.agConv,
3206 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003207 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003208 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003209 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3210 "error_if_validators": (
3211 TosaErrorValidator.evWrongInputType,
3212 TosaErrorValidator.evWrongOutputType,
3213 TosaErrorValidator.evWrongInputList,
3214 TosaErrorValidator.evWrongOutputList,
3215 TosaErrorValidator.evInputZeroPointNotZero,
3216 TosaErrorValidator.evWeightZeroPointNotZero,
3217 TosaErrorValidator.evPadSmallerZero,
3218 TosaErrorValidator.evStrideSmallerOne,
3219 TosaErrorValidator.evDilationSmallerOne,
3220 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003221 TosaErrorValidator.evConvOutputShapeMismatch,
3222 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003223 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003224 "data_gen": {
3225 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3226 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003227 "template": True,
3228 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003229 # Templated operator. Filled in by createDynamicOpLists
3230 "conv3d_TEMPLATE": {
3231 "op": Op.CONV3D,
3232 "operands": (1, 2),
3233 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003234 "build_fcn": (
3235 build_conv3d,
3236 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003237 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003238 TosaArgGen.agConv,
3239 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003240 "qgen": TosaQuantGen.qgConv,
3241 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003242 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3243 "error_if_validators": (
3244 TosaErrorValidator.evWrongInputType,
3245 TosaErrorValidator.evWrongOutputType,
3246 TosaErrorValidator.evWrongInputList,
3247 TosaErrorValidator.evWrongOutputList,
3248 TosaErrorValidator.evInputZeroPointNotZero,
3249 TosaErrorValidator.evWeightZeroPointNotZero,
3250 TosaErrorValidator.evPadSmallerZero,
3251 TosaErrorValidator.evStrideSmallerOne,
3252 TosaErrorValidator.evDilationSmallerOne,
3253 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003254 TosaErrorValidator.evConvOutputShapeMismatch,
3255 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003256 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003257 "template": True,
3258 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003259 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003260 "depthwise_conv2d_TEMPLATE": {
3261 "op": Op.DEPTHWISE_CONV2D,
3262 "operands": (1, 2),
3263 "filter": [1, 1],
3264 "rank": (4, 4),
3265 "build_fcn": (
3266 build_depthwise_conv2d,
3267 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003268 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003269 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003270 ),
3271 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003272 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003273 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3274 "error_if_validators": (
3275 TosaErrorValidator.evWrongInputType,
3276 TosaErrorValidator.evWrongOutputType,
3277 TosaErrorValidator.evWrongInputList,
3278 TosaErrorValidator.evWrongOutputList,
3279 TosaErrorValidator.evInputZeroPointNotZero,
3280 TosaErrorValidator.evWeightZeroPointNotZero,
3281 TosaErrorValidator.evPadSmallerZero,
3282 TosaErrorValidator.evStrideSmallerOne,
3283 TosaErrorValidator.evDilationSmallerOne,
3284 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003285 TosaErrorValidator.evConvOutputShapeMismatch,
3286 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003287 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003288 "data_gen": {
3289 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3290 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003291 "template": True,
3292 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003293 "fully_connected": {
3294 "op": Op.FULLY_CONNECTED,
3295 "operands": (1, 2),
3296 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003297 "build_fcn": (
3298 build_fully_connected,
3299 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003300 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003301 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003302 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003303 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003304 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003305 "error_if_validators": (
3306 TosaErrorValidator.evInputZeroPointNotZero,
3307 TosaErrorValidator.evWeightZeroPointNotZero,
3308 TosaErrorValidator.evWrongRank,
3309 TosaErrorValidator.evWrongInputType,
3310 TosaErrorValidator.evWrongOutputType,
3311 TosaErrorValidator.evWrongInputList,
3312 TosaErrorValidator.evWrongOutputList,
3313 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003314 "data_gen": {
3315 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3316 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003317 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003318 "matmul": {
3319 "op": Op.MATMUL,
3320 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003321 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003322 "build_fcn": (
3323 build_matmul,
3324 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003325 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003326 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003327 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003328 "qgen": TosaQuantGen.qgMatmul,
3329 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003330 "error_if_validators": (
3331 TosaErrorValidator.evInputZeroPointNotZero,
3332 TosaErrorValidator.evWrongRank,
3333 TosaErrorValidator.evWrongInputType,
3334 TosaErrorValidator.evWrongOutputType,
3335 TosaErrorValidator.evWrongInputList,
3336 TosaErrorValidator.evWrongOutputList,
3337 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003338 "data_gen": {
3339 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003340 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003341 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003342 "max_pool2d": {
3343 "op": Op.MAX_POOL2D,
3344 "operands": (1, 0),
3345 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003346 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003347 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003348 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003349 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003350 TosaArgGen.agPooling,
3351 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003352 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003353 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003354 "error_if_validators": (
3355 TosaErrorValidator.evKernelSmallerOne,
3356 TosaErrorValidator.evStrideSmallerOne,
3357 TosaErrorValidator.evPadSmallerZero,
3358 TosaErrorValidator.evWrongRank,
3359 TosaErrorValidator.evWrongInputType,
3360 TosaErrorValidator.evWrongOutputType,
3361 TosaErrorValidator.evWrongInputList,
3362 TosaErrorValidator.evWrongOutputList,
3363 TosaErrorValidator.evPadLargerEqualKernel,
3364 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003365 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003366 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003367 "data_gen": {
3368 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3369 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003370 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003371 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003372 "transpose_conv2d_TEMPLATE": {
3373 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003374 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003375 "rank": (4, 4),
3376 "build_fcn": (
3377 build_transpose_conv2d,
3378 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003379 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003380 TosaArgGen.agTransposeConv2D,
3381 ),
3382 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003383 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003384 "invalid_test_validators": (
3385 TosaInvalidValidator.ivHeightWidthInvalid,
3386 TosaInvalidValidator.ivNonPositiveOutputShape,
3387 ),
3388 "error_if_validators": (
3389 TosaErrorValidator.evWrongInputType,
3390 TosaErrorValidator.evWrongOutputType,
3391 TosaErrorValidator.evWrongInputList,
3392 TosaErrorValidator.evWrongOutputList,
3393 TosaErrorValidator.evInputZeroPointNotZero,
3394 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003395 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003396 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003397 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003398 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003399 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003400 "data_gen": {
3401 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3402 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003403 "template": True,
3404 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003405 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003406 "clamp": {
3407 "op": Op.CLAMP,
3408 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003409 "build_fcn": (
3410 build_clamp,
3411 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003412 TosaTensorValuesGen.tvgLazyGenDefault,
3413 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003414 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003415 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003416 "error_if_validators": (
3417 TosaErrorValidator.evMaxSmallerMin,
3418 TosaErrorValidator.evWrongInputType,
3419 TosaErrorValidator.evWrongOutputType,
3420 TosaErrorValidator.evWrongInputList,
3421 TosaErrorValidator.evWrongOutputList,
3422 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003423 "data_gen": {
3424 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3425 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003426 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003427 "sigmoid": {
3428 "op": Op.SIGMOID,
3429 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003430 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003431 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003432 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003433 TosaTensorValuesGen.tvgLazyGenDefault,
3434 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003435 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003436 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003437 "error_if_validators": (
3438 TosaErrorValidator.evWrongInputType,
3439 TosaErrorValidator.evWrongOutputType,
3440 TosaErrorValidator.evWrongInputList,
3441 TosaErrorValidator.evWrongOutputList,
3442 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003443 "data_gen": {
3444 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3445 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003446 },
3447 "tanh": {
3448 "op": Op.TANH,
3449 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003450 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003451 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003452 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003453 TosaTensorValuesGen.tvgLazyGenDefault,
3454 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003455 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003456 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003457 "error_if_validators": (
3458 TosaErrorValidator.evWrongInputType,
3459 TosaErrorValidator.evWrongOutputType,
3460 TosaErrorValidator.evWrongInputList,
3461 TosaErrorValidator.evWrongOutputList,
3462 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003463 "data_gen": {
3464 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3465 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003466 "compliance": {
3467 "abs_error_lower_bound": 0.5,
3468 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003469 },
Won Jeon78155c62023-06-10 00:20:04 +00003470 "erf": {
3471 "op": Op.ERF,
3472 "operands": (1, 0),
3473 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003474 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003475 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003476 TosaTensorValuesGen.tvgLazyGenDefault,
3477 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003478 ),
3479 "types": TYPE_FP,
3480 "error_if_validators": (
3481 TosaErrorValidator.evWrongInputType,
3482 TosaErrorValidator.evWrongOutputType,
3483 TosaErrorValidator.evWrongInputList,
3484 TosaErrorValidator.evWrongOutputList,
3485 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003486 "data_gen": {
3487 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3488 },
3489 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003490 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003491 # Elementwise Binary Operators
3492 "add": {
3493 "op": Op.ADD,
3494 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003495 "build_fcn": (
3496 build_binary_broadcast,
3497 TosaTensorGen.tgBroadcastFuzz,
3498 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003499 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003500 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003501 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003502 "error_if_validators": (
3503 TosaErrorValidator.evRankMismatch,
3504 TosaErrorValidator.evWrongInputType,
3505 TosaErrorValidator.evWrongOutputType,
3506 TosaErrorValidator.evWrongInputList,
3507 TosaErrorValidator.evWrongOutputList,
3508 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003509 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003510 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003511 "data_gen": {
3512 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3513 },
3514 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003515 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003516 "arithmetic_right_shift": {
3517 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3518 "operands": (2, 0),
3519 "build_fcn": (
3520 build_arithmetic_right_shift,
3521 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003522 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003523 TosaArgGen.agArithmeticRightShift,
3524 ),
3525 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003526 "error_if_validators": (
3527 TosaErrorValidator.evRankMismatch,
3528 TosaErrorValidator.evWrongInputType,
3529 TosaErrorValidator.evWrongOutputType,
3530 TosaErrorValidator.evWrongInputList,
3531 TosaErrorValidator.evWrongOutputList,
3532 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003533 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003534 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003536 "bitwise_and": {
3537 "op": Op.BITWISE_AND,
3538 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003539 "build_fcn": (
3540 build_binary_broadcast,
3541 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003542 TosaTensorValuesGen.tvgLazyGenDefault,
3543 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003544 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003545 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003546 "error_if_validators": (
3547 TosaErrorValidator.evRankMismatch,
3548 TosaErrorValidator.evWrongInputType,
3549 TosaErrorValidator.evWrongOutputType,
3550 TosaErrorValidator.evWrongInputList,
3551 TosaErrorValidator.evWrongOutputList,
3552 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003553 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003554 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003555 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003556 "bitwise_or": {
3557 "op": Op.BITWISE_OR,
3558 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003559 "build_fcn": (
3560 build_binary_broadcast,
3561 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003562 TosaTensorValuesGen.tvgLazyGenDefault,
3563 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003564 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003565 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003566 "error_if_validators": (
3567 TosaErrorValidator.evRankMismatch,
3568 TosaErrorValidator.evWrongInputType,
3569 TosaErrorValidator.evWrongOutputType,
3570 TosaErrorValidator.evWrongInputList,
3571 TosaErrorValidator.evWrongOutputList,
3572 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003573 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003574 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003575 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003576 "bitwise_xor": {
3577 "op": Op.BITWISE_XOR,
3578 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003579 "build_fcn": (
3580 build_binary_broadcast,
3581 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003582 TosaTensorValuesGen.tvgLazyGenDefault,
3583 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003584 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003585 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003586 "error_if_validators": (
3587 TosaErrorValidator.evRankMismatch,
3588 TosaErrorValidator.evWrongInputType,
3589 TosaErrorValidator.evWrongOutputType,
3590 TosaErrorValidator.evWrongInputList,
3591 TosaErrorValidator.evWrongOutputList,
3592 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003593 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003594 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003595 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003596 "intdiv": {
3597 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003598 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003599 "build_fcn": (
3600 build_binary_broadcast,
3601 TosaTensorGen.tgBroadcastFuzz,
3602 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003603 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003604 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003605 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003606 "error_if_validators": (
3607 TosaErrorValidator.evRankMismatch,
3608 TosaErrorValidator.evWrongInputType,
3609 TosaErrorValidator.evWrongOutputType,
3610 TosaErrorValidator.evWrongInputList,
3611 TosaErrorValidator.evWrongOutputList,
3612 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003613 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003614 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003615 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003616 "logical_and": {
3617 "op": Op.LOGICAL_AND,
3618 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003619 "build_fcn": (
3620 build_binary_broadcast,
3621 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003622 TosaTensorValuesGen.tvgLazyGenDefault,
3623 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003624 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003625 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003626 "error_if_validators": (
3627 TosaErrorValidator.evRankMismatch,
3628 TosaErrorValidator.evWrongInputType,
3629 TosaErrorValidator.evWrongOutputType,
3630 TosaErrorValidator.evWrongInputList,
3631 TosaErrorValidator.evWrongOutputList,
3632 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003633 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003634 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003635 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003636 "logical_left_shift": {
3637 "op": Op.LOGICAL_LEFT_SHIFT,
3638 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003639 "build_fcn": (
3640 build_binary_broadcast,
3641 TosaTensorGen.tgBroadcastFuzz,
3642 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003643 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003644 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003645 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003646 "error_if_validators": (
3647 TosaErrorValidator.evRankMismatch,
3648 TosaErrorValidator.evWrongInputType,
3649 TosaErrorValidator.evWrongOutputType,
3650 TosaErrorValidator.evWrongInputList,
3651 TosaErrorValidator.evWrongOutputList,
3652 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003653 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003654 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003655 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003656 "logical_right_shift": {
3657 "op": Op.LOGICAL_RIGHT_SHIFT,
3658 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003659 "build_fcn": (
3660 build_binary_broadcast,
3661 TosaTensorGen.tgBroadcastFuzz,
3662 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003663 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003664 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003665 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003666 "error_if_validators": (
3667 TosaErrorValidator.evRankMismatch,
3668 TosaErrorValidator.evWrongInputType,
3669 TosaErrorValidator.evWrongOutputType,
3670 TosaErrorValidator.evWrongInputList,
3671 TosaErrorValidator.evWrongOutputList,
3672 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003673 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003674 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003675 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003676 "logical_or": {
3677 "op": Op.LOGICAL_OR,
3678 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003679 "build_fcn": (
3680 build_binary_broadcast,
3681 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003682 TosaTensorValuesGen.tvgLazyGenDefault,
3683 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003684 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003685 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003686 "error_if_validators": (
3687 TosaErrorValidator.evRankMismatch,
3688 TosaErrorValidator.evWrongInputType,
3689 TosaErrorValidator.evWrongOutputType,
3690 TosaErrorValidator.evWrongInputList,
3691 TosaErrorValidator.evWrongOutputList,
3692 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003693 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003694 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003695 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003696 "logical_xor": {
3697 "op": Op.LOGICAL_XOR,
3698 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003699 "build_fcn": (
3700 build_binary_broadcast,
3701 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003702 TosaTensorValuesGen.tvgLazyGenDefault,
3703 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003704 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003705 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003706 "error_if_validators": (
3707 TosaErrorValidator.evRankMismatch,
3708 TosaErrorValidator.evWrongInputType,
3709 TosaErrorValidator.evWrongOutputType,
3710 TosaErrorValidator.evWrongInputList,
3711 TosaErrorValidator.evWrongOutputList,
3712 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003713 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003714 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003715 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003716 "maximum": {
3717 "op": Op.MAXIMUM,
3718 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003719 "build_fcn": (
3720 build_binary_broadcast,
3721 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003722 TosaTensorValuesGen.tvgLazyGenDefault,
3723 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003724 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003725 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003726 "error_if_validators": (
3727 TosaErrorValidator.evRankMismatch,
3728 TosaErrorValidator.evWrongInputType,
3729 TosaErrorValidator.evWrongOutputType,
3730 TosaErrorValidator.evWrongInputList,
3731 TosaErrorValidator.evWrongOutputList,
3732 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003733 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003734 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003735 "data_gen": {
3736 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3737 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003738 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003739 "minimum": {
3740 "op": Op.MINIMUM,
3741 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003742 "build_fcn": (
3743 build_binary_broadcast,
3744 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003745 TosaTensorValuesGen.tvgLazyGenDefault,
3746 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003747 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003748 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003749 "error_if_validators": (
3750 TosaErrorValidator.evRankMismatch,
3751 TosaErrorValidator.evWrongInputType,
3752 TosaErrorValidator.evWrongOutputType,
3753 TosaErrorValidator.evWrongInputList,
3754 TosaErrorValidator.evWrongOutputList,
3755 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003756 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003757 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003758 "data_gen": {
3759 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3760 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003761 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003762 "mul": {
3763 "op": Op.MUL,
3764 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003765 "build_fcn": (
3766 build_mul,
3767 TosaTensorGen.tgBroadcastFuzz,
3768 TosaTensorValuesGen.tvgMul,
3769 TosaArgGen.agMul,
3770 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003771 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003772 "error_if_validators": (
3773 TosaErrorValidator.evWrongInputType,
3774 TosaErrorValidator.evWrongOutputType,
3775 TosaErrorValidator.evWrongInputList,
3776 TosaErrorValidator.evWrongOutputList,
3777 TosaErrorValidator.evRankMismatch,
3778 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003779 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003780 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003781 "data_gen": {
3782 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3783 },
3784 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003785 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003786 "pow": {
3787 "op": Op.POW,
3788 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003789 "build_fcn": (
3790 build_binary_broadcast,
3791 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003792 TosaTensorValuesGen.tvgPow,
3793 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003794 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003795 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003796 "error_if_validators": (
3797 TosaErrorValidator.evRankMismatch,
3798 TosaErrorValidator.evWrongInputType,
3799 TosaErrorValidator.evWrongOutputType,
3800 TosaErrorValidator.evWrongInputList,
3801 TosaErrorValidator.evWrongOutputList,
3802 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003803 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003804 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003805 "data_gen": {
3806 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3807 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003808 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003809 "sub": {
3810 "op": Op.SUB,
3811 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003812 "build_fcn": (
3813 build_binary_broadcast,
3814 TosaTensorGen.tgBroadcastFuzz,
3815 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003816 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003817 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003818 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003819 "error_if_validators": (
3820 TosaErrorValidator.evRankMismatch,
3821 TosaErrorValidator.evWrongInputType,
3822 TosaErrorValidator.evWrongOutputType,
3823 TosaErrorValidator.evWrongInputList,
3824 TosaErrorValidator.evWrongOutputList,
3825 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003826 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003827 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003828 "data_gen": {
3829 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3830 },
3831 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003832 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003833 "table": {
3834 "op": Op.TABLE,
3835 # Use the automatic generation functions to create the input array
3836 # but create the table tensor in the build function, as it may be
3837 # a different type from the input
3838 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003839 "build_fcn": (
3840 build_table,
3841 TosaTensorGen.tgBasic,
3842 TosaTensorValuesGen.tvgDefault,
3843 TosaArgGen.agTable,
3844 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003845 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003846 "error_if_validators": (
3847 TosaErrorValidator.evWrongInputType,
3848 TosaErrorValidator.evWrongOutputType,
3849 TosaErrorValidator.evWrongInputList,
3850 TosaErrorValidator.evWrongOutputList,
3851 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003852 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003853 # Elementwise Unary operators
3854 "abs": {
3855 "op": Op.ABS,
3856 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003857 "build_fcn": (
3858 build_unary,
3859 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003860 TosaTensorValuesGen.tvgLazyGenDefault,
3861 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003862 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003863 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003864 "error_if_validators": (
3865 TosaErrorValidator.evWrongInputType,
3866 TosaErrorValidator.evWrongOutputType,
3867 TosaErrorValidator.evWrongInputList,
3868 TosaErrorValidator.evWrongOutputList,
3869 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003870 "data_gen": {
3871 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3872 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003873 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003874 "bitwise_not": {
3875 "op": Op.BITWISE_NOT,
3876 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003877 "build_fcn": (
3878 build_unary,
3879 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003880 TosaTensorValuesGen.tvgLazyGenDefault,
3881 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003882 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003883 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003884 "error_if_validators": (
3885 TosaErrorValidator.evWrongInputType,
3886 TosaErrorValidator.evWrongOutputType,
3887 TosaErrorValidator.evWrongInputList,
3888 TosaErrorValidator.evWrongOutputList,
3889 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003890 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003891 "ceil": {
3892 "op": Op.CEIL,
3893 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003894 "build_fcn": (
3895 build_unary,
3896 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003897 TosaTensorValuesGen.tvgLazyGenDefault,
3898 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003899 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003900 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003901 "error_if_validators": (
3902 TosaErrorValidator.evWrongInputType,
3903 TosaErrorValidator.evWrongOutputType,
3904 TosaErrorValidator.evWrongInputList,
3905 TosaErrorValidator.evWrongOutputList,
3906 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003907 "data_gen": {
3908 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3909 },
3910 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003911 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003912 "clz": {
3913 "op": Op.CLZ,
3914 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003915 "build_fcn": (
3916 build_unary,
3917 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003918 TosaTensorValuesGen.tvgLazyGenDefault,
3919 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003920 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003921 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003922 "error_if_validators": (
3923 TosaErrorValidator.evWrongInputType,
3924 TosaErrorValidator.evWrongOutputType,
3925 TosaErrorValidator.evWrongInputList,
3926 TosaErrorValidator.evWrongOutputList,
3927 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003928 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003929 "exp": {
3930 "op": Op.EXP,
3931 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003932 "build_fcn": (
3933 build_unary,
3934 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003935 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003936 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003937 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003938 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003939 "error_if_validators": (
3940 TosaErrorValidator.evWrongInputType,
3941 TosaErrorValidator.evWrongOutputType,
3942 TosaErrorValidator.evWrongInputList,
3943 TosaErrorValidator.evWrongOutputList,
3944 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003945 "data_gen": {
3946 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3947 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003948 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003949 "floor": {
3950 "op": Op.FLOOR,
3951 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003952 "build_fcn": (
3953 build_unary,
3954 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003955 TosaTensorValuesGen.tvgLazyGenDefault,
3956 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003957 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003958 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003959 "error_if_validators": (
3960 TosaErrorValidator.evWrongInputType,
3961 TosaErrorValidator.evWrongOutputType,
3962 TosaErrorValidator.evWrongInputList,
3963 TosaErrorValidator.evWrongOutputList,
3964 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003965 "data_gen": {
3966 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3967 },
3968 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003969 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003970 "log": {
3971 "op": Op.LOG,
3972 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003973 "build_fcn": (
3974 build_unary,
3975 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003976 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003977 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003978 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003979 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003980 "error_if_validators": (
3981 TosaErrorValidator.evWrongInputType,
3982 TosaErrorValidator.evWrongOutputType,
3983 TosaErrorValidator.evWrongInputList,
3984 TosaErrorValidator.evWrongOutputList,
3985 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003986 "data_gen": {
3987 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3988 },
3989 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003990 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003991 "logical_not": {
3992 "op": Op.LOGICAL_NOT,
3993 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003994 "build_fcn": (
3995 build_unary,
3996 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003997 TosaTensorValuesGen.tvgLazyGenDefault,
3998 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003999 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004000 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004001 "error_if_validators": (
4002 TosaErrorValidator.evWrongInputType,
4003 TosaErrorValidator.evWrongOutputType,
4004 TosaErrorValidator.evWrongInputList,
4005 TosaErrorValidator.evWrongOutputList,
4006 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004007 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004008 "negate": {
4009 "op": Op.NEGATE,
4010 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004011 "build_fcn": (
4012 build_unary,
4013 TosaTensorGen.tgBasic,
4014 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004015 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004016 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004017 "qgen": TosaQuantGen.qgUnary,
4018 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004019 "error_if_validators": (
4020 TosaErrorValidator.evInputZeroPointNotZero,
4021 TosaErrorValidator.evOutputZeroPointNotZero,
4022 TosaErrorValidator.evWrongInputType,
4023 TosaErrorValidator.evWrongOutputType,
4024 TosaErrorValidator.evWrongInputList,
4025 TosaErrorValidator.evWrongOutputList,
4026 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004027 "data_gen": {
4028 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4029 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004030 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004031 "reciprocal": {
4032 "op": Op.RECIPROCAL,
4033 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004034 "build_fcn": (
4035 build_unary,
4036 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004037 TosaTensorValuesGen.tvgLazyGenDefault,
4038 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004039 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004040 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004041 "error_if_validators": (
4042 TosaErrorValidator.evWrongInputType,
4043 TosaErrorValidator.evWrongOutputType,
4044 TosaErrorValidator.evWrongInputList,
4045 TosaErrorValidator.evWrongOutputList,
4046 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004047 "data_gen": {
4048 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4049 },
4050 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004051 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004052 "rsqrt": {
4053 "op": Op.RSQRT,
4054 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004055 "build_fcn": (
4056 build_unary,
4057 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004058 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004059 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004060 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004061 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004062 "error_if_validators": (
4063 TosaErrorValidator.evWrongInputType,
4064 TosaErrorValidator.evWrongOutputType,
4065 TosaErrorValidator.evWrongInputList,
4066 TosaErrorValidator.evWrongOutputList,
4067 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004068 "data_gen": {
4069 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4070 },
4071 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004072 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004073 # Elementwise Ternary operators
4074 "select": {
4075 "op": Op.SELECT,
4076 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004077 "build_fcn": (
4078 build_select,
4079 TosaTensorGen.tgBroadcastFuzz,
4080 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004081 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004082 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004083 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004084 "error_if_validators": (
4085 TosaErrorValidator.evRankMismatch,
4086 TosaErrorValidator.evWrongInputType,
4087 TosaErrorValidator.evWrongOutputType,
4088 TosaErrorValidator.evWrongInputList,
4089 TosaErrorValidator.evWrongOutputList,
4090 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004091 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004092 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004093 "data_gen": {
4094 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4095 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004096 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004097 # Comparison operators
4098 "equal": {
4099 "op": Op.EQUAL,
4100 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004101 "build_fcn": (
4102 build_comparison,
4103 TosaTensorGen.tgBroadcastFuzz,
4104 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004105 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004106 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004107 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004108 "error_if_validators": (
4109 TosaErrorValidator.evRankMismatch,
4110 TosaErrorValidator.evWrongInputType,
4111 TosaErrorValidator.evWrongOutputType,
4112 TosaErrorValidator.evWrongInputList,
4113 TosaErrorValidator.evWrongOutputList,
4114 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004115 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004116 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004117 "data_gen": {
4118 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4119 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004120 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004121 "greater_equal": {
4122 "op": Op.GREATER_EQUAL,
4123 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004124 "build_fcn": (
4125 build_comparison,
4126 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004127 TosaTensorValuesGen.tvgLazyGenDefault,
4128 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004129 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004130 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004131 "error_if_validators": (
4132 TosaErrorValidator.evRankMismatch,
4133 TosaErrorValidator.evWrongInputType,
4134 TosaErrorValidator.evWrongOutputType,
4135 TosaErrorValidator.evWrongInputList,
4136 TosaErrorValidator.evWrongOutputList,
4137 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004138 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004139 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004140 "data_gen": {
4141 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4142 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004143 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004144 "greater": {
4145 "op": Op.GREATER,
4146 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004147 "build_fcn": (
4148 build_comparison,
4149 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004150 TosaTensorValuesGen.tvgLazyGenDefault,
4151 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004152 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004153 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004154 "error_if_validators": (
4155 TosaErrorValidator.evRankMismatch,
4156 TosaErrorValidator.evWrongInputType,
4157 TosaErrorValidator.evWrongOutputType,
4158 TosaErrorValidator.evWrongInputList,
4159 TosaErrorValidator.evWrongOutputList,
4160 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004161 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004162 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004163 "data_gen": {
4164 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4165 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004166 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004167 # Reduction operators
4168 "reduce_all": {
4169 "op": Op.REDUCE_ALL,
4170 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004171 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004172 "build_fcn": (
4173 build_reduce,
4174 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004175 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004176 TosaArgGen.agAxis,
4177 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004178 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004179 "error_if_validators": (
4180 TosaErrorValidator.evAxisLargerRank,
4181 TosaErrorValidator.evAxisSmallerZero,
4182 TosaErrorValidator.evShapeOfAxisNotOne,
4183 TosaErrorValidator.evWrongInputType,
4184 TosaErrorValidator.evWrongOutputType,
4185 TosaErrorValidator.evWrongRank,
4186 TosaErrorValidator.evWrongInputList,
4187 TosaErrorValidator.evWrongOutputList,
4188 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004189 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004190 "reduce_any": {
4191 "op": Op.REDUCE_ANY,
4192 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004193 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004194 "build_fcn": (
4195 build_reduce,
4196 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004197 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004198 TosaArgGen.agAxis,
4199 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004200 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004201 "error_if_validators": (
4202 TosaErrorValidator.evAxisLargerRank,
4203 TosaErrorValidator.evAxisSmallerZero,
4204 TosaErrorValidator.evShapeOfAxisNotOne,
4205 TosaErrorValidator.evWrongInputType,
4206 TosaErrorValidator.evWrongOutputType,
4207 TosaErrorValidator.evWrongRank,
4208 TosaErrorValidator.evWrongInputList,
4209 TosaErrorValidator.evWrongOutputList,
4210 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004211 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004212 "reduce_max": {
4213 "op": Op.REDUCE_MAX,
4214 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004215 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004216 "build_fcn": (
4217 build_reduce,
4218 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004219 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004220 TosaArgGen.agAxis,
4221 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004222 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004223 "error_if_validators": (
4224 TosaErrorValidator.evAxisLargerRank,
4225 TosaErrorValidator.evAxisSmallerZero,
4226 TosaErrorValidator.evShapeOfAxisNotOne,
4227 TosaErrorValidator.evWrongInputType,
4228 TosaErrorValidator.evWrongOutputType,
4229 TosaErrorValidator.evWrongRank,
4230 TosaErrorValidator.evWrongInputList,
4231 TosaErrorValidator.evWrongOutputList,
4232 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004233 "data_gen": {
4234 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4235 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004236 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004237 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004238 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004239 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004240 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004241 "build_fcn": (
4242 build_reduce,
4243 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004244 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004245 TosaArgGen.agAxis,
4246 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004247 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004248 "error_if_validators": (
4249 TosaErrorValidator.evAxisLargerRank,
4250 TosaErrorValidator.evAxisSmallerZero,
4251 TosaErrorValidator.evShapeOfAxisNotOne,
4252 TosaErrorValidator.evWrongInputType,
4253 TosaErrorValidator.evWrongOutputType,
4254 TosaErrorValidator.evWrongRank,
4255 TosaErrorValidator.evWrongInputList,
4256 TosaErrorValidator.evWrongOutputList,
4257 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004258 "data_gen": {
4259 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4260 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004261 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004262 "reduce_product": {
4263 "op": Op.REDUCE_PRODUCT,
4264 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004265 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004266 "build_fcn": (
4267 build_reduce,
4268 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004269 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004270 TosaArgGen.agAxis,
4271 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004272 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004273 "error_if_validators": (
4274 TosaErrorValidator.evAxisLargerRank,
4275 TosaErrorValidator.evAxisSmallerZero,
4276 TosaErrorValidator.evShapeOfAxisNotOne,
4277 TosaErrorValidator.evWrongInputType,
4278 TosaErrorValidator.evWrongOutputType,
4279 TosaErrorValidator.evWrongRank,
4280 TosaErrorValidator.evWrongInputList,
4281 TosaErrorValidator.evWrongOutputList,
4282 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004283 "data_gen": {
4284 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4285 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004286 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004287 "reduce_sum": {
4288 "op": Op.REDUCE_SUM,
4289 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004290 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004291 "build_fcn": (
4292 build_reduce,
4293 TosaTensorGen.tgBasic,
4294 TosaTensorValuesGen.tvgReduceSum,
4295 TosaArgGen.agAxis,
4296 ),
James Ward24dbc422022-10-19 12:20:31 +01004297 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004298 "error_if_validators": (
4299 TosaErrorValidator.evAxisLargerRank,
4300 TosaErrorValidator.evAxisSmallerZero,
4301 TosaErrorValidator.evShapeOfAxisNotOne,
4302 TosaErrorValidator.evWrongInputType,
4303 TosaErrorValidator.evWrongOutputType,
4304 TosaErrorValidator.evWrongRank,
4305 TosaErrorValidator.evWrongInputList,
4306 TosaErrorValidator.evWrongOutputList,
4307 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004308 "data_gen": {
4309 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4310 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004311 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004312 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004313 "concat": {
4314 "op": Op.CONCAT,
4315 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004316 "build_fcn": (
4317 build_concat,
4318 TosaTensorGen.tgConcat,
4319 TosaTensorValuesGen.tvgConcat,
4320 TosaArgGen.agAxis,
4321 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004322 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004323 "error_if_validators": (
4324 TosaErrorValidator.evAxisLargerRank,
4325 TosaErrorValidator.evAxisSmallerZero,
4326 TosaErrorValidator.evConcatInputRankMismatch,
4327 TosaErrorValidator.evConcatShapeSumMismatch,
4328 TosaErrorValidator.evConcatInputDimMismatch,
4329 TosaErrorValidator.evWrongInputType,
4330 TosaErrorValidator.evWrongOutputType,
4331 TosaErrorValidator.evWrongOutputList,
4332 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004333 "data_gen": {
4334 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4335 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004336 },
4337 "pad": {
4338 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004339 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004340 "build_fcn": (
4341 build_pad,
4342 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004343 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004344 TosaArgGen.agPad,
4345 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004346 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004347 "error_if_validators": (
4348 TosaErrorValidator.evWrongInputType,
4349 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004350 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004351 TosaErrorValidator.evWrongOutputType,
4352 TosaErrorValidator.evWrongInputList,
4353 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004354 TosaErrorValidator.evRankMismatch,
4355 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004356 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004357 "data_gen": {
4358 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4359 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004360 },
Won Jeona21b2e82023-08-10 10:33:01 +00004361 "dim": {
4362 "op": Op.DIM,
4363 "operands": (1, 0),
4364 "build_fcn": (
4365 build_dim,
4366 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004367 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004368 TosaArgGen.agAxis,
4369 ),
4370 "types": TYPE_FIB,
4371 "error_if_validators": (
4372 TosaErrorValidator.evAxisLargerRank,
4373 TosaErrorValidator.evAxisSmallerZero,
4374 TosaErrorValidator.evWrongInputType,
4375 TosaErrorValidator.evWrongInputList,
4376 TosaErrorValidator.evWrongOutputList,
4377 TosaErrorValidator.evWrongRank,
4378 ),
4379 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004380 "reshape": {
4381 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004382 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004383 "build_fcn": (
4384 build_reshape,
4385 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004386 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004387 TosaArgGen.agReshape,
4388 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004389 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004390 "error_if_validators": (
4391 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4392 TosaErrorValidator.evWrongInputType,
4393 TosaErrorValidator.evWrongOutputType,
4394 TosaErrorValidator.evWrongInputList,
4395 TosaErrorValidator.evWrongOutputList,
4396 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004397 "data_gen": {
4398 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4399 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004400 },
4401 "reverse": {
4402 "op": Op.REVERSE,
4403 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004404 "build_fcn": (
4405 build_reverse,
4406 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004407 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004408 TosaArgGen.agAxis,
4409 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004410 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004411 "error_if_validators": (
4412 TosaErrorValidator.evAxisSmallerZero,
4413 TosaErrorValidator.evAxisLargerRank,
4414 TosaErrorValidator.evWrongInputType,
4415 TosaErrorValidator.evWrongOutputType,
4416 TosaErrorValidator.evWrongInputList,
4417 TosaErrorValidator.evWrongOutputList,
4418 ),
evacha0198477222024-01-26 12:25:32 +00004419 "data_gen": {
4420 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4421 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004422 },
4423 "slice": {
4424 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004425 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004426 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004427 "build_fcn": (
4428 build_slice,
4429 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004430 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004431 TosaArgGen.agSlice,
4432 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004433 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004434 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004435 # TODO Turn off these error categories for now as the reference
4436 # model cannot allocate memory space for empty tensor. We probably
4437 # can report an accurate error messege at the right place during
4438 # exeuction.
4439 # TosaErrorValidator.evStartSmallerZero,
4440 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004441 TosaErrorValidator.evStartSizeOutsideBounds,
4442 TosaErrorValidator.evSizeOutputShapeMismatch,
4443 TosaErrorValidator.evInputSizeStartLengthMismatch,
4444 TosaErrorValidator.evWrongRank,
4445 TosaErrorValidator.evWrongInputType,
4446 TosaErrorValidator.evWrongOutputType,
4447 TosaErrorValidator.evWrongInputList,
4448 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004449 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004450 ),
evacha017f7d4252024-01-24 12:08:09 +00004451 "data_gen": {
4452 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4453 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004454 },
4455 "tile": {
4456 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004457 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004458 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004459 "build_fcn": (
4460 build_tile,
4461 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004462 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004463 TosaArgGen.agTile,
4464 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004465 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004466 "error_if_validators": (
4467 TosaErrorValidator.evWrongInputType,
4468 TosaErrorValidator.evWrongOutputType,
4469 TosaErrorValidator.evWrongInputList,
4470 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004471 TosaErrorValidator.evRankMismatch,
4472 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004473 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004474 "data_gen": {
4475 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4476 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004477 },
4478 "transpose": {
4479 "op": Op.TRANSPOSE,
4480 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004481 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004482 "build_fcn": (
4483 build_transpose,
4484 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004485 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004486 TosaArgGen.agTranspose,
4487 ),
4488 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004489 "error_if_validators": (
4490 TosaErrorValidator.evIndexOutsideBounds,
4491 TosaErrorValidator.evIndexUsedTwice,
4492 TosaErrorValidator.evWrongInputType,
4493 TosaErrorValidator.evWrongOutputType,
4494 TosaErrorValidator.evWrongInputList,
4495 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004496 TosaErrorValidator.evWrongRank,
4497 TosaErrorValidator.evRankMismatch,
4498 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004499 ),
evacha0198477222024-01-26 12:25:32 +00004500 "data_gen": {
4501 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4502 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004503 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004504 # Data nodes
4505 "const": {
4506 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004507 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004508 "build_fcn": (
4509 build_const,
4510 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004511 TosaTensorValuesGen.tvgLazyGenDefault,
4512 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004513 ),
Luke Hutton65872422023-02-20 10:33:04 +00004514 "types": TYPE_FIB + [DType.INT48],
evacha0198477222024-01-26 12:25:32 +00004515 "data_gen": {
4516 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4517 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004518 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004519 "identity": {
4520 "op": Op.IDENTITY,
4521 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004522 "build_fcn": (
4523 build_unary,
4524 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004525 TosaTensorValuesGen.tvgLazyGenDefault,
4526 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004527 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004528 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004529 "data_gen": {
4530 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4531 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004532 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004533 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004534 "gather": {
4535 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004536 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004537 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004538 "build_fcn": (
4539 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004540 TosaTensorGen.tgGather,
4541 TosaTensorValuesGen.tvgGather,
4542 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004543 ),
James Ward24dbc422022-10-19 12:20:31 +01004544 "types": (
4545 DType.INT8,
4546 DType.INT16,
4547 DType.INT32,
4548 DType.FP16,
4549 DType.BF16,
4550 DType.FP32,
4551 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004552 "error_if_validators": (
4553 TosaErrorValidator.evWrongInputType,
4554 TosaErrorValidator.evWrongOutputType,
4555 TosaErrorValidator.evWrongInputList,
4556 TosaErrorValidator.evWrongOutputList,
4557 TosaErrorValidator.evWrongRank,
4558 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004559 "data_gen": {
4560 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4561 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004562 },
4563 "scatter": {
4564 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004565 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004566 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004567 "build_fcn": (
4568 build_scatter,
4569 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004570 TosaTensorValuesGen.tvgScatter,
4571 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004572 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004573 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004574 "error_if_validators": (
4575 TosaErrorValidator.evWrongInputType,
4576 TosaErrorValidator.evWrongOutputType,
4577 TosaErrorValidator.evWrongInputList,
4578 TosaErrorValidator.evWrongOutputList,
4579 TosaErrorValidator.evWrongRank,
4580 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004581 "data_gen": {
4582 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4583 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004584 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004585 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004586 "resize": {
4587 "op": Op.RESIZE,
4588 "operands": (1, 0),
4589 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004590 "build_fcn": (
4591 build_resize,
4592 TosaTensorGen.tgNHWC,
4593 TosaTensorValuesGen.tvgDefault,
4594 TosaArgGen.agResize,
4595 ),
James Ward24dbc422022-10-19 12:20:31 +01004596 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004597 "invalid_test_validators": (
4598 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004599 ),
4600 "error_if_validators": (
4601 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004602 TosaErrorValidator.evScaleSmallerEqualZero,
4603 TosaErrorValidator.evScaleNLargerMax,
4604 TosaErrorValidator.evScaleDLargerMax,
4605 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004606 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004607 TosaErrorValidator.evBorderSmallerMin,
4608 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004609 TosaErrorValidator.evWrongInputType,
4610 TosaErrorValidator.evWrongOutputType,
4611 TosaErrorValidator.evWrongRank,
4612 TosaErrorValidator.evWrongInputList,
4613 TosaErrorValidator.evWrongOutputList,
4614 TosaErrorValidator.evBatchMismatch,
4615 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004616 TosaErrorValidator.evResizeOutputShapeMismatch,
4617 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004618 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004619 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004620 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004621 "cast": {
4622 "op": Op.CAST,
4623 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004624 "build_fcn": (
4625 build_cast,
4626 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004627 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004628 TosaArgGen.agCast,
4629 ),
James Ward8b390432022-08-12 20:48:56 +01004630 "types": (
4631 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004632 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004633 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004634 DType.INT8,
4635 DType.INT16,
4636 DType.INT32,
4637 DType.BOOL,
4638 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004639 "error_if_validators": (
4640 TosaErrorValidator.evWrongInputType,
4641 TosaErrorValidator.evWrongOutputType,
4642 TosaErrorValidator.evWrongInputList,
4643 TosaErrorValidator.evWrongOutputList,
4644 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004645 "data_gen": {
4646 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4647 },
4648 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004649 },
4650 "rescale": {
4651 "op": Op.RESCALE,
4652 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004653 "build_fcn": (
4654 build_rescale,
4655 TosaTensorGen.tgBasic,
4656 TosaTensorValuesGen.tvgDefault,
4657 TosaArgGen.agRescale,
4658 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004659 "types": [
4660 DType.UINT8,
4661 DType.INT8,
4662 DType.INT16,
4663 DType.INT32,
4664 DType.INT48,
4665 DType.UINT16,
4666 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004667 "error_if_validators": (
4668 TosaErrorValidator.evInputZeroPointNotZero,
4669 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004670 TosaErrorValidator.evU16InputZeroPointNotValid,
4671 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004672 TosaErrorValidator.evScaleTrue,
4673 TosaErrorValidator.evScaleNotTrue,
4674 TosaErrorValidator.evWrongInputType,
4675 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004676 TosaErrorValidator.evWrongInputList,
4677 TosaErrorValidator.evWrongOutputList,
4678 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004679 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004680 # Custom
4681 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004682 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004683 # Two varients of cond_if, one that generates one of two constant tensors (no
4684 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4685 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004686 "cond_if_const": {
4687 "op": Op.COND_IF,
4688 "operands": (0, 2),
4689 "build_fcn": (
4690 build_cond_if_const,
4691 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004692 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004693 TosaArgGen.agCondIf,
4694 ),
4695 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004696 "error_if_validators": (
4697 TosaErrorValidator.evOutputListThenGraphMismatch,
4698 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004699 TosaErrorValidator.evCondIfCondNotMatchingBool,
4700 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004701 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004702 },
4703 "cond_if_binary": {
4704 "op": Op.COND_IF,
4705 "operands": (2, 0),
4706 "build_fcn": (
4707 build_cond_if_binary,
4708 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004709 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004710 TosaArgGen.agCondIf,
4711 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004712 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004713 "error_if_validators": (
4714 TosaErrorValidator.evInputListThenGraphMismatch,
4715 TosaErrorValidator.evInputListElseGraphMismatch,
4716 TosaErrorValidator.evOutputListThenGraphMismatch,
4717 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004718 TosaErrorValidator.evCondIfCondNotMatchingBool,
4719 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004720 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004721 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004722 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004723 "while_loop": {
4724 "op": Op.WHILE_LOOP,
4725 "operands": (0, 1),
4726 "build_fcn": (
4727 build_while_loop,
4728 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004729 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004730 TosaArgGen.agWhileLoop,
4731 ),
4732 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004733 "error_if_validators": (
4734 TosaErrorValidator.evInputListOutputListMismatch,
4735 TosaErrorValidator.evInputListCondGraphMismatch,
4736 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4737 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4738 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004739 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004740 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004741 },
Luke Hutton57287132023-02-06 14:54:18 +00004742 "fft2d": {
4743 "op": Op.FFT2D,
4744 "operands": (2, 0),
4745 "rank": (3, 3),
4746 "build_fcn": (
4747 build_fft2d,
4748 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004749 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004750 TosaArgGen.agFFT2d,
4751 ),
4752 "types": [DType.FP32],
4753 "error_if_validators": (
4754 TosaErrorValidator.evWrongInputType,
4755 TosaErrorValidator.evWrongOutputType,
4756 TosaErrorValidator.evWrongInputList,
4757 TosaErrorValidator.evWrongOutputList,
4758 TosaErrorValidator.evWrongRank,
4759 TosaErrorValidator.evBatchMismatch,
4760 TosaErrorValidator.evKernelNotPowerOfTwo,
4761 TosaErrorValidator.evFFTInputShapeMismatch,
4762 TosaErrorValidator.evFFTOutputShapeMismatch,
4763 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004764 "data_gen": {
4765 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4766 },
Luke Hutton57287132023-02-06 14:54:18 +00004767 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004768 "rfft2d": {
4769 "op": Op.RFFT2D,
4770 "operands": (1, 0),
4771 "rank": (3, 3),
4772 "build_fcn": (
4773 build_rfft2d,
4774 TosaTensorGen.tgRFFT2d,
4775 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004776 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004777 ),
4778 "types": [DType.FP32],
4779 "error_if_validators": (
4780 TosaErrorValidator.evWrongInputType,
4781 TosaErrorValidator.evWrongOutputType,
4782 TosaErrorValidator.evWrongInputList,
4783 TosaErrorValidator.evWrongOutputList,
4784 TosaErrorValidator.evWrongRank,
4785 TosaErrorValidator.evBatchMismatch,
4786 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004787 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004788 ),
4789 },
Won Jeon74342e52024-01-09 00:34:40 +00004790 # Shape
4791 "add_shape": {
4792 "op": Op.ADD_SHAPE,
4793 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004794 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004795 "build_fcn": (
4796 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004797 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004798 TosaTensorValuesGen.tvgAddSub,
4799 TosaArgGen.agNone,
4800 ),
4801 "types": [DType.SHAPE],
4802 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4803 },
4804 "sub_shape": {
4805 "op": Op.SUB_SHAPE,
4806 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004807 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004808 "build_fcn": (
4809 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004810 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004811 TosaTensorValuesGen.tvgAddSub,
4812 TosaArgGen.agNone,
4813 ),
4814 "types": [DType.SHAPE],
4815 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4816 },
4817 "mul_shape": {
4818 "op": Op.MUL_SHAPE,
4819 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004820 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004821 "build_fcn": (
4822 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004823 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004824 TosaTensorValuesGen.tvgMul,
4825 TosaArgGen.agNone,
4826 ),
4827 "types": [DType.SHAPE],
4828 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4829 },
4830 "div_shape": {
4831 "op": Op.DIV_SHAPE,
4832 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004833 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004834 "build_fcn": (
4835 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004836 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004837 TosaTensorValuesGen.tvgIntDiv,
4838 TosaArgGen.agNone,
4839 ),
4840 "types": [DType.SHAPE],
4841 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4842 },
4843 "concat_shape": {
4844 "op": Op.CONCAT_SHAPE,
4845 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004846 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004847 "build_fcn": (
4848 build_concat,
4849 TosaTensorGen.tgConcat,
4850 TosaTensorValuesGen.tvgConcat,
4851 TosaArgGen.agNone,
4852 ),
4853 "types": [DType.SHAPE],
4854 "error_if_validators": (),
4855 },
4856 "const_shape": {
4857 "op": Op.CONST_SHAPE,
4858 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004859 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004860 "build_fcn": (
4861 build_const,
4862 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004863 TosaTensorValuesGen.tvgLazyGenDefault,
4864 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00004865 ),
4866 "types": [DType.SHAPE],
4867 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004868 }
4869
Kevin Cheng550ccc52021-03-03 11:21:43 -08004870
Eric Kunzee5e26762020-10-13 16:11:07 -07004871class OutputShaper:
4872 # Methods in this class compute the expected output shape and datatype
4873 # for common classes of operations
4874 def __init__(self):
4875 pass
4876
4877 # These methods return arguments that can be used for
4878 # creating a new output tensor
4879 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004880 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4881 if error_name != ErrorIf.RankMismatch:
4882 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004883 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004884
4885 shape = []
4886 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004887 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004888 shape.append(b.shape[i])
4889 else:
4890 shape.append(a.shape[i])
4891
Jerry Ge135c9552023-05-23 20:59:32 +00004892 fuzz_idx = rng.integers(0, len(a.shape))
4893 if error_name == ErrorIf.DimensionMismatch:
4894 shape[fuzz_idx] += 1
4895
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004896 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004897 all_dtypes = [
4898 DType.INT8,
4899 DType.INT16,
4900 DType.INT32,
4901 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004902 DType.FP16,
4903 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004904 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004905 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004906 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4907 outputDType = rng.choice(wrong_dtypes)
4908 else:
4909 outputDType = a.dtype
4910
4911 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004912
4913 @staticmethod
4914 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004915 assert len(a.shape) == len(b.shape)
4916 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004917
4918 shape = []
4919 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004920 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004921 shape.append(a.shape[i])
4922
Kevin Cheng550ccc52021-03-03 11:21:43 -08004923 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004924
4925 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004926 def unaryOp(ser, rng, a, error_name=None):
4927 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004928 all_dtypes = [
4929 DType.INT8,
4930 DType.INT16,
4931 DType.INT32,
4932 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004933 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004934 DType.FP16,
4935 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004936 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004937 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4938 outputDType = rng.choice(wrong_dtypes)
4939 else:
4940 outputDType = a.dtype
4941
4942 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004943
4944 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004945 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004946 if error_name != ErrorIf.RankMismatch:
4947 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004948 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004949
4950 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004951 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004952 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004953 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4954 else:
4955 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004956
Jerry Ge135c9552023-05-23 20:59:32 +00004957 fuzz_idx = rng.integers(0, len(a.shape))
4958 if error_name == ErrorIf.DimensionMismatch:
4959 shape[fuzz_idx] += 1
4960
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004961 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004962 all_dtypes = [
4963 DType.INT8,
4964 DType.INT16,
4965 DType.INT32,
4966 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004967 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004968 DType.FP16,
4969 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004970 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004971 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4972 outputDType = rng.choice(wrong_dtypes)
4973 else:
4974 outputDType = a.dtype
4975
4976 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004977
4978 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004979 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004980 if error_name != ErrorIf.RankMismatch:
4981 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004982 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004983
4984 # Do broadcast
4985 shape = []
4986 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004987 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004988 shape.append(b.shape[i])
4989 else:
4990 shape.append(a.shape[i])
4991
Jerry Ge135c9552023-05-23 20:59:32 +00004992 fuzz_idx = rng.integers(0, len(a.shape))
4993 if error_name == ErrorIf.DimensionMismatch:
4994 shape[fuzz_idx] += 1
4995
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004996 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004997 wrong_dtypes = [
4998 DType.INT8,
4999 DType.INT16,
5000 DType.INT32,
5001 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005002 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005003 DType.FP16,
5004 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005005 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005006 outputDType = rng.choice(wrong_dtypes)
5007 else:
5008 outputDType = DType.BOOL
5009
5010 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005011
5012 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005013 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005014 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005015 if error_name not in [
5016 ErrorIf.AxisSmallerZero,
5017 ErrorIf.AxisLargerRank,
5018 ErrorIf.ShapeOfAxisNotOne,
5019 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005020 shape[axis] = 1
5021 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5022 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005023
Matthew Haddond6ce7252021-09-29 15:35:44 +01005024 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005025 all_dtypes = [
5026 DType.INT8,
5027 DType.INT16,
5028 DType.INT32,
5029 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005030 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005031 DType.FP16,
5032 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005033 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005034 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5035 outputDType = rng.choice(wrong_dtypes)
5036 else:
5037 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005038
Matthew Haddond6ce7252021-09-29 15:35:44 +01005039 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005040
5041 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005042 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005043 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005044
5045 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5046 del shape[axis]
5047
5048 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5049 remove = rng.choice([True, False])
5050 if remove and len(shape) > 1:
5051 del shape[0]
5052 else:
5053 shape.append(1)
5054 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5055 for i in range(len(shape)):
5056 shape[i] = shape[i] + rng.integers(1, 10)
5057
5058 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005059 all_dtypes = [
5060 DType.INT8,
5061 DType.INT16,
5062 DType.INT32,
5063 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005064 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005065 DType.FP16,
5066 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005067 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005068 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5069 outputDType = rng.choice(wrong_dtypes)
5070 else:
5071 outputDType = DType.INT32
5072
5073 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005074
5075 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005076 def conv2dOp(
5077 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5078 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005079
5080 # IFM: NHWC
5081 # Filter: OHWI
5082 # OFM: NHWC
5083
Kevin Cheng550ccc52021-03-03 11:21:43 -08005084 h = (
5085 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005086 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005087 + padding[0]
5088 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005089 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005090 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005091
Kevin Cheng550ccc52021-03-03 11:21:43 -08005092 w = (
5093 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005094 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005095 + padding[2]
5096 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005097 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005098 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005099
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005100 if error_name == ErrorIf.ConvOutputShapeMismatch:
5101 choices = [1, 2, 3]
5102 change = rng.choice(choices)
5103 # increment in multiples of stride to not hit non-integer error case
5104 if change in [1, 3]:
5105 h = h + (rng.choice(choices) * strides[0])
5106 if change in [2, 3]:
5107 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005108
Eric Kunzee5e26762020-10-13 16:11:07 -07005109 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5110
James Ward8b390432022-08-12 20:48:56 +01005111 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005112 # Pick some potentially correct output dtype if input type is incorrect
5113 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005114 else:
James Ward8b390432022-08-12 20:48:56 +01005115 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005116
5117 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005118 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005119 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005120 else:
5121 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005122 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005123 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005124
Kevin Cheng550ccc52021-03-03 11:21:43 -08005125 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005126
5127 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005128 def conv3dOp(
5129 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5130 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005131
5132 # IFM: NDHWC
5133 # Filter: ODHWI
5134 # OFM: NDHWC
5135
5136 d = (
5137 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005138 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005139 + padding[0]
5140 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005141 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005142 ) // strides[0] + 1
5143
5144 h = (
5145 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005146 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005147 + padding[2]
5148 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005149 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005150 ) // strides[1] + 1
5151
5152 w = (
5153 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005154 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005155 + padding[4]
5156 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005157 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005158 ) // strides[2] + 1
5159
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005160 if error_name == ErrorIf.ConvOutputShapeMismatch:
5161 choices = [1, 2, 3, 4]
5162 change = rng.choice(choices)
5163 # increment in multiples of stride to not hit non-integer error case
5164 if change in [1, 4]:
5165 d = d + (rng.choice(choices) * strides[0])
5166 if change in [2, 4]:
5167 h = h + (rng.choice(choices) * strides[1])
5168 if change in [3, 4]:
5169 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005170
Kevin Cheng1533b852021-09-01 12:51:58 -07005171 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5172
James Ward8b390432022-08-12 20:48:56 +01005173 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005174 # Pick some potentially correct output dtype if input type is incorrect
5175 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005176 else:
James Ward8b390432022-08-12 20:48:56 +01005177 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005178
5179 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005180 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005181 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005182 else:
5183 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005184 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005185 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005186
5187 return ser.addOutput(ofm_shape, out_dtype)
5188
5189 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005190 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005191 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005192 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005193 # IFM: NHWC
5194 # Filter: HWCM
5195 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005196
Kevin Cheng550ccc52021-03-03 11:21:43 -08005197 h = (
5198 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005199 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005200 + padding[0]
5201 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005202 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005203 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005204
Kevin Cheng550ccc52021-03-03 11:21:43 -08005205 w = (
5206 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005207 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005208 + padding[2]
5209 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005210 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005211 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005212
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005213 if error_name == ErrorIf.ConvOutputShapeMismatch:
5214 choices = [1, 2, 3]
5215 change = rng.choice(choices)
5216 # increment in multiples of stride to not hit non-integer error case
5217 if change in [1, 3]:
5218 h = h + (rng.choice(choices) * strides[0])
5219 if change in [2, 3]:
5220 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005221
Eric Kunzee5e26762020-10-13 16:11:07 -07005222 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5223
James Ward8b390432022-08-12 20:48:56 +01005224 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005225 # Pick some potentially correct output dtype if input type is incorrect
5226 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005227 else:
James Ward8b390432022-08-12 20:48:56 +01005228 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005229
5230 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005231 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005232 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005233 else:
5234 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005235 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005236 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005237
Kevin Cheng550ccc52021-03-03 11:21:43 -08005238 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005239
5240 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005241 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005242 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005243 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005244 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005245 h = 1
5246 w = 1
5247 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005248 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5249 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005250
5251 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005252 choices = [1, 2, 3]
5253 change = rng.choice(choices)
5254 # increment in multiples of stride to not hit non-integer error case
5255 if change in [1, 3]:
5256 h = h + (rng.choice(choices) * stride[0])
5257 if change in [2, 3]:
5258 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005259 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005260
5261 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005262 all_dtypes = [
5263 DType.INT8,
5264 DType.INT16,
5265 DType.INT32,
5266 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005267 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005268 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005269 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005270 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005271 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5272 outputDType = rng.choice(wrong_dtypes)
5273 else:
5274 outputDType = ifm.dtype
5275
5276 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005277
5278 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005279 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005280 # input: N, IC
5281 # filter: OC, IC
5282 # output: N, OC
5283
5284 output_shape = [input.shape[0], filter.shape[0]]
5285
James Ward8b390432022-08-12 20:48:56 +01005286 # Validated in arg_gen (also invalidated for ErrorIf)
5287 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005288
Kevin Cheng550ccc52021-03-03 11:21:43 -08005289 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005290
5291 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005292 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005293 # a: N, H, C
5294 # b: N, C, W
5295 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005296
Kevin Cheng2d60f002021-06-09 14:18:32 -07005297 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005298
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005299 if error_name == ErrorIf.WrongOutputType:
5300 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005301 incorrect_types = (
5302 DType.INT4,
5303 DType.INT8,
5304 DType.INT16,
5305 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005306 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005307 DType.FP16,
5308 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005309 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005310 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005311 incorrect_types = (
5312 DType.INT4,
5313 DType.INT8,
5314 DType.INT16,
5315 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005316 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005317 DType.FP16,
5318 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005319 )
James Ward24dbc422022-10-19 12:20:31 +01005320 elif (
5321 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5322 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005323 incorrect_types = (
5324 DType.INT4,
5325 DType.INT8,
5326 DType.INT16,
5327 DType.INT32,
5328 DType.INT48,
5329 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005330 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005331 elif error_name == ErrorIf.WrongInputType:
5332 # Pick some potentially correct output dtype if input type is incorrect
5333 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005334 else:
James Ward8b390432022-08-12 20:48:56 +01005335 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005336
Kevin Cheng550ccc52021-03-03 11:21:43 -08005337 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005338
5339 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005340 def concatOp(ser, rng, axis, inputs, error_name=None):
5341 input1 = inputs[0]
5342 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005343
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005344 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005345 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005346 if not (
5347 # unable to concat tensors of different ranks
5348 error_name == ErrorIf.ConcatInputRankMismatch
5349 # unable to concat tensors along an invalid axis
5350 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005351 ):
5352 for tensor in remaining_inputs:
5353 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005354
Matthew Haddon01c359d2021-10-15 16:30:48 +01005355 if error_name == ErrorIf.ConcatShapeSumMismatch:
5356 output_shape[axis] += rng.integers(5, 10)
5357
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005358 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005359 all_dtypes = {
5360 DType.INT8,
5361 DType.INT16,
5362 DType.INT32,
5363 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005364 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005365 DType.FP16,
5366 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005367 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005368 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5369 outputDType = rng.choice(wrong_dtypes)
5370 else:
5371 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005372
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005373 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005374
5375 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005376 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005377
5378 output_shape = a.shape.copy()
5379
5380 for i in range(len(output_shape)):
5381 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5382
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005383 if error_name == ErrorIf.PadOutputShapeMismatch:
5384 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005385 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005386 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005387 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005388
Matthew Haddone807aae2021-10-11 18:12:58 +01005389 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005390 all_dtypes = [
5391 DType.INT8,
5392 DType.INT16,
5393 DType.INT32,
5394 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005395 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005396 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005397 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005398 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005399 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5400 outputDType = rng.choice(wrong_dtypes)
5401 else:
5402 outputDType = a.dtype
5403
5404 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005405
5406 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005407 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005408 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005409
5410 if error_name == ErrorIf.WrongOutputType:
5411 all_dtypes = [
5412 DType.INT8,
5413 DType.INT16,
5414 DType.INT32,
5415 DType.INT48,
5416 DType.FP32,
5417 DType.FP16,
5418 DType.BF16,
5419 ]
5420 wrong_dtypes = list(set(all_dtypes))
5421 outputDType = rng.choice(wrong_dtypes)
5422 else:
5423 outputDType = DType.SHAPE
5424
5425 return ser.addOutput(output_shape, outputDType)
5426
5427 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005428 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005429 output_shape = shape.copy()
5430
Matthew Haddone807aae2021-10-11 18:12:58 +01005431 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5432 for i in range(len(output_shape)):
5433 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5434
5435 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005436 all_dtypes = [
5437 DType.INT8,
5438 DType.INT16,
5439 DType.INT32,
5440 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005441 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005442 DType.FP16,
5443 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005444 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005445 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5446 outputDType = rng.choice(wrong_dtypes)
5447 else:
5448 outputDType = a.dtype
5449
5450 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005451
5452 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005453 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005454
Matthew Haddone807aae2021-10-11 18:12:58 +01005455 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005456 all_dtypes = [
5457 DType.INT8,
5458 DType.INT16,
5459 DType.INT32,
5460 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005461 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005462 DType.FP16,
5463 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005464 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005465 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005466 outputDType = rng.choice(wrong_dtypes)
5467 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005468 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005469
Luke Huttona4e48ca2023-02-22 11:53:48 +00005470 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005471 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005472 for index in range(len(output_shape)):
5473 if output_shape[index] <= 2:
5474 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5475 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005476 output_shape[index] = output_shape[index] + rng.choice(
5477 [-2, -1, 1, 2]
5478 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005479 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5480 output_shape = input.shape.copy()
5481 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005482 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005483
5484 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005485
5486 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005487 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005488
5489 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005490 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005491
5492 for i in range(len(output_shape)):
5493 output_shape[i] = a.shape[i] * multiples[i]
5494
Luke Huttona4e48ca2023-02-22 11:53:48 +00005495 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005496 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005497
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005498 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005499 all_dtypes = [
5500 DType.INT8,
5501 DType.INT16,
5502 DType.INT32,
5503 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005504 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005505 DType.FP16,
5506 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005507 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005508 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5509 outputDType = rng.choice(wrong_dtypes)
5510 else:
5511 outputDType = a.dtype
5512
5513 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005514
5515 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005516 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005517 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005518
Kevin Cheng550ccc52021-03-03 11:21:43 -08005519 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005520
Luke Huttona4e48ca2023-02-22 11:53:48 +00005521 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005522 for i in range(len(output_shape)):
5523 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005524
Luke Huttona4e48ca2023-02-22 11:53:48 +00005525 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5526 for i in range(len(output_shape)):
5527 output_shape[i] += rng.integers(1, 10)
5528 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005529 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005530
Matthew Haddone807aae2021-10-11 18:12:58 +01005531 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005532 all_dtypes = [
5533 DType.INT8,
5534 DType.INT16,
5535 DType.INT32,
5536 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005537 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005538 DType.FP16,
5539 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005540 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005541 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5542 outputDType = rng.choice(wrong_dtypes)
5543 else:
5544 outputDType = a.dtype
5545
5546 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005547
5548 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005549 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005550 if error_name != ErrorIf.WrongRank:
5551 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005552 assert len(indices.shape) == 2
5553 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005554
Kevin Cheng77d0f762020-11-24 10:26:32 -08005555 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5556
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005557 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005558 all_dtypes = [
5559 DType.INT8,
5560 DType.INT16,
5561 DType.INT32,
5562 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005563 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005564 DType.FP16,
5565 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005566 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005567 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5568 outputDType = rng.choice(wrong_dtypes)
5569 else:
5570 outputDType = values.dtype
5571
5572 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005573
5574 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005575 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005576 if error_name != ErrorIf.WrongRank:
5577 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005578 assert len(indices.shape) == 2
5579 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005580 assert values_in.shape[0] == indices.shape[0] # N
5581 assert input.shape[1] == indices.shape[1] # W
5582 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005583
5584 output_shape = values_in.shape
5585
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005586 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005587 all_dtypes = [
5588 DType.INT8,
5589 DType.INT16,
5590 DType.INT32,
5591 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005592 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005593 DType.FP16,
5594 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005595 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005596 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5597 outputDType = rng.choice(wrong_dtypes)
5598 else:
5599 outputDType = values_in.dtype
5600
5601 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005602
5603 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005604 def tableOp(ser, rng, input, error_name=None):
5605 # Same shape as the input, dtype dependent on input dtype
5606 if error_name != ErrorIf.WrongInputType:
5607 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005608 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005609 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005610 wrong_dtypes = [
5611 DType.INT8,
5612 DType.INT16,
5613 DType.INT32,
5614 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005615 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005616 DType.FP16,
5617 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005618 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005619 wrong_dtypes.remove(output_dtype)
5620 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005621 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005622
5623 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005624 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005625 serializer,
5626 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005627 input,
5628 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005629 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005630 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005631 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005632 input_dtype,
5633 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005634 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005635 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005636 # Calculate OH, OW
5637 scale_y_n = scale[0]
5638 scale_y_d = scale[1]
5639 scale_x_n = scale[2]
5640 scale_x_d = scale[3]
5641 if error_name == ErrorIf.ScaleSmallerEqualZero:
5642 scale_y_n = max(scale_y_n, 1)
5643 scale_y_d = max(scale_y_d, 1)
5644 scale_x_n = max(scale_x_n, 1)
5645 scale_x_d = max(scale_x_d, 1)
5646
5647 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5648 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5649
5650 if error_name is not None:
5651 # Make sure the output tensor is valid, which can occur when
5652 # scale, offset or border have been changed for ERROR_IFs
5653 oh = max(oh, 1)
5654 ow = max(ow, 1)
5655 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005656 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5657 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005658
5659 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5660 choices = [1, 2, 3]
5661 change = rng.choice(choices)
5662 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5663 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005664 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005665 oh -= scale_y_d
5666 assert oh > 0 # Should have been caught in agResize
5667 else:
5668 oh += scale_y_d
5669 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005670 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005671 ow -= scale_x_d
5672 assert ow > 0 # Should have been caught in agResize
5673 else:
5674 ow += scale_x_d
5675
Matthew Haddon848efb42021-09-09 12:30:53 +01005676 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005677 output_dims = [
5678 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005679 oh,
5680 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005681 input.shape[0],
5682 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005683 elif error_name == ErrorIf.BatchMismatch:
5684 output_dims = [
5685 input.shape[0] + rng.integers(1, 10),
5686 oh,
5687 ow,
5688 input.shape[3],
5689 ]
5690 elif error_name == ErrorIf.ChannelMismatch:
5691 output_dims = [
5692 input.shape[0],
5693 oh,
5694 ow,
5695 input.shape[3] + rng.integers(1, 10),
5696 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005697 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005698 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005699
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005700 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005701
5702 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005703 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005704 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005705
5706 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005707 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005708 if error_name == ErrorIf.ConvOutputShapeMismatch:
5709 choices = [1, 2, 3]
5710 change = rng.choice(choices)
5711 if change in [1, 3]:
5712 output_shape[1] = output_shape[1] + rng.choice(choices)
5713 if change in [2, 3]:
5714 output_shape[2] = output_shape[2] + rng.choice(choices)
5715
James Ward8b390432022-08-12 20:48:56 +01005716 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005717 # Pick some potentially correct output dtype if input type is incorrect
5718 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005719 else:
James Ward8b390432022-08-12 20:48:56 +01005720 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005721
5722 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005723 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005724 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005725 else:
5726 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005727 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005728 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005729
Kevin Cheng550ccc52021-03-03 11:21:43 -08005730 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005731
5732 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005733 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5734 outputs = []
5735
5736 assert ifm1.dtype == ifm2.dtype
5737 input_dtype = ifm1.dtype
5738
5739 if error_name != ErrorIf.FFTInputShapeMismatch:
5740 assert ifm1.shape == ifm2.shape
5741
5742 input_shape = ifm1.shape
5743 if error_name != ErrorIf.WrongRank:
5744 assert len(input_shape) == 3
5745
5746 output_shape = input_shape.copy()
5747 output_dtype = input_dtype
5748
5749 if error_name == ErrorIf.WrongOutputType:
5750 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005751 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005752 output_dtype = rng.choice(wrong_dtypes)
5753 elif error_name == ErrorIf.BatchMismatch:
5754 output_shape[0] += rng.integers(1, 10)
5755 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5756 modify_dim = rng.choice([1, 2])
5757 output_shape[modify_dim] += rng.integers(1, 10)
5758
5759 outputs.append(serializer.addOutput(output_shape, output_dtype))
5760 outputs.append(serializer.addOutput(output_shape, output_dtype))
5761 return outputs
5762
5763 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005764 def rfft2dOp(serializer, rng, value, error_name=None):
5765 outputs = []
5766
5767 input_shape = value.shape
5768 if error_name != ErrorIf.WrongRank:
5769 assert len(input_shape) == 3
5770
5771 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5772
5773 output_dtype = value.dtype
5774 if error_name == ErrorIf.WrongOutputType:
5775 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005776 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005777 output_dtype = rng.choice(wrong_dtypes)
5778 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005779 output_shape[0] += rng.integers(1, 10)
5780 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5781 modify_dim = rng.choice([1, 2])
5782 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005783
5784 outputs.append(serializer.addOutput(output_shape, output_dtype))
5785 outputs.append(serializer.addOutput(output_shape, output_dtype))
5786 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005787
5788 @staticmethod
5789 def addShapeOp(ser, rng, a, b, error_name=None):
5790 if error_name != ErrorIf.RankMismatch:
5791 assert len(a.shape) == len(b.shape)
5792 assert a.dtype == b.dtype
5793
5794 shape = []
5795 for i in range(len(a.shape)):
5796 shape.append(a.shape[i])
5797
5798 fuzz_idx = rng.integers(0, len(a.shape))
5799 if error_name == ErrorIf.DimensionMismatch:
5800 shape[fuzz_idx] += 1
5801
5802 if error_name == ErrorIf.WrongOutputType:
5803 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5804 outputDType = rng.choice(wrong_dtypes)
5805 else:
5806 outputDType = DType.SHAPE
5807 return ser.addOutput(shape, outputDType)