blob: bfafd23567d026679c00372a964ba82e32e149bd [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
79 for dtype in (DType.FP32, DType.FP16, DType.BF16):
80 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000198 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 return np.int64(self.rng.integers(low=low, high=high, size=shape))
200 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
201 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
202
203 if dtype == DType.FP16:
204 return np.float16(f_tensor)
205 else:
206 f32_tensor = np.float32(f_tensor)
207 if dtype == DType.BF16:
208 # Floor the last 16 bits of each f32 value
209 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
210 else:
211 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100213 # All other integer types
214 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
Kevin Cheng989cb052021-04-28 16:29:44 -0700216 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700217 placeholders = []
218
Kevin Cheng989cb052021-04-28 16:29:44 -0700219 assert len(shape_list) == len(dtype_list)
220
Jeremy Johnson1271c442023-09-05 11:39:26 +0100221 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700222 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 if not self.args.lazy_data_gen:
224 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700225 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700226
227 return placeholders
228
Kevin Cheng989cb052021-04-28 16:29:44 -0700229 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700230 consts = []
231
Kevin Cheng989cb052021-04-28 16:29:44 -0700232 assert len(shape_list) == len(dtype_list)
233
Jeremy Johnson1271c442023-09-05 11:39:26 +0100234 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700235 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100236 if not self.args.lazy_data_gen:
237 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700238 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700239
240 return consts
241
242 def makeShape(self, rank):
243 if self.targetted_shape:
244 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 return np.int32(
246 self.rng.integers(
247 low=self.args.tensor_shape_range[0],
248 high=self.args.tensor_shape_range[1],
249 size=rank,
250 )
251 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700252
253 def setTargetShape(self, shape):
254 self.targetted_shape = shape
255
256 def randInt(self, low=0, high=256):
257 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
258
259 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 low, high = self.getDTypeRange(dtype)
261
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100262 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100263 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100264 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100265 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100266 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100267 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
268 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700269 elif dtype == DType.BOOL:
270 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000271 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700272 # Special size
273 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 return np.int32(self.rng.integers(low, high, size=1))[0]
276
277 def shapeStr(self, shape):
278
279 sStr = []
280 # Convert to strings
281 for i in shape:
282 sStr.append(str(i))
283
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100286 def typeStr(self, dtype):
287 if isinstance(dtype, list) or isinstance(dtype, tuple):
288 assert len(dtype) >= 2
289 strs = [self.typeStr(t) for t in dtype]
290 # Limit types to the first 2 as the 3rd is the accumulator
291 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100293 if dtype in gtu.DTYPE_ATTRIBUTES:
294 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700295 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100296 raise Exception(
297 "Unknown dtype, cannot convert to string: {}".format(dtype)
298 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100300 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100301 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100302 if dtype in gtu.DTYPE_ATTRIBUTES:
303 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100305 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700306
Luke Hutton57287132023-02-06 14:54:18 +0000307 def constrictBatchSize(self, shape):
308 # Limit the batch size unless an explicit target shape set
309 if self.args.max_batch_size and not self.args.target_shapes:
310 shape[0] = min(shape[0], self.args.max_batch_size)
311 return shape
312
James Ward30124a82023-02-02 14:56:33 +0000313 def makeDimension(self):
314 return self.randInt(
315 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
316 )
317
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100318 def tensorComplianceMetaData(
319 self, op, inputType, argsDict, outputTensor, errorName
320 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000321 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
322 UNSUPPORTED_NON_FP32_INPUT_OPS = (
323 Op.MATMUL,
324 Op.CONV2D,
325 Op.FULLY_CONNECTED,
326 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000327 Op.TRANSPOSE_CONV2D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000328 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100329 if (
330 errorName
331 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000332 or (
333 not gtu.dtypeIsSupportedByCompliance(inputType)
334 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
335 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100336 ):
337 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100338 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100339
Jeremy Johnson1271c442023-09-05 11:39:26 +0100340 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100341 compliance_tens = {
342 "mode": None,
343 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
344 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
345 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100346 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
347 mode = gtu.ComplianceMode.DOT_PRODUCT
348 compliance_tens["dot_product_info"] = {
349 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100350 "ks": int(argsDict["ksb"])
351 if "ksb" in argsDict
352 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100353 }
354 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
355 mode = gtu.ComplianceMode.FP_SPECIAL
356 elif "compliance" in op and "ulp" in op["compliance"]:
357 mode = gtu.ComplianceMode.ULP
358 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
359 elif op["op"] == Op.REDUCE_PRODUCT:
360 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000361 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000362 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000363 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000364 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
365 compliance_tens["abs_error_info"] = {
366 "lower_bound": op["compliance"]["abs_error_lower_bound"]
367 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100368 else:
369 mode = gtu.ComplianceMode.EXACT
370 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
371
372 return compliance_tens
373
374 # Build Op functions
375 # Create the output tensor (calling OutputShaper as needed)
376 # Do final tweaks to attributes (if necessary for errorIf)
377 # Add Op into graph
378 # Return resulting tensor information or BuildInfo
379
380 class BuildInfo:
381 """Enhanced build information containing result tensor and associated compliance dict."""
382
383 def __init__(self, resultTensor, complianceDict):
384 self.resultTensor = resultTensor
385 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700386
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000387 def build_unary(
388 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
389 ):
390 assert len(inputs) == 1
391 a = inputs[0]
392 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100393
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000394 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100395
396 # Ensure new output type has correct qinfo
397 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000398 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000399 qinfo = [
400 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000401 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000402 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100403
404 # Invalidate Input/Output list for error if checks.
405 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000406 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100407 pCount, cCount = op["operands"]
408 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000409 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
410 self, error_name, input_list, output_list
411 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100412
Les Bell729b0352021-11-24 10:28:21 +0000413 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100414 self.ser,
415 validator_fcns,
416 error_name,
417 op=op,
418 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000419 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000420 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000421 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100422 input_list=input_list,
423 output_list=output_list,
424 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000425 ):
426 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100427
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000428 attr = None
429 if op["op"] == Op.NEGATE:
430 attr = ts.TosaSerializerAttribute()
431 attr.NegateAttribute(qinfo[0], qinfo[1])
432
433 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000434
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000435 compliance = self.tensorComplianceMetaData(
436 op, a.dtype, args_dict, result_tensor, error_name
437 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000438 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700439
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000440 def build_binary_broadcast(
441 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
442 ):
443 assert len(inputs) == 2
444 a, b = inputs
445 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000446 self.ser, self.rng, a, b, error_name
447 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100448
449 # Invalidate Input/Output list for error if checks.
450 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000451 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100452 pCount, cCount = op["operands"]
453 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000454 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
455 self, error_name, input_list, output_list
456 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100457
Les Bell729b0352021-11-24 10:28:21 +0000458 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100459 self.ser,
460 validator_fcns,
461 error_name,
462 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000463 input1=a,
464 input2=b,
465 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000466 output_dtype=result_tensor.dtype,
467 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100468 input_list=input_list,
469 output_list=output_list,
470 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000471 ):
472 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100473
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000474 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000475
Jeremy Johnson9a758382023-11-07 16:27:35 +0000476 compliance = self.tensorComplianceMetaData(
477 op, a.dtype, args_dict, result_tensor, error_name
478 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000479
480 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700481
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100482 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700483 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000484 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700485 return result_tens
486
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000487 def build_arithmetic_right_shift(
488 self, op, a, b, round, validator_fcns=None, error_name=None
489 ):
490 result_tens = OutputShaper.binaryBroadcastOp(
491 self.ser, self.rng, a, b, error_name
492 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100493
494 # Invalidate Input/Output list for error if checks.
495 input_list = [a.name, b.name]
496 output_list = [result_tens.name]
497 pCount, cCount = op["operands"]
498 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000499 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
500 self, error_name, input_list, output_list
501 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100502
Les Bell729b0352021-11-24 10:28:21 +0000503 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100504 self.ser,
505 validator_fcns,
506 error_name,
507 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000508 input1=a,
509 input2=b,
510 input_dtype=a.dtype,
511 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000512 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100513 input_list=input_list,
514 output_list=output_list,
515 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000516 ):
517 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800518
519 attr = ts.TosaSerializerAttribute()
520 attr.ArithmeticRightShiftAttribute(round)
521
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000522 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800523 return result_tens
524
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100525 def build_mul(
526 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
527 ):
528 assert len(inputs) == 2
529 a, b = inputs
530 shift = args_dict["shift"]
531
532 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000533 self.ser, self.rng, a, b, error_name
534 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700535
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100536 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100537 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100538 result_tensor.setDtype(DType.INT32)
539
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100540 if error_name == ErrorIf.WrongOutputType:
541 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
542 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100543 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100544
545 # Invalidate Input/Output list for error if checks.
546 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100547 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100548 pCount, cCount = op["operands"]
549 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000550 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
551 self, error_name, input_list, output_list
552 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100553
Les Bell729b0352021-11-24 10:28:21 +0000554 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100555 self.ser,
556 validator_fcns,
557 error_name,
558 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000559 input1=a,
560 input2=b,
561 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100562 output_dtype=result_tensor.dtype,
563 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100564 input_list=input_list,
565 output_list=output_list,
566 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000567 ):
568 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700569
Kevin Chengaee1fac2020-11-11 13:54:06 -0800570 attr = ts.TosaSerializerAttribute()
571 attr.MulAttribute(shift)
572
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000573 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100574
575 compliance = self.tensorComplianceMetaData(
576 op, a.dtype, args_dict, result_tensor, error_name
577 )
578
579 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700580
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100581 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
582 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700583
Kevin Chengfe392ce2021-10-18 21:51:55 +0000584 attr = ts.TosaSerializerAttribute()
585 attr.TableAttribute(table)
586
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100587 # Invalidate Input/Output list for error if checks.
588 input_list = [a.name]
589 output_list = [result_tens.name]
590 pCount, cCount = op["operands"]
591 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000592 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
593 self, error_name, input_list, output_list
594 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100595
Les Bell729b0352021-11-24 10:28:21 +0000596 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100597 self.ser,
598 validator_fcns,
599 error_name,
600 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000601 input_shape=a.shape,
602 input_dtype=a.dtype,
603 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000604 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100605 input_list=input_list,
606 output_list=output_list,
607 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000608 ):
609 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100610
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000611 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700612
613 return result_tens
614
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000615 def build_select(
616 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
617 ):
618 assert len(inputs) == 3
619 cond, a, b = inputs
620
621 result_tensor = OutputShaper.selectOp(
622 self.ser, self.rng, cond, a, b, error_name
623 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100624
625 # Invalidate Input/Output list for error if checks.
626 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000627 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100628 pCount, cCount = op["operands"]
629 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000630 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
631 self, error_name, input_list, output_list
632 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100633
Les Bell729b0352021-11-24 10:28:21 +0000634 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100635 self.ser,
636 validator_fcns,
637 error_name,
638 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000639 input1=cond,
640 input2=a,
641 input3=b,
642 input_shape=a.shape,
643 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000644 output_dtype=result_tensor.dtype,
645 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100646 input_list=input_list,
647 output_list=output_list,
648 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000649 ):
650 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100651
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000652 self.ser.addOperator(
653 op["op"],
654 input_list,
655 output_list,
656 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000657 compliance = self.tensorComplianceMetaData(
658 op, a.dtype, args_dict, result_tensor, error_name
659 )
660
661 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700662
Jeremy Johnsona0150012023-11-15 15:52:06 +0000663 def build_comparison(
664 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
665 ):
666 assert len(inputs) == 2
667 a, b = inputs
668
669 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000670 self.ser, self.rng, a, b, error_name
671 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100672
673 # Invalidate Input/Output list for error if checks.
674 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000675 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100676 pCount, cCount = op["operands"]
677 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000678 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
679 self, error_name, input_list, output_list
680 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100681
Les Bell729b0352021-11-24 10:28:21 +0000682 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100683 self.ser,
684 validator_fcns,
685 error_name,
686 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000687 input1=a,
688 input2=b,
689 input_shape=a.shape,
690 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000691 output_shape=result_tensor.shape,
692 output_dtype=result_tensor.dtype,
693 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100694 input_list=input_list,
695 output_list=output_list,
696 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000697 ):
698 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100699
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000700 self.ser.addOperator(
701 op["op"],
702 input_list,
703 output_list,
704 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000705
706 compliance = self.tensorComplianceMetaData(
707 op, a.dtype, args_dict, result_tensor, error_name
708 )
709 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700710
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000711 def build_argmax(
712 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
713 ):
714 assert len(inputs) == 1
715 a = inputs[0]
716 axis = args_dict["axis"]
717 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100718
719 # Invalidate Input/Output list for error if checks.
720 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000721 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100722 pCount, cCount = op["operands"]
723 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000724 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
725 self, error_name, input_list, output_list
726 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100727
Les Bell729b0352021-11-24 10:28:21 +0000728 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100729 self.ser,
730 validator_fcns,
731 error_name,
732 op=op,
733 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000734 input_shape=a.shape,
735 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000736 output_shape=result_tensor.shape,
737 output_dtype=result_tensor.dtype,
738 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100739 input_list=input_list,
740 output_list=output_list,
741 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000742 ):
743 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700744
745 attr = ts.TosaSerializerAttribute()
746 attr.AxisAttribute(axis)
747
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000748 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000749
750 compliance = self.tensorComplianceMetaData(
751 op, inputs[0].dtype, args_dict, result_tensor, error_name
752 )
753 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700754
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000755 def build_pool2d(
756 self,
757 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100758 inputs,
759 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000760 validator_fcns=None,
761 error_name=None,
762 qinfo=None,
763 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100764 assert len(inputs) == 1
765 input = inputs[0]
766 # max_pool has no accum_dtype
767 accum_dtype = (
768 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
769 )
770 stride = args_dict["stride"]
771 pad = args_dict["pad"]
772 kernel = args_dict["kernel"]
773
Jeremy Johnson0601f802023-11-08 16:28:09 +0000774 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000775 self.ser, self.rng, input, kernel, stride, pad, error_name
776 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100777
778 # Ensure new output type has correct qinfo
779 if error_name == ErrorIf.WrongInputType:
780 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000781 qinfo = [
782 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000783 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000784 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100785
786 # Invalidate Input/Output list for error if checks.
787 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000788 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100789 pCount, cCount = op["operands"]
790 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000791 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
792 self, error_name, input_list, output_list
793 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100794
Les Bell729b0352021-11-24 10:28:21 +0000795 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100796 self.ser,
797 validator_fcns,
798 error_name,
799 op=op,
800 input_shape=input.shape,
801 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000802 output_shape=result_tensor.shape,
803 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100804 kernel=kernel,
805 stride=stride,
806 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000807 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000808 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100809 input_list=input_list,
810 output_list=output_list,
811 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000812 ):
813 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700814
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000815 if qinfo is None:
816 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700817
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000818 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100819 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000820
821 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700822
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100823 compliance = self.tensorComplianceMetaData(
824 op, inputs[0].dtype, args_dict, result_tensor, error_name
825 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100826
827 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100828
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000829 def build_conv2d(
830 self,
831 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100832 inputs,
833 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000834 validator_fcns=None,
835 error_name=None,
836 qinfo=None,
837 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100838 assert len(inputs) == 3
839 ifm, filter, bias = inputs
840 accum_dtype = args_dict["acc_type"]
841 strides = args_dict["stride"]
842 padding = args_dict["pad"]
843 dilations = args_dict["dilation"]
844
Kevin Cheng550ccc52021-03-03 11:21:43 -0800845 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100846 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100847 self.ser,
848 self.rng,
849 ifm,
850 filter,
851 accum_dtype,
852 strides,
853 padding,
854 dilations,
855 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000856 )
857
858 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000859 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
860 DType.INT8,
861 DType.UINT8,
862 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000863 qinfo = [
864 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100865 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000866 ]
Les Bell0e027d42021-11-09 14:42:14 +0000867
868 # Invalidate Input/Output list for error_if checks.
869 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100870 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000871 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000872 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
873 self, error_name, input_list, output_list
874 )
Les Bell0e027d42021-11-09 14:42:14 +0000875
Les Bell729b0352021-11-24 10:28:21 +0000876 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000877 self.ser,
878 validator_fcns,
879 error_name,
880 op=op,
881 input_dtype=ifm.dtype,
882 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100883 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000884 qinfo=qinfo,
885 input_list=input_list,
886 num_operands=num_operands,
887 output_list=output_list,
888 pad=padding,
889 stride=strides,
890 dilation=dilations,
891 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100892 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100893 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000894 ):
895 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700896
Tai Lyd3797f02023-11-15 23:06:19 +0000897 # TODO - Test local_bound, for now set local bound attribute to False
898 local_bound = False
899
Eric Kunzee5e26762020-10-13 16:11:07 -0700900 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000901 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700902
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000903 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100904
905 compliance = self.tensorComplianceMetaData(
906 op, ifm.dtype, args_dict, result_tensor, error_name
907 )
908
909 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700910
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000911 def build_conv3d(
912 self,
913 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100914 inputs,
915 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000916 validator_fcns=None,
917 error_name=None,
918 qinfo=None,
919 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100920 assert len(inputs) == 3
921 ifm, filter, bias = inputs
922 accum_dtype = args_dict["acc_type"]
923 strides = args_dict["stride"]
924 padding = args_dict["pad"]
925 dilations = args_dict["dilation"]
926
Kevin Cheng1533b852021-09-01 12:51:58 -0700927 assert len(padding) == 6
928 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100929 self.ser,
930 self.rng,
931 ifm,
932 filter,
933 accum_dtype,
934 strides,
935 padding,
936 dilations,
937 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000938 )
939
940 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000941 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
942 DType.INT8,
943 DType.UINT8,
944 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000945 qinfo = [
946 TosaQuantGen.getZeroPoint(self, ifm.dtype),
947 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
948 ]
Les Bell0e027d42021-11-09 14:42:14 +0000949
950 # Invalidate Input/Output list for error_if checks.
951 input_list = [ifm.name, filter.name, bias.name]
952 output_list = [result_tens.name]
953 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000954 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
955 self, error_name, input_list, output_list
956 )
Les Bell0e027d42021-11-09 14:42:14 +0000957
Les Bell729b0352021-11-24 10:28:21 +0000958 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000959 self.ser,
960 validator_fcns,
961 error_name,
962 op=op,
963 input_dtype=ifm.dtype,
964 weight_dtype=filter.dtype,
965 output_dtype=result_tens.dtype,
966 qinfo=qinfo,
967 input_list=input_list,
968 num_operands=num_operands,
969 output_list=output_list,
970 pad=padding,
971 stride=strides,
972 dilation=dilations,
973 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100974 weight_shape=filter.shape,
975 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000976 ):
977 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700978
Tai Lyd3797f02023-11-15 23:06:19 +0000979 # TODO - Test local_bound, for now set local bound attribute to False
980 local_bound = False
981
Kevin Cheng1533b852021-09-01 12:51:58 -0700982 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000983 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700984
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000985 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700986 return result_tens
987
Kevin Cheng550ccc52021-03-03 11:21:43 -0800988 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000989 self,
990 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000991 inputs,
992 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000993 validator_fcns=None,
994 error_name=None,
995 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800996 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +0000997 assert len(inputs) == 3
998 ifm, filter, bias = inputs
999 accum_dtype = args_dict["acc_type"]
1000 strides = args_dict["stride"]
1001 out_pad = args_dict["pad"]
1002 output_shape = args_dict["out_shape"]
1003
TatWai Chong24594f52022-06-08 00:48:04 -07001004 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001005 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001006 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001007 )
Les Bell0e027d42021-11-09 14:42:14 +00001008
1009 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001010 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1011 DType.INT8,
1012 DType.UINT8,
1013 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001014 qinfo = [
1015 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001016 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001017 ]
Les Bell0e027d42021-11-09 14:42:14 +00001018
1019 # Invalidate Input/Output list for error_if checks.
1020 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001021 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001022 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001023 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1024 self, error_name, input_list, output_list
1025 )
Les Bell0e027d42021-11-09 14:42:14 +00001026
Les Bell729b0352021-11-24 10:28:21 +00001027 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001028 self.ser,
1029 validator_fcns,
1030 error_name,
1031 op=op,
1032 input_dtype=ifm.dtype,
1033 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001034 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001035 qinfo=qinfo,
1036 input_list=input_list,
1037 num_operands=num_operands,
1038 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001039 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001040 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001041 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001042 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001043 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001044 ):
1045 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001046
Tai Lyd3797f02023-11-15 23:06:19 +00001047 # TODO - Test local_bound, for now set local bound attribute to False
1048 local_bound = False
1049
Eric Kunzee5e26762020-10-13 16:11:07 -07001050 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001051 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001052 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001053 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001054
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001055 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001056
1057 compliance = self.tensorComplianceMetaData(
1058 op, ifm.dtype, args_dict, result_tensor, error_name
1059 )
1060
1061 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001062
Kevin Cheng550ccc52021-03-03 11:21:43 -08001063 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001064 self,
1065 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001066 inputs,
1067 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001068 validator_fcns=None,
1069 error_name=None,
1070 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001071 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001072 assert len(inputs) == 3
1073 ifm, filter, bias = inputs
1074 accum_dtype = args_dict["acc_type"]
1075 strides = args_dict["stride"]
1076 padding = args_dict["pad"]
1077 dilations = args_dict["dilation"]
1078
Jeremy Johnson4f931302024-01-04 17:05:24 +00001079 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001080 self.ser,
1081 self.rng,
1082 ifm,
1083 filter,
1084 accum_dtype,
1085 strides,
1086 padding,
1087 dilations,
1088 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001089 )
1090
1091 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001092 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1093 DType.INT8,
1094 DType.UINT8,
1095 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001096 qinfo = [
1097 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001098 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001099 ]
Les Bell0e027d42021-11-09 14:42:14 +00001100
1101 # Invalidate Input/Output list for error_if checks.
1102 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001103 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001104 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001105 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1106 self, error_name, input_list, output_list
1107 )
Les Bell0e027d42021-11-09 14:42:14 +00001108
Les Bell729b0352021-11-24 10:28:21 +00001109 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001110 self.ser,
1111 validator_fcns,
1112 error_name,
1113 op=op,
1114 input_dtype=ifm.dtype,
1115 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001116 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001117 qinfo=qinfo,
1118 input_list=input_list,
1119 num_operands=num_operands,
1120 output_list=output_list,
1121 pad=padding,
1122 stride=strides,
1123 dilation=dilations,
1124 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001125 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001126 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001127 ):
1128 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001129
Tai Lyd3797f02023-11-15 23:06:19 +00001130 # TODO - Test local_bound, for now set local bound attribute to False
1131 local_bound = False
1132
Eric Kunzee5e26762020-10-13 16:11:07 -07001133 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001134 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001135
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001136 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001137
1138 compliance = self.tensorComplianceMetaData(
1139 op, ifm.dtype, args_dict, result_tensor, error_name
1140 )
1141
1142 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001143
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001144 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001145 self,
1146 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001147 inputs,
1148 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001149 validator_fcns=None,
1150 error_name=None,
1151 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001152 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001153 assert len(inputs) == 3
1154 ifm, filter, bias = inputs
1155 accum_dtype = args_dict["acc_type"]
1156
1157 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001158 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001159 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001160
1161 # Invalidate Input/Output list for error if checks.
1162 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001163 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001164 pCount, cCount = op["operands"]
1165 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001166 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1167 self, error_name, input_list, output_list
1168 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001169
Les Bell729b0352021-11-24 10:28:21 +00001170 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001171 self.ser,
1172 validator_fcns,
1173 error_name,
1174 op=op,
1175 input_shape=ifm.shape,
1176 input_dtype=ifm.dtype,
1177 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001178 output_shape=result_tensor.shape,
1179 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001180 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001181 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001182 input_list=input_list,
1183 output_list=output_list,
1184 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001185 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001186 ):
1187 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001188
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001189 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001190 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001191
1192 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001193
1194 compliance = self.tensorComplianceMetaData(
1195 op, ifm.dtype, args_dict, result_tensor, error_name
1196 )
1197
1198 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001199
James Ward8b390432022-08-12 20:48:56 +01001200 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001201 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001202 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001203 assert len(inputs) == 2
1204 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001205 accum_dtype = args_dict["acc_type"]
1206 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001207 self.ser, self.rng, a, b, accum_dtype, error_name
1208 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001209
1210 # Invalidate Input/Output list for error if checks.
1211 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001212 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001213 pCount, cCount = op["operands"]
1214 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001215 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1216 self, error_name, input_list, output_list
1217 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001218
Les Bell729b0352021-11-24 10:28:21 +00001219 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001220 self.ser,
1221 validator_fcns,
1222 error_name,
1223 op=op,
1224 input_shape=a.shape,
1225 input_dtype=a.dtype,
1226 input2_shape=b.shape,
1227 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001228 output_shape=result_tensor.shape,
1229 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001230 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001231 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001232 input_list=input_list,
1233 output_list=output_list,
1234 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001235 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001236 ):
1237 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001238
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001239 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001240 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001241
1242 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001243
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001244 compliance = self.tensorComplianceMetaData(
1245 op, a.dtype, args_dict, result_tensor, error_name
1246 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001247
1248 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001249
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001250 def build_reduce(
1251 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1252 ):
1253 assert len(inputs) == 1
1254 a = inputs[0]
1255 axis = args_dict["axis"]
1256 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001257
1258 # Invalidate Input/Output list for error if checks.
1259 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001260 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001261 pCount, cCount = op["operands"]
1262 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001263 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1264 self, error_name, input_list, output_list
1265 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001266
Les Bell729b0352021-11-24 10:28:21 +00001267 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001268 self.ser,
1269 validator_fcns,
1270 error_name,
1271 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001272 axis=axis,
1273 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001274 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001275 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001276 output_dtype=result_tensor.dtype,
1277 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001278 input_list=input_list,
1279 output_list=output_list,
1280 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001281 ):
1282 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001283
1284 attr = ts.TosaSerializerAttribute()
1285 attr.AxisAttribute(axis)
1286
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001287 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001288
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001289 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1290 # Number of products - needed for compliance
1291 args_dict["n"] = a.shape[axis]
1292
1293 compliance = self.tensorComplianceMetaData(
1294 op, a.dtype, args_dict, result_tensor, error_name
1295 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001296
1297 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001298
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001299 def build_clamp(
1300 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1301 ):
1302 assert len(inputs) == 1
1303 a = inputs[0]
1304
1305 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001306
Jeremy Johnson18e26662021-07-22 16:15:29 +01001307 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001308
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001309 if error_name == ErrorIf.MaxSmallerMin:
1310 # Make sure the numbers are different to invoke this error
1311 while v[0] == v[1]:
1312 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1313 max_val = min(v)
1314 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001315 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001316 max_val = max(v)
1317 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001318
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001319 # Invalidate Input/Output list for error if checks.
1320 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001321 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001322 pCount, cCount = op["operands"]
1323 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001324 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1325 self, error_name, input_list, output_list
1326 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001327
Les Bell729b0352021-11-24 10:28:21 +00001328 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001329 self.ser,
1330 validator_fcns,
1331 error_name,
1332 op=op,
1333 max_val=max_val,
1334 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001335 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001336 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001337 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001338 output_dtype=result_tensor.dtype,
1339 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001340 input_list=input_list,
1341 output_list=output_list,
1342 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001343 ):
1344 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001345
1346 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001347 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1348 if a.dtype == DType.FP16:
1349 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1350 min_val = min_val.astype(np.float32)
1351 max_val = max_val.astype(np.float32)
1352
1353 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001354 else:
James Ward34071252022-12-07 15:48:47 +00001355 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001356
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001357 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001358
1359 compliance = self.tensorComplianceMetaData(
1360 op, a.dtype, args_dict, result_tensor, error_name
1361 )
1362
1363 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001364
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001365 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1366 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001367 attr = ts.TosaSerializerAttribute()
1368
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001369 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001370
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001371 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001372 return result_tens
1373
1374 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001375 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1376 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001377
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001378 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001379 return result_tens
1380
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001381 def build_activation(
1382 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1383 ):
1384 assert len(inputs) == 1
1385 a = inputs[0]
1386
1387 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001388
1389 # Invalidate Input/Output list for error if checks.
1390 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001391 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001392 pCount, cCount = op["operands"]
1393 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001394 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1395 self, error_name, input_list, output_list
1396 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001397
Les Bell729b0352021-11-24 10:28:21 +00001398 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001399 self.ser,
1400 validator_fcns,
1401 error_name,
1402 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001403 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001404 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001405 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001406 output_dtype=result_tensor.dtype,
1407 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001408 input_list=input_list,
1409 output_list=output_list,
1410 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001411 ):
1412 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001413
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001414 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001415
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001416 compliance = self.tensorComplianceMetaData(
1417 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001418 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001419
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001420 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001421
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001422 def build_concat(
1423 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1424 ):
Won Jeon74342e52024-01-09 00:34:40 +00001425 if op["op"] == Op.CONCAT_SHAPE:
1426 axis = 0
1427 else:
1428 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001429 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001430 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001431
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001432 result_tensor = OutputShaper.concatOp(
1433 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001434 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001435
Matthew Haddon818ab902021-07-27 09:12:49 +01001436 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001437 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001438 input_tensor_names.append(tensor.name)
1439
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001440 # Invalidate Input/Output list for error if checks.
1441 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001442 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001443 pCount, cCount = op["operands"]
1444 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001445 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1446 self, error_name, input_list, output_list
1447 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001448
Les Bell729b0352021-11-24 10:28:21 +00001449 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001450 self.ser,
1451 validator_fcns,
1452 error_name,
1453 op=op,
1454 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001455 input_shape=inputs[0].shape,
1456 output_shape=result_tensor.shape,
1457 input_dtype=inputs[0].dtype,
1458 output_dtype=result_tensor.dtype,
1459 inputs=inputs,
1460 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001461 input_list=input_list,
1462 output_list=output_list,
1463 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001464 ):
1465 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001466
Won Jeon74342e52024-01-09 00:34:40 +00001467 if op["op"] == Op.CONCAT:
1468 attr = ts.TosaSerializerAttribute()
1469 attr.AxisAttribute(axis)
1470 else:
1471 assert op["op"] == Op.CONCAT_SHAPE
1472 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001473 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001474
1475 compliance = self.tensorComplianceMetaData(
1476 op, inputs[0].dtype, args_dict, result_tensor, error_name
1477 )
1478
1479 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001480
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001481 def build_pad(
1482 self,
1483 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001484 inputs,
1485 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001486 validator_fcns=None,
1487 error_name=None,
1488 qinfo=None,
1489 ):
Tai Lye095da72024-01-25 22:00:18 +00001490 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001491 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001492 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001493 padding = args_dict["pad"]
1494 pad_const_int = args_dict["pad_const_int"]
1495 pad_const_float = args_dict["pad_const_fp"]
1496
1497 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001498
Tai Lye095da72024-01-25 22:00:18 +00001499 # write empty padding into PadAttribute to ensure inputs[1] is used
Kevin Chengfe392ce2021-10-18 21:51:55 +00001500 attr = ts.TosaSerializerAttribute()
Tai Lye095da72024-01-25 22:00:18 +00001501 attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001502
Matthew Haddone807aae2021-10-11 18:12:58 +01001503 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001504 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001505 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001506 pCount, cCount = op["operands"]
1507 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001508 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1509 self, error_name, input_list, output_list
1510 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001511
Les Bell729b0352021-11-24 10:28:21 +00001512 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001513 self.ser,
1514 validator_fcns,
1515 error_name,
1516 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001517 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001518 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001519 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001520 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001521 pad=padding,
1522 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001523 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001524 input_list=input_list,
1525 output_list=output_list,
1526 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001527 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001528 ):
1529 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001530
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001531 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001532
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001533 compliance = self.tensorComplianceMetaData(
1534 op, a.dtype, args_dict, result_tensor, error_name
1535 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001536
1537 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001538
Won Jeona21b2e82023-08-10 10:33:01 +00001539 def build_dim(
1540 self,
1541 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001542 inputs,
1543 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001544 validator_fcns=None,
1545 error_name=None,
1546 qinfo=None,
1547 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001548 assert len(inputs) == 1
1549 a = inputs[0]
1550 axis = args_dict["axis"]
1551 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001552
1553 # Invalidate Input/Output list for error if checks.
1554 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001555 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001556 pCount, cCount = op["operands"]
1557 num_operands = pCount + cCount
1558 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1559 self, error_name, input_list, output_list
1560 )
1561
1562 if not TosaErrorValidator.evValidateErrorIfs(
1563 self.ser,
1564 validator_fcns,
1565 error_name,
1566 op=op,
1567 axis=axis,
1568 input_shape=a.shape,
1569 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001570 output_shape=result_tensor.shape,
1571 output_dtype=result_tensor.dtype,
1572 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001573 input_list=input_list,
1574 output_list=output_list,
1575 num_operands=num_operands,
1576 ):
1577 return None
1578
1579 attr = ts.TosaSerializerAttribute()
1580 attr.AxisAttribute(axis)
1581
1582 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001583 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001584
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001585 def build_reshape(
1586 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1587 ):
Tai Ly8690a082023-12-18 20:40:24 +00001588 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001589 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001590 shape = inputs[1]
1591 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001592 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001593 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001594 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001595
1596 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001597 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001598 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001599 pCount, cCount = op["operands"]
1600 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001601 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1602 self, error_name, input_list, output_list
1603 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001604
Les Bell729b0352021-11-24 10:28:21 +00001605 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001606 self.ser,
1607 validator_fcns,
1608 error_name,
1609 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001610 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001611 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001612 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001613 output_dtype=result_tensor.dtype,
1614 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001615 input_list=input_list,
1616 output_list=output_list,
1617 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001618 ):
1619 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001620
Tai Ly8690a082023-12-18 20:40:24 +00001621 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001622
1623 compliance = self.tensorComplianceMetaData(
1624 op, a.dtype, args_dict, result_tensor, error_name
1625 )
1626
1627 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001628
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001629 def build_reverse(
1630 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1631 ):
1632 assert len(inputs) == 1
1633 a = inputs[0]
1634 axis = args_dict["axis"]
1635 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001636
1637 # Invalidate Input/Output list for error if checks.
1638 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001639 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001640 pCount, cCount = op["operands"]
1641 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001642 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1643 self, error_name, input_list, output_list
1644 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001645
Les Bell729b0352021-11-24 10:28:21 +00001646 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001647 self.ser,
1648 validator_fcns,
1649 error_name,
1650 op=op,
1651 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001652 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001653 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001654 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001655 output_dtype=result_tensor.dtype,
1656 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001657 input_list=input_list,
1658 output_list=output_list,
1659 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001660 ):
1661 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001662
1663 attr = ts.TosaSerializerAttribute()
1664 attr.AxisAttribute(axis)
1665
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001666 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001667 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001668
evacha0198477222024-01-26 12:25:32 +00001669 def build_transpose(
1670 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1671 ):
1672 assert len(inputs) == 1
1673 a = inputs[0]
1674 perms = args_dict["perms"]
1675
1676 result_tensor = OutputShaper.transposeOp(
1677 self.ser, self.rng, a, perms, error_name
1678 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001679
Kevin Chengfe392ce2021-10-18 21:51:55 +00001680 attr = ts.TosaSerializerAttribute()
1681 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001682
Matthew Haddone807aae2021-10-11 18:12:58 +01001683 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001684 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001685 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001686 pCount, cCount = op["operands"]
1687 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001688 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1689 self, error_name, input_list, output_list
1690 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001691
Les Bell729b0352021-11-24 10:28:21 +00001692 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001693 self.ser,
1694 validator_fcns,
1695 error_name,
1696 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001697 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001698 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001699 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001700 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001701 output_dtype=result_tensor.dtype,
1702 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001703 input_list=input_list,
1704 output_list=output_list,
1705 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001706 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001707 ):
1708 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001709
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001710 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001711
1712 compliance = self.tensorComplianceMetaData(
1713 op, a.dtype, args_dict, result_tensor, error_name
1714 )
1715
1716 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001717
evacha017f7d4252024-01-24 12:08:09 +00001718 def build_slice(
1719 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1720 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001721 assert len(inputs) == 3
1722 a, start_var, size_var = inputs
1723 start_const = args_dict["start"]
1724 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001725
1726 result_tensor = OutputShaper.sliceOp(
TatWai Chongf15bad82024-01-31 21:33:27 -08001727 self.ser, self.rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001728 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001729
1730 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001731 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001732 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001733 pCount, cCount = op["operands"]
1734 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001735 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1736 self, error_name, input_list, output_list
1737 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001738
Les Bell729b0352021-11-24 10:28:21 +00001739 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001740 self.ser,
1741 validator_fcns,
1742 error_name,
1743 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001744 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001745 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001746 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001747 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001748 start=start_const,
1749 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001750 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001751 input_list=input_list,
1752 output_list=output_list,
1753 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001754 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001755 ):
1756 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001757
TatWai Chongf15bad82024-01-31 21:33:27 -08001758 # TODO remove the slice attribute once shape dynamism support is mature.
Eric Kunzee5e26762020-10-13 16:11:07 -07001759 attr = ts.TosaSerializerAttribute()
TatWai Chongf15bad82024-01-31 21:33:27 -08001760 attr.SliceAttribute(start_const, size_const)
Eric Kunzee5e26762020-10-13 16:11:07 -07001761
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001762 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha017f7d4252024-01-24 12:08:09 +00001763
1764 compliance = self.tensorComplianceMetaData(
1765 op, a.dtype, args_dict, result_tensor, error_name
1766 )
1767
1768 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001769
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001770 def build_tile(
1771 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1772 ):
Tai Ly8690a082023-12-18 20:40:24 +00001773 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001774 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001775 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001776 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001777 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001778 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001779 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001780
1781 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001782 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001783 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001784 pCount, cCount = op["operands"]
1785 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001786 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1787 self, error_name, input_list, output_list
1788 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001789
Les Bell729b0352021-11-24 10:28:21 +00001790 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001791 self.ser,
1792 validator_fcns,
1793 error_name,
1794 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001795 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001796 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001797 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001798 output_dtype=result_tensor.dtype,
1799 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001800 input_list=input_list,
1801 output_list=output_list,
1802 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001803 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001804 ):
1805 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001806
Tai Ly8690a082023-12-18 20:40:24 +00001807 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001808
1809 compliance = self.tensorComplianceMetaData(
1810 op, a.dtype, args_dict, result_tensor, error_name
1811 )
1812
1813 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001814
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001815 def build_gather(
1816 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1817 ):
1818 assert len(inputs) == 2
1819 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001820
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001821 result_tensor = OutputShaper.gatherOp(
1822 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001823 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001824
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001825 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001826 input_list = [values.name, indices.name]
1827 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001828 pCount, cCount = op["operands"]
1829 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001830 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1831 self, error_name, input_list, output_list
1832 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001833
Les Bell729b0352021-11-24 10:28:21 +00001834 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001835 self.ser,
1836 validator_fcns,
1837 error_name,
1838 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001839 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001840 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001841 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001842 output_dtype=result_tensor.dtype,
1843 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001844 input_list=input_list,
1845 output_list=output_list,
1846 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001847 ):
1848 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001849
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001850 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001851
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001852 compliance = self.tensorComplianceMetaData(
1853 op, values.dtype, args_dict, result_tensor, error_name
1854 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001855
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001856 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001857
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001858 def build_scatter(
1859 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1860 ):
1861 assert len(inputs) == 3
1862 values_in, indices, input = inputs
1863 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001864 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001865 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001866
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001867 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001868 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001869 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001870 pCount, cCount = op["operands"]
1871 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001872 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1873 self, error_name, input_list, output_list
1874 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001875
Les Bell729b0352021-11-24 10:28:21 +00001876 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001877 self.ser,
1878 validator_fcns,
1879 error_name,
1880 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001881 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001882 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001883 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001884 output_dtype=result_tensor.dtype,
1885 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001886 input_list=input_list,
1887 output_list=output_list,
1888 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001889 ):
1890 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001891
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001892 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001893
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001894 compliance = self.tensorComplianceMetaData(
1895 op, values_in.dtype, args_dict, result_tensor, error_name
1896 )
1897
1898 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001899
Kevin Cheng550ccc52021-03-03 11:21:43 -08001900 def build_resize(
1901 self,
1902 op,
1903 input,
1904 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001905 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001906 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001907 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001908 input_dtype,
1909 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001910 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001911 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001912 ):
1913 result_tens = OutputShaper.resizeOp(
1914 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001915 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001916 input,
1917 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001918 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001919 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001920 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001921 input_dtype,
1922 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001923 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001924 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001925
Matthew Haddon848efb42021-09-09 12:30:53 +01001926 # Invalidate Input/Output list for error if checks.
1927 input_list = [input.name]
1928 output_list = [result_tens.name]
1929 pCount, cCount = op["operands"]
1930 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001931 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1932 self, error_name, input_list, output_list
1933 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001934
Les Bell729b0352021-11-24 10:28:21 +00001935 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001936 self.ser,
1937 validator_fcns,
1938 error_name,
1939 op=op,
1940 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001941 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001942 input_dtype=input_dtype,
1943 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001944 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001945 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001946 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001947 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001948 input_list=input_list,
1949 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001950 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001951 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001952 ):
1953 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001954
Eric Kunzee5e26762020-10-13 16:11:07 -07001955 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001956
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001957 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001958
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001959 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001960 return result_tens
1961
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001962 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1963 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1964 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001965 self.ser.addOperator(
1966 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1967 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001968 return result_tens
1969
evacha0198477222024-01-26 12:25:32 +00001970 def build_const(
1971 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1972 ):
1973 assert len(inputs) == 1
1974 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07001975 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00001976
1977 compliance = self.tensorComplianceMetaData(
1978 op, val.dtype, args_dict, val, error_name
1979 )
1980
1981 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001982
1983 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00001984 def build_cast(
1985 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1986 ):
1987 assert len(inputs) == 1
1988 val = inputs[0]
1989 out_dtype = args_dict["out_type"]
1990
1991 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001992 self.ser, self.rng, val, out_dtype, error_name
1993 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001994
1995 # Invalidate Input/Output list for error if checks.
1996 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00001997 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001998 pCount, cCount = op["operands"]
1999 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002000 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2001 self, error_name, input_list, output_list
2002 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002003
Les Bell729b0352021-11-24 10:28:21 +00002004 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002005 self.ser,
2006 validator_fcns,
2007 error_name,
2008 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002009 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002010 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002011 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002012 output_dtype=result_tensor.dtype,
2013 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002014 input_list=input_list,
2015 output_list=output_list,
2016 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002017 ):
2018 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002019
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002020 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002021
2022 compliance = self.tensorComplianceMetaData(
2023 op, val.dtype, args_dict, result_tensor, error_name
2024 )
2025
2026 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002027
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002028 def build_rescale(
2029 self,
2030 op,
2031 val,
2032 out_dtype,
2033 scale32,
2034 double_round,
2035 per_channel,
2036 validator_fcns,
2037 error_name,
2038 ):
2039 result_tens = OutputShaper.typeConversionOp(
2040 self.ser, self.rng, val, out_dtype, error_name
2041 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002042
2043 if per_channel:
2044 nc = val.shape[-1]
2045 else:
2046 nc = 1
2047
2048 in_type_width = self.typeWidth(val.dtype)
2049 out_type_width = self.typeWidth(out_dtype)
2050
Tai Ly8690a082023-12-18 20:40:24 +00002051 input_unsigned = False
2052 output_unsigned = False
2053
Kevin Cheng3a478572021-01-22 17:21:02 -08002054 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002055 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002056 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002057 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002058 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002059 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002060 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002061 elif error_name in [
2062 ErrorIf.InputZeroPointNotZero,
2063 ErrorIf.U16InputZeroPointNotValid,
2064 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002065 input_zp = self.randInt(-128, 128)
2066 if input_zp == 0:
2067 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002068 in_type_width += 1
2069 elif val.dtype == DType.UINT16:
2070 # Must come after ErrorIf.U16InputZeroPointNotValid check
2071 input_zp = self.rng.choice([0, 32768])
2072 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002073 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002074 else:
2075 input_zp = 0
2076
Kevin Cheng3a478572021-01-22 17:21:02 -08002077 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002078 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002079 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002080 elif out_dtype == DType.UINT8:
2081 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002082 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002083 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002084 elif error_name in [
2085 ErrorIf.OutputZeroPointNotZero,
2086 ErrorIf.U16OutputZeroPointNotValid,
2087 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002088 output_zp = self.randInt(-128, 128)
2089 if output_zp == 0:
2090 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002091 out_type_width += 1
2092 elif out_dtype == DType.UINT16:
2093 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2094 output_zp = self.rng.choice([0, 32768])
2095 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002096 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002097 else:
2098 output_zp = 0
2099
2100 # Calculate scale based on:
2101 # scale = a *(2^output_width)/(2^input_width))
2102
2103 a = np.float32(self.rng.random(size=[nc]))
2104 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2105
2106 if scale32:
2107 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002108 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002109 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2110 else:
2111 # Cap the scaling at 2^15 - 1 for scale16
2112 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2113
Kevin Cheng550ccc52021-03-03 11:21:43 -08002114 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002115
2116 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2117 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002118 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2119 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002120
2121 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002122 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2123 scale_arr[i], scale32
2124 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002125 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2126 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002127
Kevin Cheng550ccc52021-03-03 11:21:43 -08002128 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002129 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002130 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002131 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002132 assert val.placeholderFilename
2133 values = np.load(
2134 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2135 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002136 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2137 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2138 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002139 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2140 # Check we can safely convert to the expected dtype
2141 assert (
2142 val_adj.all() >= np.iinfo(values.dtype).min
2143 and val_adj.all() <= np.iinfo(values.dtype).max
2144 )
2145
2146 # Force casting to output datatype
2147 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2148
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002149 if not np.all(np.array_equal(values, val_adj)):
2150 # Values changed so overwrite file with new values
2151 np.save(
2152 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2153 val_adj,
2154 False,
2155 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002156
Matthew Haddonc2025212021-10-08 21:21:05 +01002157 # Invalidate Input/Output list for error if checks.
2158 input_list = [val.name]
2159 output_list = [result_tens.name]
2160 pCount, cCount = op["operands"]
2161 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002162 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2163 self, error_name, input_list, output_list
2164 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002165
2166 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002167 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002168 self.ser,
2169 validator_fcns,
2170 error_name,
2171 op=op,
2172 input_dtype=val.dtype,
2173 output_dtype=out_dtype,
2174 input_shape=val.shape,
2175 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002176 scale32=scale32,
2177 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002178 input_list=input_list,
2179 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002180 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002181 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002182 ):
2183 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002184
Eric Kunzee5e26762020-10-13 16:11:07 -07002185 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002186 attr.RescaleAttribute(
2187 input_zp,
2188 output_zp,
2189 multiplier_arr,
2190 shift_arr,
2191 scale32,
2192 double_round,
2193 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002194 input_unsigned,
2195 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002196 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002197
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002198 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002199 return result_tens
2200
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002201 def _get_condition_tensor(self, op, cond, error_name):
2202 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002203 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002204 else:
2205 cond_type = DType.BOOL
2206 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2207 choice = self.rng.choice([1, 2])
2208 if choice == 1:
2209 cond_shape = [2]
2210 else:
2211 cond_shape = [1, 2]
2212 else:
2213 # Must be of size 1 (rank 0)
2214 cond_shape = []
2215 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2216 return cond_tens
2217
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002218 def build_cond_if_const(
2219 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2220 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002221 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002222 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002223 # and fill them with const nodes for the body.
2224
2225 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002226 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002227
2228 # Make then/else tensors
2229 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002230
2231 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002232 if error_name in [
2233 ErrorIf.CondIfOutputListThenGraphMismatch,
2234 ErrorIf.CondIfOutputListElseGraphMismatch,
2235 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002236 incorrect_shape = deepcopy(then_tens.shape)
2237 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002238 incorrect_shape[i] += (
2239 self.rng.choice([-3, -2, 2, 3])
2240 if incorrect_shape[i] > 3
2241 else self.rng.choice([1, 2, 4])
2242 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002243 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2244
Jeremy Johnson18e26662021-07-22 16:15:29 +01002245 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2246 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002247
2248 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002249 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002250
2251 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002252 then_block = "THEN_BLOCK"
2253 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002254 attr = ts.TosaSerializerAttribute()
2255 attr.CondIfAttribute(then_block, else_block)
2256
2257 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002258 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002259
Jerry Ge9e94af82022-10-27 09:57:00 -07002260 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002261 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002262 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2263 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2264 else:
2265 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002266 self.ser.addOutputTensor(then_tens)
2267
Jerry Ge9e94af82022-10-27 09:57:00 -07002268 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002269 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2270 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2271 else:
2272 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002273 self.ser.addOutputTensor(else_tens)
2274
Les Bell729b0352021-11-24 10:28:21 +00002275 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002276 self.ser,
2277 validator_fcns,
2278 error_name,
2279 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002280 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002281 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002282 ):
2283 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002284
Eric Kunzee5e26762020-10-13 16:11:07 -07002285 return result_tens
2286
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002287 def build_cond_if_binary(
2288 self, op, a, b, cond, validator_fcns=None, error_name=None
2289 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002290 # For cond_if with a binary op in the then/else blocks, take a and b and
2291 # alternately add or subtract them based on the condition
2292
2293 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002294 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002295
Kevin Cheng550ccc52021-03-03 11:21:43 -08002296 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002297
2298 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002299 then_block = "THEN_BLOCK"
2300 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002301 attr = ts.TosaSerializerAttribute()
2302 attr.CondIfAttribute(then_block, else_block)
2303
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002304 if error_name in [
2305 ErrorIf.CondIfInputListThenGraphMismatch,
2306 ErrorIf.CondIfInputListElseGraphMismatch,
2307 ErrorIf.CondIfOutputListElseGraphMismatch,
2308 ErrorIf.CondIfOutputListThenGraphMismatch,
2309 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002310 incorrect_shape = a.shape.copy()
2311 for i in range(len(incorrect_shape)):
2312 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2313 incorrect_block_input = deepcopy(a)
2314 incorrect_block_input.shape = incorrect_shape
2315
Eric Kunzee5e26762020-10-13 16:11:07 -07002316 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002317 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002318 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002319 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002320
James Ward24dbc422022-10-19 12:20:31 +01002321 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002322 then_op, else_op = Op.ADD, Op.SUB
2323 elif a.dtype in (DType.INT8, DType.INT16):
2324 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2325 else:
2326 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002327
Les Bell6040b4d2021-10-11 12:50:31 +01002328 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002329 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002330 if (
2331 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2332 and block == then_block
2333 ) or (
2334 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2335 and block == else_block
2336 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002337 self.ser.addInputTensor(incorrect_block_input)
2338 self.ser.addInputTensor(b)
2339 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002340 elif (
2341 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2342 and block == then_block
2343 ) or (
2344 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2345 and block == else_block
2346 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002347 self.ser.addInputTensor(a)
2348 self.ser.addInputTensor(b)
2349 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2350 else:
2351 self.ser.addInputTensor(a)
2352 self.ser.addInputTensor(b)
2353 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002354 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002355
Les Bell729b0352021-11-24 10:28:21 +00002356 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002357 self.ser,
2358 validator_fcns,
2359 error_name,
2360 op=op,
2361 a=a,
2362 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002363 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002364 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002365 ):
2366 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002367
Eric Kunzee5e26762020-10-13 16:11:07 -07002368 return result_tens
2369
Matthew Haddon630c17c2021-10-14 15:05:41 +01002370 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002371 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002372
Kevin Cheng550ccc52021-03-03 11:21:43 -08002373 cond_block = "COND_BLOCK"
2374 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002375
2376 attr = ts.TosaSerializerAttribute()
2377 attr.WhileLoopAttribute(cond_block, body_block)
2378
2379 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002380 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002381 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002382 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002383
2384 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002385 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2386 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002387 if error_name == ErrorIf.InputListOutputListMismatch:
2388 incorrect_acc = deepcopy(acc)
2389 for i in range(len(incorrect_acc.shape)):
2390 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2391 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2392 else:
2393 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002394
2395 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002396 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002397 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002398 [iter.name, a.name, acc.name],
2399 [iter_out.name, a_out.name, acc_out.name],
2400 attr,
2401 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002402 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002403
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002404 if error_name in [
2405 ErrorIf.InputListCondGraphMismatch,
2406 ErrorIf.InputListBodyGraphInputMismatch,
2407 ErrorIf.InputListBodyGraphOutputMismatch,
2408 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002409 incorrect_iter = deepcopy(iter)
2410 for i in range(len(incorrect_iter.shape)):
2411 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2412 if len(incorrect_iter.shape) == 0:
2413 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2414
2415 incorrect_acc = deepcopy(acc)
2416 for i in range(len(incorrect_acc.shape)):
2417 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2418
Eric Kunzee5e26762020-10-13 16:11:07 -07002419 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002420 self.ser.addBasicBlock(cond_block)
2421
Matthew Haddon630c17c2021-10-14 15:05:41 +01002422 if error_name == ErrorIf.InputListCondGraphMismatch:
2423 self.ser.addInputTensor(incorrect_iter)
2424 self.ser.addInputTensor(a)
2425 self.ser.addInputTensor(incorrect_acc)
2426 else:
2427 self.ser.addInputTensor(iter)
2428 self.ser.addInputTensor(a)
2429 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002430 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002431
2432 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002433 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002434 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002435 cond_type = DType.BOOL
2436 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2437 choice = self.rng.choice([1, 2])
2438 if choice == 1:
2439 cond_shape = [3]
2440 else:
2441 cond_shape = [1, 2]
2442 else:
2443 cond_shape = []
2444 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002445
Kevin Cheng550ccc52021-03-03 11:21:43 -08002446 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002447
2448 # BODY block (input: a, acc, iter, output: a, acc, iter)
2449 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002450 self.ser.addBasicBlock(body_block)
2451
Matthew Haddon630c17c2021-10-14 15:05:41 +01002452 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2453 self.ser.addInputTensor(incorrect_iter)
2454 self.ser.addInputTensor(a)
2455 self.ser.addInputTensor(incorrect_acc)
2456 else:
2457 self.ser.addInputTensor(iter)
2458 self.ser.addInputTensor(a)
2459 self.ser.addInputTensor(acc)
2460
Kevin Cheng550ccc52021-03-03 11:21:43 -08002461 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002462
2463 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002464 iter_body_out = self.ser.addIntermediate(
2465 incorrect_iter.shape, incorrect_iter.dtype
2466 )
2467 acc_body_out = self.ser.addIntermediate(
2468 incorrect_acc.shape, incorrect_acc.dtype
2469 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002470 else:
2471 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2472 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2473
Eric Kunzee5e26762020-10-13 16:11:07 -07002474 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2475 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2476 self.ser.addOutputTensor(iter_body_out)
2477 self.ser.addOutputTensor(a)
2478 self.ser.addOutputTensor(acc_body_out)
2479
Les Bell729b0352021-11-24 10:28:21 +00002480 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002481 self.ser,
2482 validator_fcns,
2483 error_name,
2484 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002485 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002486 ):
2487 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002488
Eric Kunzee5e26762020-10-13 16:11:07 -07002489 return acc_out
2490
Luke Hutton57287132023-02-06 14:54:18 +00002491 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002492 self,
2493 op,
2494 val1,
2495 val2,
2496 inverse,
2497 validator_fcns=None,
2498 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002499 ):
2500 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2501
2502 input_names = [val1.name, val2.name]
2503 pCount, cCount = op["operands"]
2504 num_operands = pCount + cCount
2505
2506 output_names = [res.name for res in results]
2507 output_shapes = [res.shape for res in results]
2508 output_dtypes = [res.dtype for res in results]
2509
2510 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2511 self, error_name, input_names, output_names
2512 )
2513
2514 if not TosaErrorValidator.evValidateErrorIfs(
2515 self.ser,
2516 validator_fcns,
2517 error_name,
2518 op=op,
2519 inverse=inverse,
2520 input1=val1,
2521 input2=val2,
2522 input_shape=val1.shape,
2523 input_dtype=val1.dtype,
2524 output_shape=output_shapes,
2525 output_dtype=output_dtypes,
2526 result_tensors=results,
2527 input_list=input_names,
2528 output_list=output_names,
2529 num_operands=num_operands,
2530 ):
2531 return None
2532
Tai Lyd3797f02023-11-15 23:06:19 +00002533 # TODO - Test local_bound, for now set local bound attribute to False
2534 local_bound = False
2535
Luke Hutton57287132023-02-06 14:54:18 +00002536 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002537 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002538
2539 self.ser.addOperator(op["op"], input_names, output_names, attr)
2540 return results
2541
Tai Lyd3797f02023-11-15 23:06:19 +00002542 def build_rfft2d(
2543 self,
2544 op,
2545 val,
2546 validator_fcns=None,
2547 error_name=None,
2548 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002549 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2550
2551 input_names = [val.name]
2552 pCount, cCount = op["operands"]
2553 num_operands = pCount + cCount
2554
2555 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002556 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002557 output_dtypes = [res.dtype for res in results]
2558
2559 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2560 self, error_name, input_names, output_names
2561 )
2562
2563 if not TosaErrorValidator.evValidateErrorIfs(
2564 self.ser,
2565 validator_fcns,
2566 error_name,
2567 op=op,
2568 input_shape=val.shape,
2569 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002570 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002571 output_dtype=output_dtypes,
2572 result_tensors=results,
2573 input_list=input_names,
2574 output_list=output_names,
2575 num_operands=num_operands,
2576 ):
2577 return None
2578
Tai Lyd3797f02023-11-15 23:06:19 +00002579 # TODO - Test local_bound, for now set local bound attribute to False
2580 local_bound = False
2581
2582 attr = ts.TosaSerializerAttribute()
2583 attr.RFFTAttribute(local_bound)
2584
2585 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002586 return results
2587
Won Jeon74342e52024-01-09 00:34:40 +00002588 def build_shape_op(
2589 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2590 ):
2591 assert len(inputs) == 2
2592 a, b = inputs
2593
2594 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2595
2596 # Invalidate Input/Output list for error if checks.
2597 input_list = [a.name, b.name]
2598 output_list = [result_tensor.name]
2599 pCount, cCount = op["operands"]
2600 num_operands = pCount + cCount
2601 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2602 self, error_name, input_list, output_list
2603 )
2604
2605 if not TosaErrorValidator.evValidateErrorIfs(
2606 self.ser,
2607 validator_fcns,
2608 error_name,
2609 op=op,
2610 input1=a,
2611 input2=b,
2612 input_shape=a.shape,
2613 input_dtype=a.dtype,
2614 output_shape=result_tensor.shape,
2615 output_dtype=result_tensor.dtype,
2616 result_tensors=[result_tensor],
2617 input_list=input_list,
2618 output_list=output_list,
2619 num_operands=num_operands,
2620 ):
2621 return None
2622
2623 self.ser.addOperator(
2624 op["op"],
2625 input_list,
2626 output_list,
2627 )
2628 compliance = self.tensorComplianceMetaData(
2629 op, a.dtype, args_dict, result_tensor, error_name
2630 )
2631
2632 return TosaTestGen.BuildInfo(result_tensor, compliance)
2633
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002634 def create_filter_lists(
2635 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2636 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002637 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2638 default_test_rank_range = range(1, 5)
2639 if not shapeFilter:
2640 shapeFilter = [None]
2641
2642 # Calculate the filters based on what is requested and what the operator allows
2643 rmin, rmax = op["rank"]
2644 if rankFilter is not None:
2645 cleanRankFilter = []
2646 # Ensure rankFilter values are allowed by operator
2647 for rank in rankFilter:
2648 if rank >= rmin and rank <= rmax:
2649 cleanRankFilter.append(rank)
2650 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002651 # Ensure default behaviour is bounded by default range or by operator,
2652 # whichever is the smaller range of ranks.
2653 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002654 cleanRankFilter = (
2655 opRankRange
2656 if len(opRankRange) <= len(default_test_rank_range)
2657 else default_test_rank_range
2658 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002659 else:
2660 cleanRankFilter = range(rmin, rmax + 1)
2661
2662 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002663
Matthew Haddon1c00b712021-10-01 15:51:03 +01002664 if dtypeFilter is not None:
2665 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002666 # Create list of operator dtypes filtered by requested dtypes
2667 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002668 if dtype in dtypeFilter or (
2669 isinstance(dtype, list) and dtype[0] in dtypeFilter
2670 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002671 cleanDtypeFilter.append(dtype)
2672 else:
2673 cleanDtypeFilter = dtypes
2674
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002675 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002676 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002677 "shapeFilter": shapeFilter,
2678 "rankFilter": cleanRankFilter,
2679 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002680 }
2681 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002682 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002683 if validator is not None:
2684 validator_info = validator(check=False, op=op)
2685 else:
2686 return None
2687
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002688 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002689
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002690 # Set parameters as required
2691 if error_arguments["rank"] is not None:
2692 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002693 else:
2694 rankFilter = cleanRankFilter
2695
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002696 if error_arguments["dtype"] is not None:
2697 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002698 else:
2699 dtypeFilter = cleanDtypeFilter
2700
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002701 if error_arguments["shape"] is not None:
2702 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002703 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002704 shapeFilter = shapeFilter[
2705 :2
2706 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002707
2708 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002709 "shapeFilter": shapeFilter,
2710 "rankFilter": rankFilter,
2711 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002712 }
2713 return filterDict
2714
Kevin Cheng550ccc52021-03-03 11:21:43 -08002715 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002716 self,
2717 opName,
2718 shapeFilter=[None],
2719 rankFilter=None,
2720 dtypeFilter=None,
2721 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002722 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002723
2724 try:
2725 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002726 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002727 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002728
2729 # Initialize a new random number generator
2730 self.rng = np.random.default_rng(self.random_seed)
2731
Jeremy Johnson1271c442023-09-05 11:39:26 +01002732 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002733
Eric Kunzee5e26762020-10-13 16:11:07 -07002734 # Test list consists of a tuple of:
2735 # (opName, testNameStr, dtype, shapeList, argumentsList)
2736 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002737 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002738 error_if_validators = op["error_if_validators"]
2739 else:
2740 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002741
Matthew Haddon1c00b712021-10-01 15:51:03 +01002742 for validator in error_if_validators:
2743 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002744 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002745 else:
2746 error_name = None
2747
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002748 filterDict = self.create_filter_lists(
2749 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2750 )
2751 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002752 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002753 cleanRankFilter = filterDict["rankFilter"]
2754 cleanDtypeFilter = filterDict["dtypeFilter"]
2755 cleanShapeFilter = filterDict["shapeFilter"]
2756 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002757
2758 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002759 for t in cleanDtypeFilter:
2760 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002761 # Filter out by rank
2762 if shape is not None and len(shape) != r:
2763 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002764 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002765 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002766
Matthew Haddon74567092021-07-16 15:38:20 +01002767 shapeStr = self.shapeStr(shapeList[0])
2768 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002769
Matthew Haddon74567092021-07-16 15:38:20 +01002770 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2771 argList = []
2772 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002773 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002774 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002775 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002776
Matthew Haddon74567092021-07-16 15:38:20 +01002777 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002778 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002779 if argStr:
2780 testStr = "{}_{}_{}_{}".format(
2781 opName, shapeStr, typeStr, argStr
2782 )
2783 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002784 testStr = "{}_{}_{}".format(
2785 opName, shapeStr, typeStr
2786 )
2787 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002788 if argStr:
2789 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2790 opName, error_name, shapeStr, typeStr, argStr
2791 )
2792 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002793 testStr = "{}_ERRORIF_{}_{}_{}".format(
2794 opName, error_name, shapeStr, typeStr
2795 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002796
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002797 testList.append(
2798 (opName, testStr, t, error_name, shapeList, args)
2799 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002800
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002801 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002802 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2803 if "invalid_test_validators" in op:
2804 invalid_test_validators = op["invalid_test_validators"]
2805 clean_testList = []
2806 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002807 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002808 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002809 if validator_fcn(
2810 opName=test[0],
2811 input_dtype=test[2],
2812 shapeList=test[4],
2813 args=test[5],
2814 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002815 remove_test = True
2816 if not remove_test:
2817 clean_testList.append(test)
2818 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002819
2820 return testList
2821
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002822 def serializeTest(
2823 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2824 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002825 try:
2826 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002827 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002828 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002829
Jeremy Johnson0c716862023-04-13 17:18:19 +01002830 if self.args.verbose:
2831 print(f"Creating {testStr}")
2832
Eric Kunzee5e26762020-10-13 16:11:07 -07002833 # Create a serializer
2834 self.createSerializer(opName, testStr)
2835
Jeremy Johnson1271c442023-09-05 11:39:26 +01002836 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002837 if "error_if_validators" in op:
2838 error_if_validators = op["error_if_validators"]
2839 else:
2840 error_if_validators = None
2841
Kevin Cheng550ccc52021-03-03 11:21:43 -08002842 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002843 num_operands = pCount + cCount
2844
2845 if isinstance(dtype_or_dtypeList, list):
2846 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00002847 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002848 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002849 else:
2850 dtypeList = [dtype_or_dtypeList] * (num_operands)
2851
Won Jeon74342e52024-01-09 00:34:40 +00002852 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002853 assert (
2854 len(shapeList) == num_operands
2855 ), "shapeList length {} must match number of operands {}".format(
2856 len(shapeList), num_operands
2857 )
2858 assert (
2859 len(dtypeList) == num_operands
2860 ), "dtypeList length {} must match number of operands {}".format(
2861 len(dtypeList), num_operands
2862 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002863
2864 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002865 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002866 except KeyError:
2867 qgen = None
2868
2869 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002870
Matthew Haddon1c00b712021-10-01 15:51:03 +01002871 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002872 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002873 else:
2874 qinfo = None
2875
Jeremy Johnson1271c442023-09-05 11:39:26 +01002876 # Extra meta data for the desc.json
2877 tensMeta = {}
2878
2879 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002880 if isinstance(testArgs, dict):
2881 # New interface with args info in dictionary
2882 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002883 assert "dg_type" in argsDict
2884 tvgInfo = tvgen_fcn(
2885 self, opName, dtypeList, shapeList, argsDict, error_name
2886 )
2887 if tvgInfo.dataGenDict:
2888 tensMeta["data_gen"] = tvgInfo.dataGenDict
2889 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002890
2891 result = build_fcn(
2892 self,
2893 op,
2894 tens,
2895 argsDict,
2896 validator_fcns=error_if_validators,
2897 error_name=error_name,
2898 qinfo=qinfo,
2899 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002900 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002901 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002902 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002903
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002904 try:
2905 if error_if_validators is None:
2906 if qinfo is not None:
2907 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2908 else:
2909 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002910 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002911 if qinfo is not None:
2912 result = build_fcn(
2913 self,
2914 op,
2915 *tens,
2916 *testArgs,
2917 validator_fcns=error_if_validators,
2918 error_name=error_name,
2919 qinfo=qinfo,
2920 )
2921 else:
2922 result = build_fcn(
2923 self,
2924 op,
2925 *tens,
2926 *testArgs,
2927 validator_fcns=error_if_validators,
2928 error_name=error_name,
2929 )
2930 except TypeError as e:
2931 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2932 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002933
Jeremy Johnson1271c442023-09-05 11:39:26 +01002934 if result:
Les Bell729b0352021-11-24 10:28:21 +00002935 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002936 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2937 # Add the compliance meta data
2938 # NOTE: This currently expects only one result output
2939 tensMeta["compliance"] = {
2940 "version": "0.1",
2941 "tensors": {result.resultTensor.name: result.complianceDict},
2942 }
2943 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002944 else:
2945 # The test is not valid
2946 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002947
Eric Kunzee5e26762020-10-13 16:11:07 -07002948 def createDynamicOpLists(self):
2949
Jeremy Johnson00423432022-09-12 17:27:37 +01002950 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2951 # Already created these lists (can occur when class is initialized more than once)
2952 return
2953
Eric Kunzee5e26762020-10-13 16:11:07 -07002954 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002955 if not self.args.level8k:
2956 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2957 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2958 else:
2959 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2960 KERNELS_2D = [[1, bigK], [bigK, 2]]
2961 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002962
Kevin Cheng1533b852021-09-01 12:51:58 -07002963 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002964 testName = "conv2d_{}x{}".format(k[0], k[1])
2965 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2966 self.TOSA_OP_LIST[testName]["filter"] = k
2967 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002968
Kevin Cheng550ccc52021-03-03 11:21:43 -08002969 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2970 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2971 "depthwise_conv2d_TEMPLATE"
2972 ].copy()
2973 self.TOSA_OP_LIST[testName]["filter"] = k
2974 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002975
Kevin Cheng550ccc52021-03-03 11:21:43 -08002976 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2977 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2978 "transpose_conv2d_TEMPLATE"
2979 ].copy()
2980 self.TOSA_OP_LIST[testName]["filter"] = k
2981 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002982
Kevin Cheng1533b852021-09-01 12:51:58 -07002983 for k in KERNELS_3D:
2984 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2985 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2986 self.TOSA_OP_LIST[testName]["filter"] = k
2987 self.TOSA_OP_LIST[testName]["template"] = False
2988
Eric Kunzee5e26762020-10-13 16:11:07 -07002989 # Delete any templates after having created any dynamic ops
2990 # This is a two-pass operation because it's bad practice to delete
2991 # keys from dictionaries while iterating
2992 keyList = []
2993 for k in self.TOSA_OP_LIST:
2994 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002995 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002996 keyList.append(k)
2997 continue
2998 except KeyError:
2999 pass
3000
3001 for k in keyList:
3002 del self.TOSA_OP_LIST[k]
3003
3004 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003005 """Fill in default fields for ops if they aren't already specified.
3006 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003007 for op in self.TOSA_OP_LIST:
3008
3009 # Required fields
3010 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003011 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003012 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003013 raise Exception(
3014 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3015 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003016
3017 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003018 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003019 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003020 raise Exception(
3021 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3022 op
3023 )
3024 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003025
3026 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003027 _ = self.TOSA_OP_LIST[op]["types"]
3028 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003029 raise Exception(
3030 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3031 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003032
3033 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003034 _ = self.TOSA_OP_LIST[op]["op"]
3035 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003036 raise Exception(
3037 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3038 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003039
3040 # Put in default rank range, if missing
3041 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003042 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003043 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003044 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003045
3046 # Tensor operator list
3047 # 'op': op name
3048 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003049 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3050 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003051 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3052 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003053 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003054
Kevin Cheng550ccc52021-03-03 11:21:43 -08003055 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003056 TYPE_INT_FP = [
3057 DType.INT8,
3058 DType.INT16,
3059 DType.INT32,
3060 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003061 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003062 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003063 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003064
Kevin Cheng550ccc52021-03-03 11:21:43 -08003065 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003066 TYPE_FI32 = [
3067 DType.FP32,
3068 DType.FP16,
3069 DType.BF16,
3070 DType.INT32,
3071 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003072 TYPE_FIB = [
3073 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003074 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003075 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003076 DType.INT8,
3077 DType.INT16,
3078 DType.INT32,
3079 DType.BOOL,
3080 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003081 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003082
James Ward24dbc422022-10-19 12:20:31 +01003083 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07003084
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003085 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003086 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003087 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003088 [DType.INT8, DType.INT8, DType.INT32],
3089 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003090 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003091 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003092 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003093 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003094 ]
3095
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003096 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003097
3098 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003099 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003100 "argmax": {
3101 "op": Op.ARGMAX,
3102 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003103 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003104 "build_fcn": (
3105 build_argmax,
3106 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003107 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003108 TosaArgGen.agAxis,
3109 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003110 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003111 "error_if_validators": (
3112 TosaErrorValidator.evAxisSmallerZero,
3113 TosaErrorValidator.evAxisLargerRank,
3114 TosaErrorValidator.evArgmaxOutputRankMismatch,
3115 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3116 TosaErrorValidator.evWrongRank,
3117 TosaErrorValidator.evWrongInputType,
3118 TosaErrorValidator.evWrongOutputType,
3119 TosaErrorValidator.evWrongInputList,
3120 TosaErrorValidator.evWrongOutputList,
3121 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003122 "data_gen": {
3123 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3124 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003125 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003126 "avg_pool2d": {
3127 "op": Op.AVG_POOL2D,
3128 "operands": (1, 0),
3129 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003130 "build_fcn": (
3131 build_pool2d,
3132 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003133 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003134 TosaArgGen.agPooling,
3135 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003136 "qgen": TosaQuantGen.qgUnary,
3137 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003138 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003139 "error_if_validators": (
3140 TosaErrorValidator.evKernelSmallerOne,
3141 TosaErrorValidator.evStrideSmallerOne,
3142 TosaErrorValidator.evPadSmallerZero,
3143 TosaErrorValidator.evWrongRank,
3144 TosaErrorValidator.evWrongInputType,
3145 TosaErrorValidator.evWrongOutputType,
3146 TosaErrorValidator.evWrongInputList,
3147 TosaErrorValidator.evWrongOutputList,
3148 TosaErrorValidator.evInputZeroPointNotZero,
3149 TosaErrorValidator.evOutputZeroPointNotZero,
3150 TosaErrorValidator.evPadLargerEqualKernel,
3151 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003152 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003153 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003154 "data_gen": {
3155 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3156 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003157 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003158 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003159 "conv2d_TEMPLATE": {
3160 "op": Op.CONV2D,
3161 "operands": (1, 2),
3162 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003163 "build_fcn": (
3164 build_conv2d,
3165 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003166 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003167 TosaArgGen.agConv,
3168 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003169 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003170 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003171 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3172 "error_if_validators": (
3173 TosaErrorValidator.evWrongInputType,
3174 TosaErrorValidator.evWrongOutputType,
3175 TosaErrorValidator.evWrongInputList,
3176 TosaErrorValidator.evWrongOutputList,
3177 TosaErrorValidator.evInputZeroPointNotZero,
3178 TosaErrorValidator.evWeightZeroPointNotZero,
3179 TosaErrorValidator.evPadSmallerZero,
3180 TosaErrorValidator.evStrideSmallerOne,
3181 TosaErrorValidator.evDilationSmallerOne,
3182 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003183 TosaErrorValidator.evConvOutputShapeMismatch,
3184 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003185 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003186 "data_gen": {
3187 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3188 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003189 "template": True,
3190 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003191 # Templated operator. Filled in by createDynamicOpLists
3192 "conv3d_TEMPLATE": {
3193 "op": Op.CONV3D,
3194 "operands": (1, 2),
3195 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003196 "build_fcn": (
3197 build_conv3d,
3198 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003199 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003200 TosaArgGen.agConv,
3201 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003202 "qgen": TosaQuantGen.qgConv,
3203 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003204 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3205 "error_if_validators": (
3206 TosaErrorValidator.evWrongInputType,
3207 TosaErrorValidator.evWrongOutputType,
3208 TosaErrorValidator.evWrongInputList,
3209 TosaErrorValidator.evWrongOutputList,
3210 TosaErrorValidator.evInputZeroPointNotZero,
3211 TosaErrorValidator.evWeightZeroPointNotZero,
3212 TosaErrorValidator.evPadSmallerZero,
3213 TosaErrorValidator.evStrideSmallerOne,
3214 TosaErrorValidator.evDilationSmallerOne,
3215 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003216 TosaErrorValidator.evConvOutputShapeMismatch,
3217 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003218 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003219 "template": True,
3220 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003221 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003222 "depthwise_conv2d_TEMPLATE": {
3223 "op": Op.DEPTHWISE_CONV2D,
3224 "operands": (1, 2),
3225 "filter": [1, 1],
3226 "rank": (4, 4),
3227 "build_fcn": (
3228 build_depthwise_conv2d,
3229 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003230 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003231 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003232 ),
3233 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003234 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003235 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3236 "error_if_validators": (
3237 TosaErrorValidator.evWrongInputType,
3238 TosaErrorValidator.evWrongOutputType,
3239 TosaErrorValidator.evWrongInputList,
3240 TosaErrorValidator.evWrongOutputList,
3241 TosaErrorValidator.evInputZeroPointNotZero,
3242 TosaErrorValidator.evWeightZeroPointNotZero,
3243 TosaErrorValidator.evPadSmallerZero,
3244 TosaErrorValidator.evStrideSmallerOne,
3245 TosaErrorValidator.evDilationSmallerOne,
3246 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003247 TosaErrorValidator.evConvOutputShapeMismatch,
3248 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003249 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003250 "data_gen": {
3251 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3252 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003253 "template": True,
3254 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003255 "fully_connected": {
3256 "op": Op.FULLY_CONNECTED,
3257 "operands": (1, 2),
3258 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003259 "build_fcn": (
3260 build_fully_connected,
3261 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003262 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003263 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003264 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003265 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003266 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003267 "error_if_validators": (
3268 TosaErrorValidator.evInputZeroPointNotZero,
3269 TosaErrorValidator.evWeightZeroPointNotZero,
3270 TosaErrorValidator.evWrongRank,
3271 TosaErrorValidator.evWrongInputType,
3272 TosaErrorValidator.evWrongOutputType,
3273 TosaErrorValidator.evWrongInputList,
3274 TosaErrorValidator.evWrongOutputList,
3275 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003276 "data_gen": {
3277 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3278 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003279 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003280 "matmul": {
3281 "op": Op.MATMUL,
3282 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003283 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003284 "build_fcn": (
3285 build_matmul,
3286 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003287 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003288 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003289 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003290 "qgen": TosaQuantGen.qgMatmul,
3291 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003292 "error_if_validators": (
3293 TosaErrorValidator.evInputZeroPointNotZero,
3294 TosaErrorValidator.evWrongRank,
3295 TosaErrorValidator.evWrongInputType,
3296 TosaErrorValidator.evWrongOutputType,
3297 TosaErrorValidator.evWrongInputList,
3298 TosaErrorValidator.evWrongOutputList,
3299 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003300 "data_gen": {
3301 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003302 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003303 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003304 "max_pool2d": {
3305 "op": Op.MAX_POOL2D,
3306 "operands": (1, 0),
3307 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003308 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003309 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003310 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003311 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003312 TosaArgGen.agPooling,
3313 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003314 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003315 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003316 "error_if_validators": (
3317 TosaErrorValidator.evKernelSmallerOne,
3318 TosaErrorValidator.evStrideSmallerOne,
3319 TosaErrorValidator.evPadSmallerZero,
3320 TosaErrorValidator.evWrongRank,
3321 TosaErrorValidator.evWrongInputType,
3322 TosaErrorValidator.evWrongOutputType,
3323 TosaErrorValidator.evWrongInputList,
3324 TosaErrorValidator.evWrongOutputList,
3325 TosaErrorValidator.evPadLargerEqualKernel,
3326 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003327 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003328 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003329 "data_gen": {
3330 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3331 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003332 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003333 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003334 "transpose_conv2d_TEMPLATE": {
3335 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003336 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003337 "rank": (4, 4),
3338 "build_fcn": (
3339 build_transpose_conv2d,
3340 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003341 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003342 TosaArgGen.agTransposeConv2D,
3343 ),
3344 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003345 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003346 "invalid_test_validators": (
3347 TosaInvalidValidator.ivHeightWidthInvalid,
3348 TosaInvalidValidator.ivNonPositiveOutputShape,
3349 ),
3350 "error_if_validators": (
3351 TosaErrorValidator.evWrongInputType,
3352 TosaErrorValidator.evWrongOutputType,
3353 TosaErrorValidator.evWrongInputList,
3354 TosaErrorValidator.evWrongOutputList,
3355 TosaErrorValidator.evInputZeroPointNotZero,
3356 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003357 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003358 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003359 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003360 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003361 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003362 "data_gen": {
3363 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3364 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003365 "template": True,
3366 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003367 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003368 "clamp": {
3369 "op": Op.CLAMP,
3370 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003371 "build_fcn": (
3372 build_clamp,
3373 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003374 TosaTensorValuesGen.tvgLazyGenDefault,
3375 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003376 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003377 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003378 "error_if_validators": (
3379 TosaErrorValidator.evMaxSmallerMin,
3380 TosaErrorValidator.evWrongInputType,
3381 TosaErrorValidator.evWrongOutputType,
3382 TosaErrorValidator.evWrongInputList,
3383 TosaErrorValidator.evWrongOutputList,
3384 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003385 "data_gen": {
3386 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3387 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003388 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003389 "sigmoid": {
3390 "op": Op.SIGMOID,
3391 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003392 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003393 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003394 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003395 TosaTensorValuesGen.tvgLazyGenDefault,
3396 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003397 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003398 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003399 "error_if_validators": (
3400 TosaErrorValidator.evWrongInputType,
3401 TosaErrorValidator.evWrongOutputType,
3402 TosaErrorValidator.evWrongInputList,
3403 TosaErrorValidator.evWrongOutputList,
3404 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003405 "data_gen": {
3406 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3407 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003408 },
3409 "tanh": {
3410 "op": Op.TANH,
3411 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003412 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003413 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003414 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003415 TosaTensorValuesGen.tvgLazyGenDefault,
3416 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003417 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003418 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003419 "error_if_validators": (
3420 TosaErrorValidator.evWrongInputType,
3421 TosaErrorValidator.evWrongOutputType,
3422 TosaErrorValidator.evWrongInputList,
3423 TosaErrorValidator.evWrongOutputList,
3424 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003425 "data_gen": {
3426 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3427 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003428 "compliance": {
3429 "abs_error_lower_bound": 0.5,
3430 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003431 },
Won Jeon78155c62023-06-10 00:20:04 +00003432 "erf": {
3433 "op": Op.ERF,
3434 "operands": (1, 0),
3435 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003436 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003437 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003438 TosaTensorValuesGen.tvgLazyGenDefault,
3439 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003440 ),
3441 "types": TYPE_FP,
3442 "error_if_validators": (
3443 TosaErrorValidator.evWrongInputType,
3444 TosaErrorValidator.evWrongOutputType,
3445 TosaErrorValidator.evWrongInputList,
3446 TosaErrorValidator.evWrongOutputList,
3447 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003448 "data_gen": {
3449 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3450 },
3451 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003452 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003453 # Elementwise Binary Operators
3454 "add": {
3455 "op": Op.ADD,
3456 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003457 "build_fcn": (
3458 build_binary_broadcast,
3459 TosaTensorGen.tgBroadcastFuzz,
3460 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003461 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003462 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003463 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003464 "error_if_validators": (
3465 TosaErrorValidator.evRankMismatch,
3466 TosaErrorValidator.evWrongInputType,
3467 TosaErrorValidator.evWrongOutputType,
3468 TosaErrorValidator.evWrongInputList,
3469 TosaErrorValidator.evWrongOutputList,
3470 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003471 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003472 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003473 "data_gen": {
3474 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3475 },
3476 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003477 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003478 "arithmetic_right_shift": {
3479 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3480 "operands": (2, 0),
3481 "build_fcn": (
3482 build_arithmetic_right_shift,
3483 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003484 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003485 TosaArgGen.agArithmeticRightShift,
3486 ),
3487 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003488 "error_if_validators": (
3489 TosaErrorValidator.evRankMismatch,
3490 TosaErrorValidator.evWrongInputType,
3491 TosaErrorValidator.evWrongOutputType,
3492 TosaErrorValidator.evWrongInputList,
3493 TosaErrorValidator.evWrongOutputList,
3494 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003495 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003496 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003497 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003498 "bitwise_and": {
3499 "op": Op.BITWISE_AND,
3500 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003501 "build_fcn": (
3502 build_binary_broadcast,
3503 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003504 TosaTensorValuesGen.tvgLazyGenDefault,
3505 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003506 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003507 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003508 "error_if_validators": (
3509 TosaErrorValidator.evRankMismatch,
3510 TosaErrorValidator.evWrongInputType,
3511 TosaErrorValidator.evWrongOutputType,
3512 TosaErrorValidator.evWrongInputList,
3513 TosaErrorValidator.evWrongOutputList,
3514 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003515 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003516 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003517 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003518 "bitwise_or": {
3519 "op": Op.BITWISE_OR,
3520 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003521 "build_fcn": (
3522 build_binary_broadcast,
3523 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003524 TosaTensorValuesGen.tvgLazyGenDefault,
3525 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003526 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003527 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003528 "error_if_validators": (
3529 TosaErrorValidator.evRankMismatch,
3530 TosaErrorValidator.evWrongInputType,
3531 TosaErrorValidator.evWrongOutputType,
3532 TosaErrorValidator.evWrongInputList,
3533 TosaErrorValidator.evWrongOutputList,
3534 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003535 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003536 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003537 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003538 "bitwise_xor": {
3539 "op": Op.BITWISE_XOR,
3540 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003541 "build_fcn": (
3542 build_binary_broadcast,
3543 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003544 TosaTensorValuesGen.tvgLazyGenDefault,
3545 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003546 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003547 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003548 "error_if_validators": (
3549 TosaErrorValidator.evRankMismatch,
3550 TosaErrorValidator.evWrongInputType,
3551 TosaErrorValidator.evWrongOutputType,
3552 TosaErrorValidator.evWrongInputList,
3553 TosaErrorValidator.evWrongOutputList,
3554 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003555 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003556 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003557 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003558 "intdiv": {
3559 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003560 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003561 "build_fcn": (
3562 build_binary_broadcast,
3563 TosaTensorGen.tgBroadcastFuzz,
3564 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003565 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003566 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003567 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003568 "error_if_validators": (
3569 TosaErrorValidator.evRankMismatch,
3570 TosaErrorValidator.evWrongInputType,
3571 TosaErrorValidator.evWrongOutputType,
3572 TosaErrorValidator.evWrongInputList,
3573 TosaErrorValidator.evWrongOutputList,
3574 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003575 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003576 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003577 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003578 "logical_and": {
3579 "op": Op.LOGICAL_AND,
3580 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003581 "build_fcn": (
3582 build_binary_broadcast,
3583 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003584 TosaTensorValuesGen.tvgLazyGenDefault,
3585 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003586 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003587 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003588 "error_if_validators": (
3589 TosaErrorValidator.evRankMismatch,
3590 TosaErrorValidator.evWrongInputType,
3591 TosaErrorValidator.evWrongOutputType,
3592 TosaErrorValidator.evWrongInputList,
3593 TosaErrorValidator.evWrongOutputList,
3594 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003595 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003596 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003597 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003598 "logical_left_shift": {
3599 "op": Op.LOGICAL_LEFT_SHIFT,
3600 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003601 "build_fcn": (
3602 build_binary_broadcast,
3603 TosaTensorGen.tgBroadcastFuzz,
3604 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003605 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003606 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003607 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003608 "error_if_validators": (
3609 TosaErrorValidator.evRankMismatch,
3610 TosaErrorValidator.evWrongInputType,
3611 TosaErrorValidator.evWrongOutputType,
3612 TosaErrorValidator.evWrongInputList,
3613 TosaErrorValidator.evWrongOutputList,
3614 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003615 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003616 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003617 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003618 "logical_right_shift": {
3619 "op": Op.LOGICAL_RIGHT_SHIFT,
3620 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003621 "build_fcn": (
3622 build_binary_broadcast,
3623 TosaTensorGen.tgBroadcastFuzz,
3624 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003625 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003626 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003627 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003628 "error_if_validators": (
3629 TosaErrorValidator.evRankMismatch,
3630 TosaErrorValidator.evWrongInputType,
3631 TosaErrorValidator.evWrongOutputType,
3632 TosaErrorValidator.evWrongInputList,
3633 TosaErrorValidator.evWrongOutputList,
3634 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003635 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003636 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003637 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003638 "logical_or": {
3639 "op": Op.LOGICAL_OR,
3640 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003641 "build_fcn": (
3642 build_binary_broadcast,
3643 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003644 TosaTensorValuesGen.tvgLazyGenDefault,
3645 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003646 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003647 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003648 "error_if_validators": (
3649 TosaErrorValidator.evRankMismatch,
3650 TosaErrorValidator.evWrongInputType,
3651 TosaErrorValidator.evWrongOutputType,
3652 TosaErrorValidator.evWrongInputList,
3653 TosaErrorValidator.evWrongOutputList,
3654 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003655 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003656 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003657 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003658 "logical_xor": {
3659 "op": Op.LOGICAL_XOR,
3660 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003661 "build_fcn": (
3662 build_binary_broadcast,
3663 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003664 TosaTensorValuesGen.tvgLazyGenDefault,
3665 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003666 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003667 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003668 "error_if_validators": (
3669 TosaErrorValidator.evRankMismatch,
3670 TosaErrorValidator.evWrongInputType,
3671 TosaErrorValidator.evWrongOutputType,
3672 TosaErrorValidator.evWrongInputList,
3673 TosaErrorValidator.evWrongOutputList,
3674 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003675 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003676 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003677 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003678 "maximum": {
3679 "op": Op.MAXIMUM,
3680 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003681 "build_fcn": (
3682 build_binary_broadcast,
3683 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003684 TosaTensorValuesGen.tvgLazyGenDefault,
3685 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003686 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003687 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003688 "error_if_validators": (
3689 TosaErrorValidator.evRankMismatch,
3690 TosaErrorValidator.evWrongInputType,
3691 TosaErrorValidator.evWrongOutputType,
3692 TosaErrorValidator.evWrongInputList,
3693 TosaErrorValidator.evWrongOutputList,
3694 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003695 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003696 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003697 "data_gen": {
3698 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3699 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003700 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003701 "minimum": {
3702 "op": Op.MINIMUM,
3703 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003704 "build_fcn": (
3705 build_binary_broadcast,
3706 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003707 TosaTensorValuesGen.tvgLazyGenDefault,
3708 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003709 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003710 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003711 "error_if_validators": (
3712 TosaErrorValidator.evRankMismatch,
3713 TosaErrorValidator.evWrongInputType,
3714 TosaErrorValidator.evWrongOutputType,
3715 TosaErrorValidator.evWrongInputList,
3716 TosaErrorValidator.evWrongOutputList,
3717 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003718 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003719 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003720 "data_gen": {
3721 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3722 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003723 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003724 "mul": {
3725 "op": Op.MUL,
3726 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003727 "build_fcn": (
3728 build_mul,
3729 TosaTensorGen.tgBroadcastFuzz,
3730 TosaTensorValuesGen.tvgMul,
3731 TosaArgGen.agMul,
3732 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003733 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003734 "error_if_validators": (
3735 TosaErrorValidator.evWrongInputType,
3736 TosaErrorValidator.evWrongOutputType,
3737 TosaErrorValidator.evWrongInputList,
3738 TosaErrorValidator.evWrongOutputList,
3739 TosaErrorValidator.evRankMismatch,
3740 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003741 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003742 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003743 "data_gen": {
3744 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3745 },
3746 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003747 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003748 "pow": {
3749 "op": Op.POW,
3750 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003751 "build_fcn": (
3752 build_binary_broadcast,
3753 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003754 TosaTensorValuesGen.tvgPow,
3755 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003756 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003757 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003758 "error_if_validators": (
3759 TosaErrorValidator.evRankMismatch,
3760 TosaErrorValidator.evWrongInputType,
3761 TosaErrorValidator.evWrongOutputType,
3762 TosaErrorValidator.evWrongInputList,
3763 TosaErrorValidator.evWrongOutputList,
3764 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003765 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003766 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003767 "data_gen": {
3768 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3769 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003770 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003771 "sub": {
3772 "op": Op.SUB,
3773 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003774 "build_fcn": (
3775 build_binary_broadcast,
3776 TosaTensorGen.tgBroadcastFuzz,
3777 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003778 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003779 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003780 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003781 "error_if_validators": (
3782 TosaErrorValidator.evRankMismatch,
3783 TosaErrorValidator.evWrongInputType,
3784 TosaErrorValidator.evWrongOutputType,
3785 TosaErrorValidator.evWrongInputList,
3786 TosaErrorValidator.evWrongOutputList,
3787 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003788 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003789 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003790 "data_gen": {
3791 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3792 },
3793 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003794 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003795 "table": {
3796 "op": Op.TABLE,
3797 # Use the automatic generation functions to create the input array
3798 # but create the table tensor in the build function, as it may be
3799 # a different type from the input
3800 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003801 "build_fcn": (
3802 build_table,
3803 TosaTensorGen.tgBasic,
3804 TosaTensorValuesGen.tvgDefault,
3805 TosaArgGen.agTable,
3806 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003807 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003808 "error_if_validators": (
3809 TosaErrorValidator.evWrongInputType,
3810 TosaErrorValidator.evWrongOutputType,
3811 TosaErrorValidator.evWrongInputList,
3812 TosaErrorValidator.evWrongOutputList,
3813 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003814 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003815 # Elementwise Unary operators
3816 "abs": {
3817 "op": Op.ABS,
3818 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003819 "build_fcn": (
3820 build_unary,
3821 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003822 TosaTensorValuesGen.tvgLazyGenDefault,
3823 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003824 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003825 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003826 "error_if_validators": (
3827 TosaErrorValidator.evWrongInputType,
3828 TosaErrorValidator.evWrongOutputType,
3829 TosaErrorValidator.evWrongInputList,
3830 TosaErrorValidator.evWrongOutputList,
3831 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003832 "data_gen": {
3833 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3834 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003835 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003836 "bitwise_not": {
3837 "op": Op.BITWISE_NOT,
3838 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003839 "build_fcn": (
3840 build_unary,
3841 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003842 TosaTensorValuesGen.tvgLazyGenDefault,
3843 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003844 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003845 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003846 "error_if_validators": (
3847 TosaErrorValidator.evWrongInputType,
3848 TosaErrorValidator.evWrongOutputType,
3849 TosaErrorValidator.evWrongInputList,
3850 TosaErrorValidator.evWrongOutputList,
3851 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003852 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003853 "ceil": {
3854 "op": Op.CEIL,
3855 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003856 "build_fcn": (
3857 build_unary,
3858 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003859 TosaTensorValuesGen.tvgLazyGenDefault,
3860 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003861 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003862 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003863 "error_if_validators": (
3864 TosaErrorValidator.evWrongInputType,
3865 TosaErrorValidator.evWrongOutputType,
3866 TosaErrorValidator.evWrongInputList,
3867 TosaErrorValidator.evWrongOutputList,
3868 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003869 "data_gen": {
3870 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3871 },
3872 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003873 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003874 "clz": {
3875 "op": Op.CLZ,
3876 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003877 "build_fcn": (
3878 build_unary,
3879 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003880 TosaTensorValuesGen.tvgLazyGenDefault,
3881 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003882 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003883 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003884 "error_if_validators": (
3885 TosaErrorValidator.evWrongInputType,
3886 TosaErrorValidator.evWrongOutputType,
3887 TosaErrorValidator.evWrongInputList,
3888 TosaErrorValidator.evWrongOutputList,
3889 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003890 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003891 "exp": {
3892 "op": Op.EXP,
3893 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003894 "build_fcn": (
3895 build_unary,
3896 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003897 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003898 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003899 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003900 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003901 "error_if_validators": (
3902 TosaErrorValidator.evWrongInputType,
3903 TosaErrorValidator.evWrongOutputType,
3904 TosaErrorValidator.evWrongInputList,
3905 TosaErrorValidator.evWrongOutputList,
3906 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003907 "data_gen": {
3908 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3909 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003910 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003911 "floor": {
3912 "op": Op.FLOOR,
3913 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003914 "build_fcn": (
3915 build_unary,
3916 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003917 TosaTensorValuesGen.tvgLazyGenDefault,
3918 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003919 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003920 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003921 "error_if_validators": (
3922 TosaErrorValidator.evWrongInputType,
3923 TosaErrorValidator.evWrongOutputType,
3924 TosaErrorValidator.evWrongInputList,
3925 TosaErrorValidator.evWrongOutputList,
3926 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003927 "data_gen": {
3928 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3929 },
3930 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003931 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003932 "log": {
3933 "op": Op.LOG,
3934 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003935 "build_fcn": (
3936 build_unary,
3937 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003938 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003939 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003940 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003941 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003942 "error_if_validators": (
3943 TosaErrorValidator.evWrongInputType,
3944 TosaErrorValidator.evWrongOutputType,
3945 TosaErrorValidator.evWrongInputList,
3946 TosaErrorValidator.evWrongOutputList,
3947 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003948 "data_gen": {
3949 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3950 },
3951 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003952 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003953 "logical_not": {
3954 "op": Op.LOGICAL_NOT,
3955 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003956 "build_fcn": (
3957 build_unary,
3958 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003959 TosaTensorValuesGen.tvgLazyGenDefault,
3960 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003961 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003962 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003963 "error_if_validators": (
3964 TosaErrorValidator.evWrongInputType,
3965 TosaErrorValidator.evWrongOutputType,
3966 TosaErrorValidator.evWrongInputList,
3967 TosaErrorValidator.evWrongOutputList,
3968 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003969 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003970 "negate": {
3971 "op": Op.NEGATE,
3972 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003973 "build_fcn": (
3974 build_unary,
3975 TosaTensorGen.tgBasic,
3976 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003977 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003978 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003979 "qgen": TosaQuantGen.qgUnary,
3980 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003981 "error_if_validators": (
3982 TosaErrorValidator.evInputZeroPointNotZero,
3983 TosaErrorValidator.evOutputZeroPointNotZero,
3984 TosaErrorValidator.evWrongInputType,
3985 TosaErrorValidator.evWrongOutputType,
3986 TosaErrorValidator.evWrongInputList,
3987 TosaErrorValidator.evWrongOutputList,
3988 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003989 "data_gen": {
3990 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3991 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003992 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003993 "reciprocal": {
3994 "op": Op.RECIPROCAL,
3995 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003996 "build_fcn": (
3997 build_unary,
3998 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003999 TosaTensorValuesGen.tvgLazyGenDefault,
4000 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004001 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004002 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004003 "error_if_validators": (
4004 TosaErrorValidator.evWrongInputType,
4005 TosaErrorValidator.evWrongOutputType,
4006 TosaErrorValidator.evWrongInputList,
4007 TosaErrorValidator.evWrongOutputList,
4008 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004009 "data_gen": {
4010 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4011 },
4012 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004013 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004014 "rsqrt": {
4015 "op": Op.RSQRT,
4016 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004017 "build_fcn": (
4018 build_unary,
4019 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004020 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004021 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004022 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004023 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004024 "error_if_validators": (
4025 TosaErrorValidator.evWrongInputType,
4026 TosaErrorValidator.evWrongOutputType,
4027 TosaErrorValidator.evWrongInputList,
4028 TosaErrorValidator.evWrongOutputList,
4029 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004030 "data_gen": {
4031 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4032 },
4033 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004034 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004035 # Elementwise Ternary operators
4036 "select": {
4037 "op": Op.SELECT,
4038 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004039 "build_fcn": (
4040 build_select,
4041 TosaTensorGen.tgBroadcastFuzz,
4042 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004043 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004044 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004045 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004046 "error_if_validators": (
4047 TosaErrorValidator.evRankMismatch,
4048 TosaErrorValidator.evWrongInputType,
4049 TosaErrorValidator.evWrongOutputType,
4050 TosaErrorValidator.evWrongInputList,
4051 TosaErrorValidator.evWrongOutputList,
4052 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004053 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004054 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004055 "data_gen": {
4056 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4057 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004058 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004059 # Comparison operators
4060 "equal": {
4061 "op": Op.EQUAL,
4062 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004063 "build_fcn": (
4064 build_comparison,
4065 TosaTensorGen.tgBroadcastFuzz,
4066 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004067 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004068 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004069 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004070 "error_if_validators": (
4071 TosaErrorValidator.evRankMismatch,
4072 TosaErrorValidator.evWrongInputType,
4073 TosaErrorValidator.evWrongOutputType,
4074 TosaErrorValidator.evWrongInputList,
4075 TosaErrorValidator.evWrongOutputList,
4076 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004077 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004078 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004079 "data_gen": {
4080 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4081 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004082 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004083 "greater_equal": {
4084 "op": Op.GREATER_EQUAL,
4085 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004086 "build_fcn": (
4087 build_comparison,
4088 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004089 TosaTensorValuesGen.tvgLazyGenDefault,
4090 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004091 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004092 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004093 "error_if_validators": (
4094 TosaErrorValidator.evRankMismatch,
4095 TosaErrorValidator.evWrongInputType,
4096 TosaErrorValidator.evWrongOutputType,
4097 TosaErrorValidator.evWrongInputList,
4098 TosaErrorValidator.evWrongOutputList,
4099 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004100 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004101 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004102 "data_gen": {
4103 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4104 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004105 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004106 "greater": {
4107 "op": Op.GREATER,
4108 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004109 "build_fcn": (
4110 build_comparison,
4111 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004112 TosaTensorValuesGen.tvgLazyGenDefault,
4113 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004114 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004115 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004116 "error_if_validators": (
4117 TosaErrorValidator.evRankMismatch,
4118 TosaErrorValidator.evWrongInputType,
4119 TosaErrorValidator.evWrongOutputType,
4120 TosaErrorValidator.evWrongInputList,
4121 TosaErrorValidator.evWrongOutputList,
4122 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004123 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004124 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004125 "data_gen": {
4126 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4127 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004128 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004129 # Reduction operators
4130 "reduce_all": {
4131 "op": Op.REDUCE_ALL,
4132 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004133 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004134 "build_fcn": (
4135 build_reduce,
4136 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004137 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004138 TosaArgGen.agAxis,
4139 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004140 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004141 "error_if_validators": (
4142 TosaErrorValidator.evAxisLargerRank,
4143 TosaErrorValidator.evAxisSmallerZero,
4144 TosaErrorValidator.evShapeOfAxisNotOne,
4145 TosaErrorValidator.evWrongInputType,
4146 TosaErrorValidator.evWrongOutputType,
4147 TosaErrorValidator.evWrongRank,
4148 TosaErrorValidator.evWrongInputList,
4149 TosaErrorValidator.evWrongOutputList,
4150 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004151 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004152 "reduce_any": {
4153 "op": Op.REDUCE_ANY,
4154 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004155 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004156 "build_fcn": (
4157 build_reduce,
4158 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004159 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004160 TosaArgGen.agAxis,
4161 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004162 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004163 "error_if_validators": (
4164 TosaErrorValidator.evAxisLargerRank,
4165 TosaErrorValidator.evAxisSmallerZero,
4166 TosaErrorValidator.evShapeOfAxisNotOne,
4167 TosaErrorValidator.evWrongInputType,
4168 TosaErrorValidator.evWrongOutputType,
4169 TosaErrorValidator.evWrongRank,
4170 TosaErrorValidator.evWrongInputList,
4171 TosaErrorValidator.evWrongOutputList,
4172 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004173 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004174 "reduce_max": {
4175 "op": Op.REDUCE_MAX,
4176 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004177 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004178 "build_fcn": (
4179 build_reduce,
4180 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004181 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004182 TosaArgGen.agAxis,
4183 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004184 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004185 "error_if_validators": (
4186 TosaErrorValidator.evAxisLargerRank,
4187 TosaErrorValidator.evAxisSmallerZero,
4188 TosaErrorValidator.evShapeOfAxisNotOne,
4189 TosaErrorValidator.evWrongInputType,
4190 TosaErrorValidator.evWrongOutputType,
4191 TosaErrorValidator.evWrongRank,
4192 TosaErrorValidator.evWrongInputList,
4193 TosaErrorValidator.evWrongOutputList,
4194 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004195 "data_gen": {
4196 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4197 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004198 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004199 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004200 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004201 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004202 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004203 "build_fcn": (
4204 build_reduce,
4205 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004206 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004207 TosaArgGen.agAxis,
4208 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004209 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004210 "error_if_validators": (
4211 TosaErrorValidator.evAxisLargerRank,
4212 TosaErrorValidator.evAxisSmallerZero,
4213 TosaErrorValidator.evShapeOfAxisNotOne,
4214 TosaErrorValidator.evWrongInputType,
4215 TosaErrorValidator.evWrongOutputType,
4216 TosaErrorValidator.evWrongRank,
4217 TosaErrorValidator.evWrongInputList,
4218 TosaErrorValidator.evWrongOutputList,
4219 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004220 "data_gen": {
4221 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4222 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004223 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004224 "reduce_product": {
4225 "op": Op.REDUCE_PRODUCT,
4226 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004227 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004228 "build_fcn": (
4229 build_reduce,
4230 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004231 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004232 TosaArgGen.agAxis,
4233 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004234 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004235 "error_if_validators": (
4236 TosaErrorValidator.evAxisLargerRank,
4237 TosaErrorValidator.evAxisSmallerZero,
4238 TosaErrorValidator.evShapeOfAxisNotOne,
4239 TosaErrorValidator.evWrongInputType,
4240 TosaErrorValidator.evWrongOutputType,
4241 TosaErrorValidator.evWrongRank,
4242 TosaErrorValidator.evWrongInputList,
4243 TosaErrorValidator.evWrongOutputList,
4244 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004245 "data_gen": {
4246 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4247 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004248 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004249 "reduce_sum": {
4250 "op": Op.REDUCE_SUM,
4251 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004252 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004253 "build_fcn": (
4254 build_reduce,
4255 TosaTensorGen.tgBasic,
4256 TosaTensorValuesGen.tvgReduceSum,
4257 TosaArgGen.agAxis,
4258 ),
James Ward24dbc422022-10-19 12:20:31 +01004259 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004260 "error_if_validators": (
4261 TosaErrorValidator.evAxisLargerRank,
4262 TosaErrorValidator.evAxisSmallerZero,
4263 TosaErrorValidator.evShapeOfAxisNotOne,
4264 TosaErrorValidator.evWrongInputType,
4265 TosaErrorValidator.evWrongOutputType,
4266 TosaErrorValidator.evWrongRank,
4267 TosaErrorValidator.evWrongInputList,
4268 TosaErrorValidator.evWrongOutputList,
4269 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004270 "data_gen": {
4271 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4272 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004273 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004274 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004275 "concat": {
4276 "op": Op.CONCAT,
4277 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004278 "build_fcn": (
4279 build_concat,
4280 TosaTensorGen.tgConcat,
4281 TosaTensorValuesGen.tvgConcat,
4282 TosaArgGen.agAxis,
4283 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004284 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004285 "error_if_validators": (
4286 TosaErrorValidator.evAxisLargerRank,
4287 TosaErrorValidator.evAxisSmallerZero,
4288 TosaErrorValidator.evConcatInputRankMismatch,
4289 TosaErrorValidator.evConcatShapeSumMismatch,
4290 TosaErrorValidator.evConcatInputDimMismatch,
4291 TosaErrorValidator.evWrongInputType,
4292 TosaErrorValidator.evWrongOutputType,
4293 TosaErrorValidator.evWrongOutputList,
4294 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004295 "data_gen": {
4296 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4297 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004298 },
4299 "pad": {
4300 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004301 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004302 "build_fcn": (
4303 build_pad,
4304 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004305 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004306 TosaArgGen.agPad,
4307 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004308 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004309 "error_if_validators": (
4310 TosaErrorValidator.evWrongInputType,
4311 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004312 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004313 TosaErrorValidator.evWrongOutputType,
4314 TosaErrorValidator.evWrongInputList,
4315 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004316 TosaErrorValidator.evRankMismatch,
4317 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004318 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004319 "data_gen": {
4320 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4321 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004322 },
Won Jeona21b2e82023-08-10 10:33:01 +00004323 "dim": {
4324 "op": Op.DIM,
4325 "operands": (1, 0),
4326 "build_fcn": (
4327 build_dim,
4328 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004329 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004330 TosaArgGen.agAxis,
4331 ),
4332 "types": TYPE_FIB,
4333 "error_if_validators": (
4334 TosaErrorValidator.evAxisLargerRank,
4335 TosaErrorValidator.evAxisSmallerZero,
4336 TosaErrorValidator.evWrongInputType,
4337 TosaErrorValidator.evWrongInputList,
4338 TosaErrorValidator.evWrongOutputList,
4339 TosaErrorValidator.evWrongRank,
4340 ),
4341 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004342 "reshape": {
4343 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004344 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004345 "build_fcn": (
4346 build_reshape,
4347 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004348 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004349 TosaArgGen.agReshape,
4350 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004351 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004352 "error_if_validators": (
4353 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4354 TosaErrorValidator.evWrongInputType,
4355 TosaErrorValidator.evWrongOutputType,
4356 TosaErrorValidator.evWrongInputList,
4357 TosaErrorValidator.evWrongOutputList,
4358 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004359 "data_gen": {
4360 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4361 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004362 },
4363 "reverse": {
4364 "op": Op.REVERSE,
4365 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004366 "build_fcn": (
4367 build_reverse,
4368 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004369 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004370 TosaArgGen.agAxis,
4371 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004372 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004373 "error_if_validators": (
4374 TosaErrorValidator.evAxisSmallerZero,
4375 TosaErrorValidator.evAxisLargerRank,
4376 TosaErrorValidator.evWrongInputType,
4377 TosaErrorValidator.evWrongOutputType,
4378 TosaErrorValidator.evWrongInputList,
4379 TosaErrorValidator.evWrongOutputList,
4380 ),
evacha0198477222024-01-26 12:25:32 +00004381 "data_gen": {
4382 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4383 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004384 },
4385 "slice": {
4386 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004387 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004388 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004389 "build_fcn": (
4390 build_slice,
4391 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004392 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004393 TosaArgGen.agSlice,
4394 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004395 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004396 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004397 # TODO Turn off these error categories for now as the reference
4398 # model cannot allocate memory space for empty tensor. We probably
4399 # can report an accurate error messege at the right place during
4400 # exeuction.
4401 # TosaErrorValidator.evStartSmallerZero,
4402 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004403 TosaErrorValidator.evStartSizeOutsideBounds,
4404 TosaErrorValidator.evSizeOutputShapeMismatch,
4405 TosaErrorValidator.evInputSizeStartLengthMismatch,
4406 TosaErrorValidator.evWrongRank,
4407 TosaErrorValidator.evWrongInputType,
4408 TosaErrorValidator.evWrongOutputType,
4409 TosaErrorValidator.evWrongInputList,
4410 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004411 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004412 ),
evacha017f7d4252024-01-24 12:08:09 +00004413 "data_gen": {
4414 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4415 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004416 },
4417 "tile": {
4418 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004419 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004420 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004421 "build_fcn": (
4422 build_tile,
4423 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004424 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004425 TosaArgGen.agTile,
4426 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004427 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004428 "error_if_validators": (
4429 TosaErrorValidator.evWrongInputType,
4430 TosaErrorValidator.evWrongOutputType,
4431 TosaErrorValidator.evWrongInputList,
4432 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004433 TosaErrorValidator.evRankMismatch,
4434 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004435 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004436 "data_gen": {
4437 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4438 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004439 },
4440 "transpose": {
4441 "op": Op.TRANSPOSE,
4442 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004443 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004444 "build_fcn": (
4445 build_transpose,
4446 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004447 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004448 TosaArgGen.agTranspose,
4449 ),
4450 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004451 "error_if_validators": (
4452 TosaErrorValidator.evIndexOutsideBounds,
4453 TosaErrorValidator.evIndexUsedTwice,
4454 TosaErrorValidator.evWrongInputType,
4455 TosaErrorValidator.evWrongOutputType,
4456 TosaErrorValidator.evWrongInputList,
4457 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004458 TosaErrorValidator.evWrongRank,
4459 TosaErrorValidator.evRankMismatch,
4460 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004461 ),
evacha0198477222024-01-26 12:25:32 +00004462 "data_gen": {
4463 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4464 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004465 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004466 # Data nodes
4467 "const": {
4468 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004469 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004470 "build_fcn": (
4471 build_const,
4472 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004473 TosaTensorValuesGen.tvgLazyGenDefault,
4474 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004475 ),
Luke Hutton65872422023-02-20 10:33:04 +00004476 "types": TYPE_FIB + [DType.INT48],
evacha0198477222024-01-26 12:25:32 +00004477 "data_gen": {
4478 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4479 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004480 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004481 "identity": {
4482 "op": Op.IDENTITY,
4483 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004484 "build_fcn": (
4485 build_unary,
4486 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004487 TosaTensorValuesGen.tvgLazyGenDefault,
4488 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004489 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004490 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004491 "data_gen": {
4492 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4493 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004494 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004495 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004496 "gather": {
4497 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004498 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004499 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004500 "build_fcn": (
4501 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004502 TosaTensorGen.tgGather,
4503 TosaTensorValuesGen.tvgGather,
4504 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004505 ),
James Ward24dbc422022-10-19 12:20:31 +01004506 "types": (
4507 DType.INT8,
4508 DType.INT16,
4509 DType.INT32,
4510 DType.FP16,
4511 DType.BF16,
4512 DType.FP32,
4513 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004514 "error_if_validators": (
4515 TosaErrorValidator.evWrongInputType,
4516 TosaErrorValidator.evWrongOutputType,
4517 TosaErrorValidator.evWrongInputList,
4518 TosaErrorValidator.evWrongOutputList,
4519 TosaErrorValidator.evWrongRank,
4520 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004521 "data_gen": {
4522 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4523 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004524 },
4525 "scatter": {
4526 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004527 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004528 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004529 "build_fcn": (
4530 build_scatter,
4531 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004532 TosaTensorValuesGen.tvgScatter,
4533 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004534 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004535 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004536 "error_if_validators": (
4537 TosaErrorValidator.evWrongInputType,
4538 TosaErrorValidator.evWrongOutputType,
4539 TosaErrorValidator.evWrongInputList,
4540 TosaErrorValidator.evWrongOutputList,
4541 TosaErrorValidator.evWrongRank,
4542 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004543 "data_gen": {
4544 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4545 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004546 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004547 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004548 "resize": {
4549 "op": Op.RESIZE,
4550 "operands": (1, 0),
4551 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004552 "build_fcn": (
4553 build_resize,
4554 TosaTensorGen.tgNHWC,
4555 TosaTensorValuesGen.tvgDefault,
4556 TosaArgGen.agResize,
4557 ),
James Ward24dbc422022-10-19 12:20:31 +01004558 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004559 "invalid_test_validators": (
4560 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004561 ),
4562 "error_if_validators": (
4563 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004564 TosaErrorValidator.evScaleSmallerEqualZero,
4565 TosaErrorValidator.evScaleNLargerMax,
4566 TosaErrorValidator.evScaleDLargerMax,
4567 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004568 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004569 TosaErrorValidator.evBorderSmallerMin,
4570 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004571 TosaErrorValidator.evWrongInputType,
4572 TosaErrorValidator.evWrongOutputType,
4573 TosaErrorValidator.evWrongRank,
4574 TosaErrorValidator.evWrongInputList,
4575 TosaErrorValidator.evWrongOutputList,
4576 TosaErrorValidator.evBatchMismatch,
4577 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004578 TosaErrorValidator.evResizeOutputShapeMismatch,
4579 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004580 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004581 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004582 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004583 "cast": {
4584 "op": Op.CAST,
4585 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004586 "build_fcn": (
4587 build_cast,
4588 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004589 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004590 TosaArgGen.agCast,
4591 ),
James Ward8b390432022-08-12 20:48:56 +01004592 "types": (
4593 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004594 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004595 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004596 DType.INT8,
4597 DType.INT16,
4598 DType.INT32,
4599 DType.BOOL,
4600 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004601 "error_if_validators": (
4602 TosaErrorValidator.evWrongInputType,
4603 TosaErrorValidator.evWrongOutputType,
4604 TosaErrorValidator.evWrongInputList,
4605 TosaErrorValidator.evWrongOutputList,
4606 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004607 "data_gen": {
4608 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4609 },
4610 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004611 },
4612 "rescale": {
4613 "op": Op.RESCALE,
4614 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004615 "build_fcn": (
4616 build_rescale,
4617 TosaTensorGen.tgBasic,
4618 TosaTensorValuesGen.tvgDefault,
4619 TosaArgGen.agRescale,
4620 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004621 "types": [
4622 DType.UINT8,
4623 DType.INT8,
4624 DType.INT16,
4625 DType.INT32,
4626 DType.INT48,
4627 DType.UINT16,
4628 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004629 "error_if_validators": (
4630 TosaErrorValidator.evInputZeroPointNotZero,
4631 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004632 TosaErrorValidator.evU16InputZeroPointNotValid,
4633 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004634 TosaErrorValidator.evScaleTrue,
4635 TosaErrorValidator.evScaleNotTrue,
4636 TosaErrorValidator.evWrongInputType,
4637 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004638 TosaErrorValidator.evWrongInputList,
4639 TosaErrorValidator.evWrongOutputList,
4640 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004641 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004642 # Custom
4643 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004644 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004645 # Two varients of cond_if, one that generates one of two constant tensors (no
4646 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4647 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004648 "cond_if_const": {
4649 "op": Op.COND_IF,
4650 "operands": (0, 2),
4651 "build_fcn": (
4652 build_cond_if_const,
4653 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004654 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004655 TosaArgGen.agCondIf,
4656 ),
4657 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004658 "error_if_validators": (
4659 TosaErrorValidator.evOutputListThenGraphMismatch,
4660 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004661 TosaErrorValidator.evCondIfCondNotMatchingBool,
4662 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004663 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004664 },
4665 "cond_if_binary": {
4666 "op": Op.COND_IF,
4667 "operands": (2, 0),
4668 "build_fcn": (
4669 build_cond_if_binary,
4670 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004671 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004672 TosaArgGen.agCondIf,
4673 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004674 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004675 "error_if_validators": (
4676 TosaErrorValidator.evInputListThenGraphMismatch,
4677 TosaErrorValidator.evInputListElseGraphMismatch,
4678 TosaErrorValidator.evOutputListThenGraphMismatch,
4679 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004680 TosaErrorValidator.evCondIfCondNotMatchingBool,
4681 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004682 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004683 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004684 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004685 "while_loop": {
4686 "op": Op.WHILE_LOOP,
4687 "operands": (0, 1),
4688 "build_fcn": (
4689 build_while_loop,
4690 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004691 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004692 TosaArgGen.agWhileLoop,
4693 ),
4694 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004695 "error_if_validators": (
4696 TosaErrorValidator.evInputListOutputListMismatch,
4697 TosaErrorValidator.evInputListCondGraphMismatch,
4698 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4699 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4700 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004701 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004702 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004703 },
Luke Hutton57287132023-02-06 14:54:18 +00004704 "fft2d": {
4705 "op": Op.FFT2D,
4706 "operands": (2, 0),
4707 "rank": (3, 3),
4708 "build_fcn": (
4709 build_fft2d,
4710 TosaTensorGen.tgFFT2d,
4711 TosaTensorValuesGen.tvgDefault,
4712 TosaArgGen.agFFT2d,
4713 ),
4714 "types": [DType.FP32],
4715 "error_if_validators": (
4716 TosaErrorValidator.evWrongInputType,
4717 TosaErrorValidator.evWrongOutputType,
4718 TosaErrorValidator.evWrongInputList,
4719 TosaErrorValidator.evWrongOutputList,
4720 TosaErrorValidator.evWrongRank,
4721 TosaErrorValidator.evBatchMismatch,
4722 TosaErrorValidator.evKernelNotPowerOfTwo,
4723 TosaErrorValidator.evFFTInputShapeMismatch,
4724 TosaErrorValidator.evFFTOutputShapeMismatch,
4725 ),
4726 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004727 "rfft2d": {
4728 "op": Op.RFFT2D,
4729 "operands": (1, 0),
4730 "rank": (3, 3),
4731 "build_fcn": (
4732 build_rfft2d,
4733 TosaTensorGen.tgRFFT2d,
4734 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004735 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004736 ),
4737 "types": [DType.FP32],
4738 "error_if_validators": (
4739 TosaErrorValidator.evWrongInputType,
4740 TosaErrorValidator.evWrongOutputType,
4741 TosaErrorValidator.evWrongInputList,
4742 TosaErrorValidator.evWrongOutputList,
4743 TosaErrorValidator.evWrongRank,
4744 TosaErrorValidator.evBatchMismatch,
4745 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004746 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004747 ),
4748 },
Won Jeon74342e52024-01-09 00:34:40 +00004749 # Shape
4750 "add_shape": {
4751 "op": Op.ADD_SHAPE,
4752 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004753 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004754 "build_fcn": (
4755 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004756 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004757 TosaTensorValuesGen.tvgAddSub,
4758 TosaArgGen.agNone,
4759 ),
4760 "types": [DType.SHAPE],
4761 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4762 },
4763 "sub_shape": {
4764 "op": Op.SUB_SHAPE,
4765 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004766 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004767 "build_fcn": (
4768 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004769 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004770 TosaTensorValuesGen.tvgAddSub,
4771 TosaArgGen.agNone,
4772 ),
4773 "types": [DType.SHAPE],
4774 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4775 },
4776 "mul_shape": {
4777 "op": Op.MUL_SHAPE,
4778 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004779 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004780 "build_fcn": (
4781 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004782 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004783 TosaTensorValuesGen.tvgMul,
4784 TosaArgGen.agNone,
4785 ),
4786 "types": [DType.SHAPE],
4787 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4788 },
4789 "div_shape": {
4790 "op": Op.DIV_SHAPE,
4791 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004792 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004793 "build_fcn": (
4794 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004795 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004796 TosaTensorValuesGen.tvgIntDiv,
4797 TosaArgGen.agNone,
4798 ),
4799 "types": [DType.SHAPE],
4800 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4801 },
4802 "concat_shape": {
4803 "op": Op.CONCAT_SHAPE,
4804 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004805 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004806 "build_fcn": (
4807 build_concat,
4808 TosaTensorGen.tgConcat,
4809 TosaTensorValuesGen.tvgConcat,
4810 TosaArgGen.agNone,
4811 ),
4812 "types": [DType.SHAPE],
4813 "error_if_validators": (),
4814 },
4815 "const_shape": {
4816 "op": Op.CONST_SHAPE,
4817 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004818 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004819 "build_fcn": (
4820 build_const,
4821 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004822 TosaTensorValuesGen.tvgLazyGenDefault,
4823 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00004824 ),
4825 "types": [DType.SHAPE],
4826 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004827 }
4828
Kevin Cheng550ccc52021-03-03 11:21:43 -08004829
Eric Kunzee5e26762020-10-13 16:11:07 -07004830class OutputShaper:
4831 # Methods in this class compute the expected output shape and datatype
4832 # for common classes of operations
4833 def __init__(self):
4834 pass
4835
4836 # These methods return arguments that can be used for
4837 # creating a new output tensor
4838 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004839 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4840 if error_name != ErrorIf.RankMismatch:
4841 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004842 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004843
4844 shape = []
4845 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004846 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004847 shape.append(b.shape[i])
4848 else:
4849 shape.append(a.shape[i])
4850
Jerry Ge135c9552023-05-23 20:59:32 +00004851 fuzz_idx = rng.integers(0, len(a.shape))
4852 if error_name == ErrorIf.DimensionMismatch:
4853 shape[fuzz_idx] += 1
4854
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004855 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004856 all_dtypes = [
4857 DType.INT8,
4858 DType.INT16,
4859 DType.INT32,
4860 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004861 DType.FP16,
4862 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004863 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004864 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004865 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4866 outputDType = rng.choice(wrong_dtypes)
4867 else:
4868 outputDType = a.dtype
4869
4870 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004871
4872 @staticmethod
4873 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004874 assert len(a.shape) == len(b.shape)
4875 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004876
4877 shape = []
4878 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004879 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004880 shape.append(a.shape[i])
4881
Kevin Cheng550ccc52021-03-03 11:21:43 -08004882 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004883
4884 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004885 def unaryOp(ser, rng, a, error_name=None):
4886 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004887 all_dtypes = [
4888 DType.INT8,
4889 DType.INT16,
4890 DType.INT32,
4891 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004892 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004893 DType.FP16,
4894 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004895 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004896 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4897 outputDType = rng.choice(wrong_dtypes)
4898 else:
4899 outputDType = a.dtype
4900
4901 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004902
4903 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004904 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004905 if error_name != ErrorIf.RankMismatch:
4906 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004907 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004908
4909 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004910 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004911 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004912 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4913 else:
4914 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004915
Jerry Ge135c9552023-05-23 20:59:32 +00004916 fuzz_idx = rng.integers(0, len(a.shape))
4917 if error_name == ErrorIf.DimensionMismatch:
4918 shape[fuzz_idx] += 1
4919
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004920 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004921 all_dtypes = [
4922 DType.INT8,
4923 DType.INT16,
4924 DType.INT32,
4925 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004926 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004927 DType.FP16,
4928 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004929 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004930 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4931 outputDType = rng.choice(wrong_dtypes)
4932 else:
4933 outputDType = a.dtype
4934
4935 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004936
4937 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004938 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004939 if error_name != ErrorIf.RankMismatch:
4940 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004941 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004942
4943 # Do broadcast
4944 shape = []
4945 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004946 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004947 shape.append(b.shape[i])
4948 else:
4949 shape.append(a.shape[i])
4950
Jerry Ge135c9552023-05-23 20:59:32 +00004951 fuzz_idx = rng.integers(0, len(a.shape))
4952 if error_name == ErrorIf.DimensionMismatch:
4953 shape[fuzz_idx] += 1
4954
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004955 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004956 wrong_dtypes = [
4957 DType.INT8,
4958 DType.INT16,
4959 DType.INT32,
4960 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004961 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004962 DType.FP16,
4963 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004964 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004965 outputDType = rng.choice(wrong_dtypes)
4966 else:
4967 outputDType = DType.BOOL
4968
4969 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004970
4971 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004972 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004973 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004974 if error_name not in [
4975 ErrorIf.AxisSmallerZero,
4976 ErrorIf.AxisLargerRank,
4977 ErrorIf.ShapeOfAxisNotOne,
4978 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004979 shape[axis] = 1
4980 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4981 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004982
Matthew Haddond6ce7252021-09-29 15:35:44 +01004983 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004984 all_dtypes = [
4985 DType.INT8,
4986 DType.INT16,
4987 DType.INT32,
4988 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004989 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004990 DType.FP16,
4991 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004992 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004993 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4994 outputDType = rng.choice(wrong_dtypes)
4995 else:
4996 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004997
Matthew Haddond6ce7252021-09-29 15:35:44 +01004998 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004999
5000 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005001 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005002 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005003
5004 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5005 del shape[axis]
5006
5007 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5008 remove = rng.choice([True, False])
5009 if remove and len(shape) > 1:
5010 del shape[0]
5011 else:
5012 shape.append(1)
5013 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5014 for i in range(len(shape)):
5015 shape[i] = shape[i] + rng.integers(1, 10)
5016
5017 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005018 all_dtypes = [
5019 DType.INT8,
5020 DType.INT16,
5021 DType.INT32,
5022 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005023 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005024 DType.FP16,
5025 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005026 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005027 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5028 outputDType = rng.choice(wrong_dtypes)
5029 else:
5030 outputDType = DType.INT32
5031
5032 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005033
5034 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005035 def conv2dOp(
5036 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5037 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005038
5039 # IFM: NHWC
5040 # Filter: OHWI
5041 # OFM: NHWC
5042
Kevin Cheng550ccc52021-03-03 11:21:43 -08005043 h = (
5044 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005045 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005046 + padding[0]
5047 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005048 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005049 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005050
Kevin Cheng550ccc52021-03-03 11:21:43 -08005051 w = (
5052 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005053 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005054 + padding[2]
5055 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005056 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005057 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005058
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005059 if error_name == ErrorIf.ConvOutputShapeMismatch:
5060 choices = [1, 2, 3]
5061 change = rng.choice(choices)
5062 # increment in multiples of stride to not hit non-integer error case
5063 if change in [1, 3]:
5064 h = h + (rng.choice(choices) * strides[0])
5065 if change in [2, 3]:
5066 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005067
Eric Kunzee5e26762020-10-13 16:11:07 -07005068 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5069
James Ward8b390432022-08-12 20:48:56 +01005070 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005071 # Pick some potentially correct output dtype if input type is incorrect
5072 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005073 else:
James Ward8b390432022-08-12 20:48:56 +01005074 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005075
5076 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005077 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005078 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005079 else:
5080 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005081 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005082 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005083
Kevin Cheng550ccc52021-03-03 11:21:43 -08005084 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005085
5086 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005087 def conv3dOp(
5088 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5089 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005090
5091 # IFM: NDHWC
5092 # Filter: ODHWI
5093 # OFM: NDHWC
5094
5095 d = (
5096 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005097 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005098 + padding[0]
5099 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005100 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005101 ) // strides[0] + 1
5102
5103 h = (
5104 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005105 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005106 + padding[2]
5107 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005108 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005109 ) // strides[1] + 1
5110
5111 w = (
5112 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005113 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005114 + padding[4]
5115 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005116 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005117 ) // strides[2] + 1
5118
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005119 if error_name == ErrorIf.ConvOutputShapeMismatch:
5120 choices = [1, 2, 3, 4]
5121 change = rng.choice(choices)
5122 # increment in multiples of stride to not hit non-integer error case
5123 if change in [1, 4]:
5124 d = d + (rng.choice(choices) * strides[0])
5125 if change in [2, 4]:
5126 h = h + (rng.choice(choices) * strides[1])
5127 if change in [3, 4]:
5128 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005129
Kevin Cheng1533b852021-09-01 12:51:58 -07005130 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5131
James Ward8b390432022-08-12 20:48:56 +01005132 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005133 # Pick some potentially correct output dtype if input type is incorrect
5134 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005135 else:
James Ward8b390432022-08-12 20:48:56 +01005136 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005137
5138 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005139 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005140 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005141 else:
5142 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005143 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005144 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005145
5146 return ser.addOutput(ofm_shape, out_dtype)
5147
5148 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005149 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005150 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005151 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005152 # IFM: NHWC
5153 # Filter: HWCM
5154 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005155
Kevin Cheng550ccc52021-03-03 11:21:43 -08005156 h = (
5157 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005158 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005159 + padding[0]
5160 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005161 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005162 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005163
Kevin Cheng550ccc52021-03-03 11:21:43 -08005164 w = (
5165 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005166 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005167 + padding[2]
5168 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005169 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005170 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005171
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005172 if error_name == ErrorIf.ConvOutputShapeMismatch:
5173 choices = [1, 2, 3]
5174 change = rng.choice(choices)
5175 # increment in multiples of stride to not hit non-integer error case
5176 if change in [1, 3]:
5177 h = h + (rng.choice(choices) * strides[0])
5178 if change in [2, 3]:
5179 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005180
Eric Kunzee5e26762020-10-13 16:11:07 -07005181 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5182
James Ward8b390432022-08-12 20:48:56 +01005183 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005184 # Pick some potentially correct output dtype if input type is incorrect
5185 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005186 else:
James Ward8b390432022-08-12 20:48:56 +01005187 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005188
5189 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005190 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005191 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005192 else:
5193 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005194 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005195 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005196
Kevin Cheng550ccc52021-03-03 11:21:43 -08005197 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005198
5199 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005200 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005201 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005202 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005203 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005204 h = 1
5205 w = 1
5206 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005207 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5208 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005209
5210 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005211 choices = [1, 2, 3]
5212 change = rng.choice(choices)
5213 # increment in multiples of stride to not hit non-integer error case
5214 if change in [1, 3]:
5215 h = h + (rng.choice(choices) * stride[0])
5216 if change in [2, 3]:
5217 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005218 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005219
5220 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005221 all_dtypes = [
5222 DType.INT8,
5223 DType.INT16,
5224 DType.INT32,
5225 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005226 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005227 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005228 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005229 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005230 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5231 outputDType = rng.choice(wrong_dtypes)
5232 else:
5233 outputDType = ifm.dtype
5234
5235 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005236
5237 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005238 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005239 # input: N, IC
5240 # filter: OC, IC
5241 # output: N, OC
5242
5243 output_shape = [input.shape[0], filter.shape[0]]
5244
James Ward8b390432022-08-12 20:48:56 +01005245 # Validated in arg_gen (also invalidated for ErrorIf)
5246 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005247
Kevin Cheng550ccc52021-03-03 11:21:43 -08005248 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005249
5250 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005251 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005252 # a: N, H, C
5253 # b: N, C, W
5254 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005255
Kevin Cheng2d60f002021-06-09 14:18:32 -07005256 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005257
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005258 if error_name == ErrorIf.WrongOutputType:
5259 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005260 incorrect_types = (
5261 DType.INT4,
5262 DType.INT8,
5263 DType.INT16,
5264 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005265 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005266 DType.FP16,
5267 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005268 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005269 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005270 incorrect_types = (
5271 DType.INT4,
5272 DType.INT8,
5273 DType.INT16,
5274 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005275 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005276 DType.FP16,
5277 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005278 )
James Ward24dbc422022-10-19 12:20:31 +01005279 elif (
5280 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5281 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005282 incorrect_types = (
5283 DType.INT4,
5284 DType.INT8,
5285 DType.INT16,
5286 DType.INT32,
5287 DType.INT48,
5288 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005289 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005290 elif error_name == ErrorIf.WrongInputType:
5291 # Pick some potentially correct output dtype if input type is incorrect
5292 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005293 else:
James Ward8b390432022-08-12 20:48:56 +01005294 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005295
Kevin Cheng550ccc52021-03-03 11:21:43 -08005296 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005297
5298 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005299 def concatOp(ser, rng, axis, inputs, error_name=None):
5300 input1 = inputs[0]
5301 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005302
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005303 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005304 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005305 if not (
5306 # unable to concat tensors of different ranks
5307 error_name == ErrorIf.ConcatInputRankMismatch
5308 # unable to concat tensors along an invalid axis
5309 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005310 ):
5311 for tensor in remaining_inputs:
5312 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005313
Matthew Haddon01c359d2021-10-15 16:30:48 +01005314 if error_name == ErrorIf.ConcatShapeSumMismatch:
5315 output_shape[axis] += rng.integers(5, 10)
5316
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005317 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005318 all_dtypes = {
5319 DType.INT8,
5320 DType.INT16,
5321 DType.INT32,
5322 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005323 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005324 DType.FP16,
5325 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005326 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005327 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5328 outputDType = rng.choice(wrong_dtypes)
5329 else:
5330 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005331
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005332 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005333
5334 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005335 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005336
5337 output_shape = a.shape.copy()
5338
5339 for i in range(len(output_shape)):
5340 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5341
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005342 if error_name == ErrorIf.PadOutputShapeMismatch:
5343 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005344 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005345 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005346 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005347
Matthew Haddone807aae2021-10-11 18:12:58 +01005348 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005349 all_dtypes = [
5350 DType.INT8,
5351 DType.INT16,
5352 DType.INT32,
5353 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005354 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005355 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005356 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005357 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005358 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5359 outputDType = rng.choice(wrong_dtypes)
5360 else:
5361 outputDType = a.dtype
5362
5363 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005364
5365 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005366 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005367 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005368
5369 if error_name == ErrorIf.WrongOutputType:
5370 all_dtypes = [
5371 DType.INT8,
5372 DType.INT16,
5373 DType.INT32,
5374 DType.INT48,
5375 DType.FP32,
5376 DType.FP16,
5377 DType.BF16,
5378 ]
5379 wrong_dtypes = list(set(all_dtypes))
5380 outputDType = rng.choice(wrong_dtypes)
5381 else:
5382 outputDType = DType.SHAPE
5383
5384 return ser.addOutput(output_shape, outputDType)
5385
5386 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005387 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005388 output_shape = shape.copy()
5389
Matthew Haddone807aae2021-10-11 18:12:58 +01005390 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5391 for i in range(len(output_shape)):
5392 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5393
5394 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005395 all_dtypes = [
5396 DType.INT8,
5397 DType.INT16,
5398 DType.INT32,
5399 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005400 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005401 DType.FP16,
5402 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005403 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005404 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5405 outputDType = rng.choice(wrong_dtypes)
5406 else:
5407 outputDType = a.dtype
5408
5409 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005410
5411 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005412 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005413
Matthew Haddone807aae2021-10-11 18:12:58 +01005414 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005415 all_dtypes = [
5416 DType.INT8,
5417 DType.INT16,
5418 DType.INT32,
5419 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005420 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005421 DType.FP16,
5422 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005423 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005424 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005425 outputDType = rng.choice(wrong_dtypes)
5426 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005427 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005428
Luke Huttona4e48ca2023-02-22 11:53:48 +00005429 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005430 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005431 for index in range(len(output_shape)):
5432 if output_shape[index] <= 2:
5433 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5434 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005435 output_shape[index] = output_shape[index] + rng.choice(
5436 [-2, -1, 1, 2]
5437 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005438 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5439 output_shape = input.shape.copy()
5440 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005441 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005442
5443 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005444
5445 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005446 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005447
5448 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005449 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005450
5451 for i in range(len(output_shape)):
5452 output_shape[i] = a.shape[i] * multiples[i]
5453
Luke Huttona4e48ca2023-02-22 11:53:48 +00005454 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005455 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005456
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005457 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005458 all_dtypes = [
5459 DType.INT8,
5460 DType.INT16,
5461 DType.INT32,
5462 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005463 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005464 DType.FP16,
5465 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005466 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005467 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5468 outputDType = rng.choice(wrong_dtypes)
5469 else:
5470 outputDType = a.dtype
5471
5472 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005473
5474 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005475 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005476 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005477
Kevin Cheng550ccc52021-03-03 11:21:43 -08005478 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005479
Luke Huttona4e48ca2023-02-22 11:53:48 +00005480 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005481 for i in range(len(output_shape)):
5482 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005483
Luke Huttona4e48ca2023-02-22 11:53:48 +00005484 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5485 for i in range(len(output_shape)):
5486 output_shape[i] += rng.integers(1, 10)
5487 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005488 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005489
Matthew Haddone807aae2021-10-11 18:12:58 +01005490 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005491 all_dtypes = [
5492 DType.INT8,
5493 DType.INT16,
5494 DType.INT32,
5495 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005496 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005497 DType.FP16,
5498 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005499 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005500 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5501 outputDType = rng.choice(wrong_dtypes)
5502 else:
5503 outputDType = a.dtype
5504
5505 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005506
5507 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005508 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005509 if error_name != ErrorIf.WrongRank:
5510 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005511 assert len(indices.shape) == 2
5512 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005513
Kevin Cheng77d0f762020-11-24 10:26:32 -08005514 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5515
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005516 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005517 all_dtypes = [
5518 DType.INT8,
5519 DType.INT16,
5520 DType.INT32,
5521 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005522 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005523 DType.FP16,
5524 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005525 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005526 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5527 outputDType = rng.choice(wrong_dtypes)
5528 else:
5529 outputDType = values.dtype
5530
5531 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005532
5533 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005534 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005535 if error_name != ErrorIf.WrongRank:
5536 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005537 assert len(indices.shape) == 2
5538 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005539 assert values_in.shape[0] == indices.shape[0] # N
5540 assert input.shape[1] == indices.shape[1] # W
5541 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005542
5543 output_shape = values_in.shape
5544
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005545 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005546 all_dtypes = [
5547 DType.INT8,
5548 DType.INT16,
5549 DType.INT32,
5550 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005551 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005552 DType.FP16,
5553 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005554 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005555 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5556 outputDType = rng.choice(wrong_dtypes)
5557 else:
5558 outputDType = values_in.dtype
5559
5560 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005561
5562 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005563 def tableOp(ser, rng, input, error_name=None):
5564 # Same shape as the input, dtype dependent on input dtype
5565 if error_name != ErrorIf.WrongInputType:
5566 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005567 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005568 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005569 wrong_dtypes = [
5570 DType.INT8,
5571 DType.INT16,
5572 DType.INT32,
5573 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005574 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005575 DType.FP16,
5576 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005577 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005578 wrong_dtypes.remove(output_dtype)
5579 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005580 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005581
5582 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005583 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005584 serializer,
5585 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005586 input,
5587 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005588 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005589 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005590 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005591 input_dtype,
5592 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005593 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005594 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005595 # Calculate OH, OW
5596 scale_y_n = scale[0]
5597 scale_y_d = scale[1]
5598 scale_x_n = scale[2]
5599 scale_x_d = scale[3]
5600 if error_name == ErrorIf.ScaleSmallerEqualZero:
5601 scale_y_n = max(scale_y_n, 1)
5602 scale_y_d = max(scale_y_d, 1)
5603 scale_x_n = max(scale_x_n, 1)
5604 scale_x_d = max(scale_x_d, 1)
5605
5606 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5607 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5608
5609 if error_name is not None:
5610 # Make sure the output tensor is valid, which can occur when
5611 # scale, offset or border have been changed for ERROR_IFs
5612 oh = max(oh, 1)
5613 ow = max(ow, 1)
5614 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005615 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5616 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005617
5618 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5619 choices = [1, 2, 3]
5620 change = rng.choice(choices)
5621 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5622 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005623 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005624 oh -= scale_y_d
5625 assert oh > 0 # Should have been caught in agResize
5626 else:
5627 oh += scale_y_d
5628 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005629 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005630 ow -= scale_x_d
5631 assert ow > 0 # Should have been caught in agResize
5632 else:
5633 ow += scale_x_d
5634
Matthew Haddon848efb42021-09-09 12:30:53 +01005635 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005636 output_dims = [
5637 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005638 oh,
5639 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005640 input.shape[0],
5641 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005642 elif error_name == ErrorIf.BatchMismatch:
5643 output_dims = [
5644 input.shape[0] + rng.integers(1, 10),
5645 oh,
5646 ow,
5647 input.shape[3],
5648 ]
5649 elif error_name == ErrorIf.ChannelMismatch:
5650 output_dims = [
5651 input.shape[0],
5652 oh,
5653 ow,
5654 input.shape[3] + rng.integers(1, 10),
5655 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005656 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005657 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005658
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005659 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005660
5661 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005662 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005663 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005664
5665 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005666 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005667 if error_name == ErrorIf.ConvOutputShapeMismatch:
5668 choices = [1, 2, 3]
5669 change = rng.choice(choices)
5670 if change in [1, 3]:
5671 output_shape[1] = output_shape[1] + rng.choice(choices)
5672 if change in [2, 3]:
5673 output_shape[2] = output_shape[2] + rng.choice(choices)
5674
James Ward8b390432022-08-12 20:48:56 +01005675 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005676 # Pick some potentially correct output dtype if input type is incorrect
5677 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005678 else:
James Ward8b390432022-08-12 20:48:56 +01005679 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005680
5681 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005682 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005683 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005684 else:
5685 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005686 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005687 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005688
Kevin Cheng550ccc52021-03-03 11:21:43 -08005689 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005690
5691 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005692 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5693 outputs = []
5694
5695 assert ifm1.dtype == ifm2.dtype
5696 input_dtype = ifm1.dtype
5697
5698 if error_name != ErrorIf.FFTInputShapeMismatch:
5699 assert ifm1.shape == ifm2.shape
5700
5701 input_shape = ifm1.shape
5702 if error_name != ErrorIf.WrongRank:
5703 assert len(input_shape) == 3
5704
5705 output_shape = input_shape.copy()
5706 output_dtype = input_dtype
5707
5708 if error_name == ErrorIf.WrongOutputType:
5709 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005710 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005711 output_dtype = rng.choice(wrong_dtypes)
5712 elif error_name == ErrorIf.BatchMismatch:
5713 output_shape[0] += rng.integers(1, 10)
5714 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5715 modify_dim = rng.choice([1, 2])
5716 output_shape[modify_dim] += rng.integers(1, 10)
5717
5718 outputs.append(serializer.addOutput(output_shape, output_dtype))
5719 outputs.append(serializer.addOutput(output_shape, output_dtype))
5720 return outputs
5721
5722 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005723 def rfft2dOp(serializer, rng, value, error_name=None):
5724 outputs = []
5725
5726 input_shape = value.shape
5727 if error_name != ErrorIf.WrongRank:
5728 assert len(input_shape) == 3
5729
5730 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5731
5732 output_dtype = value.dtype
5733 if error_name == ErrorIf.WrongOutputType:
5734 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005735 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005736 output_dtype = rng.choice(wrong_dtypes)
5737 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005738 output_shape[0] += rng.integers(1, 10)
5739 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5740 modify_dim = rng.choice([1, 2])
5741 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005742
5743 outputs.append(serializer.addOutput(output_shape, output_dtype))
5744 outputs.append(serializer.addOutput(output_shape, output_dtype))
5745 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005746
5747 @staticmethod
5748 def addShapeOp(ser, rng, a, b, error_name=None):
5749 if error_name != ErrorIf.RankMismatch:
5750 assert len(a.shape) == len(b.shape)
5751 assert a.dtype == b.dtype
5752
5753 shape = []
5754 for i in range(len(a.shape)):
5755 shape.append(a.shape[i])
5756
5757 fuzz_idx = rng.integers(0, len(a.shape))
5758 if error_name == ErrorIf.DimensionMismatch:
5759 shape[fuzz_idx] += 1
5760
5761 if error_name == ErrorIf.WrongOutputType:
5762 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5763 outputDType = rng.choice(wrong_dtypes)
5764 else:
5765 outputDType = DType.SHAPE
5766 return ser.addOutput(shape, outputDType)