blob: 6867979ab3cb505a2f0c58d5b14d4eb62814b75b [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
79 for dtype in (DType.FP32, DType.FP16, DType.BF16):
80 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000198 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 return np.int64(self.rng.integers(low=low, high=high, size=shape))
200 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
201 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
202
203 if dtype == DType.FP16:
204 return np.float16(f_tensor)
205 else:
206 f32_tensor = np.float32(f_tensor)
207 if dtype == DType.BF16:
208 # Floor the last 16 bits of each f32 value
209 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
210 else:
211 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100213 # All other integer types
214 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
Kevin Cheng989cb052021-04-28 16:29:44 -0700216 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700217 placeholders = []
218
Kevin Cheng989cb052021-04-28 16:29:44 -0700219 assert len(shape_list) == len(dtype_list)
220
Jeremy Johnson1271c442023-09-05 11:39:26 +0100221 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700222 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 if not self.args.lazy_data_gen:
224 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700225 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700226
227 return placeholders
228
Kevin Cheng989cb052021-04-28 16:29:44 -0700229 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700230 consts = []
231
Kevin Cheng989cb052021-04-28 16:29:44 -0700232 assert len(shape_list) == len(dtype_list)
233
Jeremy Johnson1271c442023-09-05 11:39:26 +0100234 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700235 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100236 if not self.args.lazy_data_gen:
237 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700238 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700239
240 return consts
241
242 def makeShape(self, rank):
243 if self.targetted_shape:
244 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 return np.int32(
246 self.rng.integers(
247 low=self.args.tensor_shape_range[0],
248 high=self.args.tensor_shape_range[1],
249 size=rank,
250 )
251 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700252
253 def setTargetShape(self, shape):
254 self.targetted_shape = shape
255
256 def randInt(self, low=0, high=256):
257 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
258
259 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 low, high = self.getDTypeRange(dtype)
261
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100262 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100263 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100264 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100265 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100266 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100267 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
268 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700269 elif dtype == DType.BOOL:
270 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000271 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700272 # Special size
273 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 return np.int32(self.rng.integers(low, high, size=1))[0]
276
277 def shapeStr(self, shape):
278
279 sStr = []
280 # Convert to strings
281 for i in shape:
282 sStr.append(str(i))
283
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100286 def typeStr(self, dtype):
287 if isinstance(dtype, list) or isinstance(dtype, tuple):
288 assert len(dtype) >= 2
289 strs = [self.typeStr(t) for t in dtype]
290 # Limit types to the first 2 as the 3rd is the accumulator
291 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100293 if dtype in gtu.DTYPE_ATTRIBUTES:
294 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700295 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100296 raise Exception(
297 "Unknown dtype, cannot convert to string: {}".format(dtype)
298 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100300 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100301 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100302 if dtype in gtu.DTYPE_ATTRIBUTES:
303 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100305 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700306
Luke Hutton57287132023-02-06 14:54:18 +0000307 def constrictBatchSize(self, shape):
308 # Limit the batch size unless an explicit target shape set
309 if self.args.max_batch_size and not self.args.target_shapes:
310 shape[0] = min(shape[0], self.args.max_batch_size)
311 return shape
312
James Ward30124a82023-02-02 14:56:33 +0000313 def makeDimension(self):
314 return self.randInt(
315 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
316 )
317
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100318 def tensorComplianceMetaData(
319 self, op, inputType, argsDict, outputTensor, errorName
320 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000321 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
322 UNSUPPORTED_NON_FP32_INPUT_OPS = (
323 Op.MATMUL,
324 Op.CONV2D,
325 Op.FULLY_CONNECTED,
326 Op.DEPTHWISE_CONV2D,
327 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100328 if (
329 errorName
330 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000331 or (
332 not gtu.dtypeIsSupportedByCompliance(inputType)
333 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
334 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100335 ):
336 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100337 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100338
Jeremy Johnson1271c442023-09-05 11:39:26 +0100339 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100340 compliance_tens = {
341 "mode": None,
342 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
343 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
344 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100345 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
346 mode = gtu.ComplianceMode.DOT_PRODUCT
347 compliance_tens["dot_product_info"] = {
348 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100349 "ks": int(argsDict["ksb"])
350 if "ksb" in argsDict
351 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100352 }
353 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
354 mode = gtu.ComplianceMode.FP_SPECIAL
355 elif "compliance" in op and "ulp" in op["compliance"]:
356 mode = gtu.ComplianceMode.ULP
357 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
358 elif op["op"] == Op.REDUCE_PRODUCT:
359 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000360 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000361 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000362 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000363 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
364 compliance_tens["abs_error_info"] = {
365 "lower_bound": op["compliance"]["abs_error_lower_bound"]
366 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100367 else:
368 mode = gtu.ComplianceMode.EXACT
369 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
370
371 return compliance_tens
372
373 # Build Op functions
374 # Create the output tensor (calling OutputShaper as needed)
375 # Do final tweaks to attributes (if necessary for errorIf)
376 # Add Op into graph
377 # Return resulting tensor information or BuildInfo
378
379 class BuildInfo:
380 """Enhanced build information containing result tensor and associated compliance dict."""
381
382 def __init__(self, resultTensor, complianceDict):
383 self.resultTensor = resultTensor
384 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700385
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000386 def build_unary(
387 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
388 ):
389 assert len(inputs) == 1
390 a = inputs[0]
391 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100392
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000393 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100394
395 # Ensure new output type has correct qinfo
396 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000397 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000398 qinfo = [
399 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000400 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000401 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100402
403 # Invalidate Input/Output list for error if checks.
404 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000405 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100406 pCount, cCount = op["operands"]
407 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000408 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
409 self, error_name, input_list, output_list
410 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100411
Les Bell729b0352021-11-24 10:28:21 +0000412 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100413 self.ser,
414 validator_fcns,
415 error_name,
416 op=op,
417 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000418 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000419 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000420 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100421 input_list=input_list,
422 output_list=output_list,
423 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000424 ):
425 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100426
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000427 attr = None
428 if op["op"] == Op.NEGATE:
429 attr = ts.TosaSerializerAttribute()
430 attr.NegateAttribute(qinfo[0], qinfo[1])
431
432 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000433
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000434 compliance = self.tensorComplianceMetaData(
435 op, a.dtype, args_dict, result_tensor, error_name
436 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000437 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700438
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000439 def build_binary_broadcast(
440 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
441 ):
442 assert len(inputs) == 2
443 a, b = inputs
444 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000445 self.ser, self.rng, a, b, error_name
446 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100447
448 # Invalidate Input/Output list for error if checks.
449 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000450 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100451 pCount, cCount = op["operands"]
452 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000453 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
454 self, error_name, input_list, output_list
455 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100456
Les Bell729b0352021-11-24 10:28:21 +0000457 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100458 self.ser,
459 validator_fcns,
460 error_name,
461 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000462 input1=a,
463 input2=b,
464 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000465 output_dtype=result_tensor.dtype,
466 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100467 input_list=input_list,
468 output_list=output_list,
469 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000470 ):
471 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100472
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000473 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000474
Jeremy Johnson9a758382023-11-07 16:27:35 +0000475 compliance = self.tensorComplianceMetaData(
476 op, a.dtype, args_dict, result_tensor, error_name
477 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000478
479 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700480
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100481 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700482 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000483 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700484 return result_tens
485
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000486 def build_arithmetic_right_shift(
487 self, op, a, b, round, validator_fcns=None, error_name=None
488 ):
489 result_tens = OutputShaper.binaryBroadcastOp(
490 self.ser, self.rng, a, b, error_name
491 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100492
493 # Invalidate Input/Output list for error if checks.
494 input_list = [a.name, b.name]
495 output_list = [result_tens.name]
496 pCount, cCount = op["operands"]
497 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000498 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
499 self, error_name, input_list, output_list
500 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100501
Les Bell729b0352021-11-24 10:28:21 +0000502 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100503 self.ser,
504 validator_fcns,
505 error_name,
506 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000507 input1=a,
508 input2=b,
509 input_dtype=a.dtype,
510 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000511 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100512 input_list=input_list,
513 output_list=output_list,
514 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000515 ):
516 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800517
518 attr = ts.TosaSerializerAttribute()
519 attr.ArithmeticRightShiftAttribute(round)
520
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000521 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800522 return result_tens
523
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100524 def build_mul(
525 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
526 ):
527 assert len(inputs) == 2
528 a, b = inputs
529 shift = args_dict["shift"]
530
531 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000532 self.ser, self.rng, a, b, error_name
533 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700534
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100535 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100536 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100537 result_tensor.setDtype(DType.INT32)
538
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100539 if error_name == ErrorIf.WrongOutputType:
540 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
541 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100542 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100543
544 # Invalidate Input/Output list for error if checks.
545 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100546 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100547 pCount, cCount = op["operands"]
548 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000549 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
550 self, error_name, input_list, output_list
551 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100552
Les Bell729b0352021-11-24 10:28:21 +0000553 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100554 self.ser,
555 validator_fcns,
556 error_name,
557 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000558 input1=a,
559 input2=b,
560 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100561 output_dtype=result_tensor.dtype,
562 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100563 input_list=input_list,
564 output_list=output_list,
565 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000566 ):
567 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700568
Kevin Chengaee1fac2020-11-11 13:54:06 -0800569 attr = ts.TosaSerializerAttribute()
570 attr.MulAttribute(shift)
571
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000572 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100573
574 compliance = self.tensorComplianceMetaData(
575 op, a.dtype, args_dict, result_tensor, error_name
576 )
577
578 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700579
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100580 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
581 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700582
Kevin Chengfe392ce2021-10-18 21:51:55 +0000583 attr = ts.TosaSerializerAttribute()
584 attr.TableAttribute(table)
585
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100586 # Invalidate Input/Output list for error if checks.
587 input_list = [a.name]
588 output_list = [result_tens.name]
589 pCount, cCount = op["operands"]
590 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000591 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
592 self, error_name, input_list, output_list
593 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100594
Les Bell729b0352021-11-24 10:28:21 +0000595 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100596 self.ser,
597 validator_fcns,
598 error_name,
599 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000600 input_shape=a.shape,
601 input_dtype=a.dtype,
602 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000603 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100604 input_list=input_list,
605 output_list=output_list,
606 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000607 ):
608 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100609
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000610 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700611
612 return result_tens
613
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000614 def build_select(
615 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
616 ):
617 assert len(inputs) == 3
618 cond, a, b = inputs
619
620 result_tensor = OutputShaper.selectOp(
621 self.ser, self.rng, cond, a, b, error_name
622 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100623
624 # Invalidate Input/Output list for error if checks.
625 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000626 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100627 pCount, cCount = op["operands"]
628 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000629 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
630 self, error_name, input_list, output_list
631 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100632
Les Bell729b0352021-11-24 10:28:21 +0000633 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100634 self.ser,
635 validator_fcns,
636 error_name,
637 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000638 input1=cond,
639 input2=a,
640 input3=b,
641 input_shape=a.shape,
642 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000643 output_dtype=result_tensor.dtype,
644 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100645 input_list=input_list,
646 output_list=output_list,
647 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000648 ):
649 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100650
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000651 self.ser.addOperator(
652 op["op"],
653 input_list,
654 output_list,
655 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000656 compliance = self.tensorComplianceMetaData(
657 op, a.dtype, args_dict, result_tensor, error_name
658 )
659
660 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700661
Jeremy Johnsona0150012023-11-15 15:52:06 +0000662 def build_comparison(
663 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
664 ):
665 assert len(inputs) == 2
666 a, b = inputs
667
668 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000669 self.ser, self.rng, a, b, error_name
670 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100671
672 # Invalidate Input/Output list for error if checks.
673 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000674 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100675 pCount, cCount = op["operands"]
676 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000677 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
678 self, error_name, input_list, output_list
679 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100680
Les Bell729b0352021-11-24 10:28:21 +0000681 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100682 self.ser,
683 validator_fcns,
684 error_name,
685 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000686 input1=a,
687 input2=b,
688 input_shape=a.shape,
689 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000690 output_shape=result_tensor.shape,
691 output_dtype=result_tensor.dtype,
692 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100693 input_list=input_list,
694 output_list=output_list,
695 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000696 ):
697 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100698
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000699 self.ser.addOperator(
700 op["op"],
701 input_list,
702 output_list,
703 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000704
705 compliance = self.tensorComplianceMetaData(
706 op, a.dtype, args_dict, result_tensor, error_name
707 )
708 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700709
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000710 def build_argmax(
711 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
712 ):
713 assert len(inputs) == 1
714 a = inputs[0]
715 axis = args_dict["axis"]
716 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100717
718 # Invalidate Input/Output list for error if checks.
719 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000720 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100721 pCount, cCount = op["operands"]
722 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000723 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
724 self, error_name, input_list, output_list
725 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100726
Les Bell729b0352021-11-24 10:28:21 +0000727 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100728 self.ser,
729 validator_fcns,
730 error_name,
731 op=op,
732 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000733 input_shape=a.shape,
734 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000735 output_shape=result_tensor.shape,
736 output_dtype=result_tensor.dtype,
737 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100738 input_list=input_list,
739 output_list=output_list,
740 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000741 ):
742 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700743
744 attr = ts.TosaSerializerAttribute()
745 attr.AxisAttribute(axis)
746
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000747 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000748
749 compliance = self.tensorComplianceMetaData(
750 op, inputs[0].dtype, args_dict, result_tensor, error_name
751 )
752 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700753
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000754 def build_pool2d(
755 self,
756 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100757 inputs,
758 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000759 validator_fcns=None,
760 error_name=None,
761 qinfo=None,
762 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100763 assert len(inputs) == 1
764 input = inputs[0]
765 # max_pool has no accum_dtype
766 accum_dtype = (
767 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
768 )
769 stride = args_dict["stride"]
770 pad = args_dict["pad"]
771 kernel = args_dict["kernel"]
772
Jeremy Johnson0601f802023-11-08 16:28:09 +0000773 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000774 self.ser, self.rng, input, kernel, stride, pad, error_name
775 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100776
777 # Ensure new output type has correct qinfo
778 if error_name == ErrorIf.WrongInputType:
779 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000780 qinfo = [
781 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000782 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000783 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100784
785 # Invalidate Input/Output list for error if checks.
786 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000787 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100788 pCount, cCount = op["operands"]
789 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000790 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
791 self, error_name, input_list, output_list
792 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100793
Les Bell729b0352021-11-24 10:28:21 +0000794 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100795 self.ser,
796 validator_fcns,
797 error_name,
798 op=op,
799 input_shape=input.shape,
800 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000801 output_shape=result_tensor.shape,
802 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100803 kernel=kernel,
804 stride=stride,
805 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000806 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000807 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100808 input_list=input_list,
809 output_list=output_list,
810 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000811 ):
812 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700813
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000814 if qinfo is None:
815 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700816
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000817 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100818 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000819
820 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700821
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100822 compliance = self.tensorComplianceMetaData(
823 op, inputs[0].dtype, args_dict, result_tensor, error_name
824 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100825
826 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100827
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000828 def build_conv2d(
829 self,
830 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100831 inputs,
832 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000833 validator_fcns=None,
834 error_name=None,
835 qinfo=None,
836 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100837 assert len(inputs) == 3
838 ifm, filter, bias = inputs
839 accum_dtype = args_dict["acc_type"]
840 strides = args_dict["stride"]
841 padding = args_dict["pad"]
842 dilations = args_dict["dilation"]
843
Kevin Cheng550ccc52021-03-03 11:21:43 -0800844 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100845 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100846 self.ser,
847 self.rng,
848 ifm,
849 filter,
850 accum_dtype,
851 strides,
852 padding,
853 dilations,
854 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000855 )
856
857 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000858 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
859 DType.INT8,
860 DType.UINT8,
861 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000862 qinfo = [
863 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100864 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000865 ]
Les Bell0e027d42021-11-09 14:42:14 +0000866
867 # Invalidate Input/Output list for error_if checks.
868 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100869 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000870 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000871 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
872 self, error_name, input_list, output_list
873 )
Les Bell0e027d42021-11-09 14:42:14 +0000874
Les Bell729b0352021-11-24 10:28:21 +0000875 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000876 self.ser,
877 validator_fcns,
878 error_name,
879 op=op,
880 input_dtype=ifm.dtype,
881 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100882 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000883 qinfo=qinfo,
884 input_list=input_list,
885 num_operands=num_operands,
886 output_list=output_list,
887 pad=padding,
888 stride=strides,
889 dilation=dilations,
890 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100891 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100892 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000893 ):
894 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700895
Tai Lyd3797f02023-11-15 23:06:19 +0000896 # TODO - Test local_bound, for now set local bound attribute to False
897 local_bound = False
898
Eric Kunzee5e26762020-10-13 16:11:07 -0700899 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000900 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700901
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000902 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100903
904 compliance = self.tensorComplianceMetaData(
905 op, ifm.dtype, args_dict, result_tensor, error_name
906 )
907
908 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700909
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000910 def build_conv3d(
911 self,
912 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100913 inputs,
914 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000915 validator_fcns=None,
916 error_name=None,
917 qinfo=None,
918 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100919 assert len(inputs) == 3
920 ifm, filter, bias = inputs
921 accum_dtype = args_dict["acc_type"]
922 strides = args_dict["stride"]
923 padding = args_dict["pad"]
924 dilations = args_dict["dilation"]
925
Kevin Cheng1533b852021-09-01 12:51:58 -0700926 assert len(padding) == 6
927 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100928 self.ser,
929 self.rng,
930 ifm,
931 filter,
932 accum_dtype,
933 strides,
934 padding,
935 dilations,
936 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000937 )
938
939 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000940 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
941 DType.INT8,
942 DType.UINT8,
943 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000944 qinfo = [
945 TosaQuantGen.getZeroPoint(self, ifm.dtype),
946 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
947 ]
Les Bell0e027d42021-11-09 14:42:14 +0000948
949 # Invalidate Input/Output list for error_if checks.
950 input_list = [ifm.name, filter.name, bias.name]
951 output_list = [result_tens.name]
952 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000953 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
954 self, error_name, input_list, output_list
955 )
Les Bell0e027d42021-11-09 14:42:14 +0000956
Les Bell729b0352021-11-24 10:28:21 +0000957 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000958 self.ser,
959 validator_fcns,
960 error_name,
961 op=op,
962 input_dtype=ifm.dtype,
963 weight_dtype=filter.dtype,
964 output_dtype=result_tens.dtype,
965 qinfo=qinfo,
966 input_list=input_list,
967 num_operands=num_operands,
968 output_list=output_list,
969 pad=padding,
970 stride=strides,
971 dilation=dilations,
972 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100973 weight_shape=filter.shape,
974 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000975 ):
976 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700977
Tai Lyd3797f02023-11-15 23:06:19 +0000978 # TODO - Test local_bound, for now set local bound attribute to False
979 local_bound = False
980
Kevin Cheng1533b852021-09-01 12:51:58 -0700981 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000982 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700983
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000984 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700985 return result_tens
986
Kevin Cheng550ccc52021-03-03 11:21:43 -0800987 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000988 self,
989 op,
990 ifm,
991 filter,
992 bias,
James Ward8b390432022-08-12 20:48:56 +0100993 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000994 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700995 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000996 output_shape,
997 validator_fcns=None,
998 error_name=None,
999 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001000 ):
TatWai Chong24594f52022-06-08 00:48:04 -07001001 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001002 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001003 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001004 )
Les Bell0e027d42021-11-09 14:42:14 +00001005
1006 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001007 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1008 DType.INT8,
1009 DType.UINT8,
1010 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001011 qinfo = [
1012 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1013 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1014 ]
Les Bell0e027d42021-11-09 14:42:14 +00001015
1016 # Invalidate Input/Output list for error_if checks.
1017 input_list = [ifm.name, filter.name, bias.name]
1018 output_list = [result_tens.name]
1019 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001020 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1021 self, error_name, input_list, output_list
1022 )
Les Bell0e027d42021-11-09 14:42:14 +00001023
Les Bell729b0352021-11-24 10:28:21 +00001024 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001025 self.ser,
1026 validator_fcns,
1027 error_name,
1028 op=op,
1029 input_dtype=ifm.dtype,
1030 weight_dtype=filter.dtype,
1031 output_dtype=result_tens.dtype,
1032 qinfo=qinfo,
1033 input_list=input_list,
1034 num_operands=num_operands,
1035 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001036 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001037 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001038 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001039 weight_shape=filter.shape,
1040 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001041 ):
1042 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001043
Tai Lyd3797f02023-11-15 23:06:19 +00001044 # TODO - Test local_bound, for now set local bound attribute to False
1045 local_bound = False
1046
Eric Kunzee5e26762020-10-13 16:11:07 -07001047 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001048 attr.TransposeConvAttribute(
1049 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1050 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001051
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001052 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001053 return result_tens
1054
Kevin Cheng550ccc52021-03-03 11:21:43 -08001055 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001056 self,
1057 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001058 inputs,
1059 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001060 validator_fcns=None,
1061 error_name=None,
1062 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001063 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001064 assert len(inputs) == 3
1065 ifm, filter, bias = inputs
1066 accum_dtype = args_dict["acc_type"]
1067 strides = args_dict["stride"]
1068 padding = args_dict["pad"]
1069 dilations = args_dict["dilation"]
1070
Jeremy Johnson4f931302024-01-04 17:05:24 +00001071 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001072 self.ser,
1073 self.rng,
1074 ifm,
1075 filter,
1076 accum_dtype,
1077 strides,
1078 padding,
1079 dilations,
1080 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001081 )
1082
1083 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001084 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1085 DType.INT8,
1086 DType.UINT8,
1087 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001088 qinfo = [
1089 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001090 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001091 ]
Les Bell0e027d42021-11-09 14:42:14 +00001092
1093 # Invalidate Input/Output list for error_if checks.
1094 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001095 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001096 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001097 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1098 self, error_name, input_list, output_list
1099 )
Les Bell0e027d42021-11-09 14:42:14 +00001100
Les Bell729b0352021-11-24 10:28:21 +00001101 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001102 self.ser,
1103 validator_fcns,
1104 error_name,
1105 op=op,
1106 input_dtype=ifm.dtype,
1107 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001108 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001109 qinfo=qinfo,
1110 input_list=input_list,
1111 num_operands=num_operands,
1112 output_list=output_list,
1113 pad=padding,
1114 stride=strides,
1115 dilation=dilations,
1116 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001117 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001118 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001119 ):
1120 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001121
Tai Lyd3797f02023-11-15 23:06:19 +00001122 # TODO - Test local_bound, for now set local bound attribute to False
1123 local_bound = False
1124
Eric Kunzee5e26762020-10-13 16:11:07 -07001125 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001126 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001127
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001128 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001129
1130 compliance = self.tensorComplianceMetaData(
1131 op, ifm.dtype, args_dict, result_tensor, error_name
1132 )
1133
1134 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001135
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001136 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001137 self,
1138 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001139 inputs,
1140 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001141 validator_fcns=None,
1142 error_name=None,
1143 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001144 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001145 assert len(inputs) == 3
1146 ifm, filter, bias = inputs
1147 accum_dtype = args_dict["acc_type"]
1148
1149 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001150 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001151 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001152
1153 # Invalidate Input/Output list for error if checks.
1154 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001155 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001156 pCount, cCount = op["operands"]
1157 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001158 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1159 self, error_name, input_list, output_list
1160 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001161
Les Bell729b0352021-11-24 10:28:21 +00001162 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001163 self.ser,
1164 validator_fcns,
1165 error_name,
1166 op=op,
1167 input_shape=ifm.shape,
1168 input_dtype=ifm.dtype,
1169 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001170 output_shape=result_tensor.shape,
1171 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001172 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001173 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001174 input_list=input_list,
1175 output_list=output_list,
1176 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001177 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001178 ):
1179 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001180
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001181 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001182 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001183
1184 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001185
1186 compliance = self.tensorComplianceMetaData(
1187 op, ifm.dtype, args_dict, result_tensor, error_name
1188 )
1189
1190 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001191
James Ward8b390432022-08-12 20:48:56 +01001192 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001193 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001194 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001195 assert len(inputs) == 2
1196 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001197 accum_dtype = args_dict["acc_type"]
1198 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001199 self.ser, self.rng, a, b, accum_dtype, error_name
1200 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001201
1202 # Invalidate Input/Output list for error if checks.
1203 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001204 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001205 pCount, cCount = op["operands"]
1206 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001207 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1208 self, error_name, input_list, output_list
1209 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001210
Les Bell729b0352021-11-24 10:28:21 +00001211 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001212 self.ser,
1213 validator_fcns,
1214 error_name,
1215 op=op,
1216 input_shape=a.shape,
1217 input_dtype=a.dtype,
1218 input2_shape=b.shape,
1219 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001220 output_shape=result_tensor.shape,
1221 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001222 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001223 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001224 input_list=input_list,
1225 output_list=output_list,
1226 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001227 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001228 ):
1229 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001230
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001231 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001232 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001233
1234 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001235
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001236 compliance = self.tensorComplianceMetaData(
1237 op, a.dtype, args_dict, result_tensor, error_name
1238 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001239
1240 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001241
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001242 def build_reduce(
1243 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1244 ):
1245 assert len(inputs) == 1
1246 a = inputs[0]
1247 axis = args_dict["axis"]
1248 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001249
1250 # Invalidate Input/Output list for error if checks.
1251 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001252 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001253 pCount, cCount = op["operands"]
1254 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001255 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1256 self, error_name, input_list, output_list
1257 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001258
Les Bell729b0352021-11-24 10:28:21 +00001259 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001260 self.ser,
1261 validator_fcns,
1262 error_name,
1263 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001264 axis=axis,
1265 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001266 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001267 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001268 output_dtype=result_tensor.dtype,
1269 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001270 input_list=input_list,
1271 output_list=output_list,
1272 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001273 ):
1274 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001275
1276 attr = ts.TosaSerializerAttribute()
1277 attr.AxisAttribute(axis)
1278
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001279 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001280
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001281 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1282 # Number of products - needed for compliance
1283 args_dict["n"] = a.shape[axis]
1284
1285 compliance = self.tensorComplianceMetaData(
1286 op, a.dtype, args_dict, result_tensor, error_name
1287 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001288
1289 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001290
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001291 def build_clamp(
1292 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1293 ):
1294 assert len(inputs) == 1
1295 a = inputs[0]
1296
1297 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001298
Jeremy Johnson18e26662021-07-22 16:15:29 +01001299 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001300
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001301 if error_name == ErrorIf.MaxSmallerMin:
1302 # Make sure the numbers are different to invoke this error
1303 while v[0] == v[1]:
1304 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1305 max_val = min(v)
1306 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001307 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001308 max_val = max(v)
1309 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001310
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001311 # Invalidate Input/Output list for error if checks.
1312 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001313 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001314 pCount, cCount = op["operands"]
1315 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001316 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1317 self, error_name, input_list, output_list
1318 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001319
Les Bell729b0352021-11-24 10:28:21 +00001320 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001321 self.ser,
1322 validator_fcns,
1323 error_name,
1324 op=op,
1325 max_val=max_val,
1326 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001327 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001328 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001329 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001330 output_dtype=result_tensor.dtype,
1331 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001332 input_list=input_list,
1333 output_list=output_list,
1334 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001335 ):
1336 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001337
1338 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001339 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1340 if a.dtype == DType.FP16:
1341 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1342 min_val = min_val.astype(np.float32)
1343 max_val = max_val.astype(np.float32)
1344
1345 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001346 else:
James Ward34071252022-12-07 15:48:47 +00001347 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001348
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001349 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001350
1351 compliance = self.tensorComplianceMetaData(
1352 op, a.dtype, args_dict, result_tensor, error_name
1353 )
1354
1355 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001356
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001357 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1358 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001359 attr = ts.TosaSerializerAttribute()
1360
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001361 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001362
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001363 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001364 return result_tens
1365
1366 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001367 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1368 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001369
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001370 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001371 return result_tens
1372
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001373 def build_activation(
1374 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1375 ):
1376 assert len(inputs) == 1
1377 a = inputs[0]
1378
1379 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001380
1381 # Invalidate Input/Output list for error if checks.
1382 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001383 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001384 pCount, cCount = op["operands"]
1385 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001386 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1387 self, error_name, input_list, output_list
1388 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001389
Les Bell729b0352021-11-24 10:28:21 +00001390 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001391 self.ser,
1392 validator_fcns,
1393 error_name,
1394 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001395 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001396 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001397 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001398 output_dtype=result_tensor.dtype,
1399 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001400 input_list=input_list,
1401 output_list=output_list,
1402 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001403 ):
1404 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001405
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001406 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001407
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001408 compliance = self.tensorComplianceMetaData(
1409 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001410 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001411
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001412 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001413
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001414 def build_concat(
1415 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1416 ):
Won Jeon74342e52024-01-09 00:34:40 +00001417 if op["op"] == Op.CONCAT_SHAPE:
1418 axis = 0
1419 else:
1420 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001421 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001422 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001423
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001424 result_tensor = OutputShaper.concatOp(
1425 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001426 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001427
Matthew Haddon818ab902021-07-27 09:12:49 +01001428 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001429 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001430 input_tensor_names.append(tensor.name)
1431
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432 # Invalidate Input/Output list for error if checks.
1433 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001434 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001435 pCount, cCount = op["operands"]
1436 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001437 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1438 self, error_name, input_list, output_list
1439 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001440
Les Bell729b0352021-11-24 10:28:21 +00001441 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001442 self.ser,
1443 validator_fcns,
1444 error_name,
1445 op=op,
1446 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001447 input_shape=inputs[0].shape,
1448 output_shape=result_tensor.shape,
1449 input_dtype=inputs[0].dtype,
1450 output_dtype=result_tensor.dtype,
1451 inputs=inputs,
1452 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001453 input_list=input_list,
1454 output_list=output_list,
1455 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001456 ):
1457 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001458
Won Jeon74342e52024-01-09 00:34:40 +00001459 if op["op"] == Op.CONCAT:
1460 attr = ts.TosaSerializerAttribute()
1461 attr.AxisAttribute(axis)
1462 else:
1463 assert op["op"] == Op.CONCAT_SHAPE
1464 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001465 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001466
1467 compliance = self.tensorComplianceMetaData(
1468 op, inputs[0].dtype, args_dict, result_tensor, error_name
1469 )
1470
1471 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001472
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001473 def build_pad(
1474 self,
1475 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001476 inputs,
1477 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001478 validator_fcns=None,
1479 error_name=None,
1480 qinfo=None,
1481 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001482 assert len(inputs) == 1
1483 a = inputs[0]
1484 padding = args_dict["pad"]
1485 pad_const_int = args_dict["pad_const_int"]
1486 pad_const_float = args_dict["pad_const_fp"]
1487
1488 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001489
Kevin Chengfe392ce2021-10-18 21:51:55 +00001490 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001491 attr.PadAttribute(
1492 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1493 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001494
Matthew Haddone807aae2021-10-11 18:12:58 +01001495 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001496 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001497 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001498 pCount, cCount = op["operands"]
1499 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001500 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1501 self, error_name, input_list, output_list
1502 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001503
Les Bell729b0352021-11-24 10:28:21 +00001504 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001505 self.ser,
1506 validator_fcns,
1507 error_name,
1508 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001509 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001510 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001511 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001512 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001513 pad=padding,
1514 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001515 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001516 input_list=input_list,
1517 output_list=output_list,
1518 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001519 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001520 ):
1521 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001522
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001523 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001524
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001525 compliance = self.tensorComplianceMetaData(
1526 op, a.dtype, args_dict, result_tensor, error_name
1527 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001528
1529 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001530
Won Jeona21b2e82023-08-10 10:33:01 +00001531 def build_dim(
1532 self,
1533 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001534 inputs,
1535 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001536 validator_fcns=None,
1537 error_name=None,
1538 qinfo=None,
1539 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001540 assert len(inputs) == 1
1541 a = inputs[0]
1542 axis = args_dict["axis"]
1543 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001544
1545 # Invalidate Input/Output list for error if checks.
1546 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001547 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001548 pCount, cCount = op["operands"]
1549 num_operands = pCount + cCount
1550 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1551 self, error_name, input_list, output_list
1552 )
1553
1554 if not TosaErrorValidator.evValidateErrorIfs(
1555 self.ser,
1556 validator_fcns,
1557 error_name,
1558 op=op,
1559 axis=axis,
1560 input_shape=a.shape,
1561 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001562 output_shape=result_tensor.shape,
1563 output_dtype=result_tensor.dtype,
1564 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001565 input_list=input_list,
1566 output_list=output_list,
1567 num_operands=num_operands,
1568 ):
1569 return None
1570
1571 attr = ts.TosaSerializerAttribute()
1572 attr.AxisAttribute(axis)
1573
1574 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001575 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001576
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001577 def build_reshape(
1578 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1579 ):
Tai Ly8690a082023-12-18 20:40:24 +00001580 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001581 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001582 shape = inputs[1]
1583 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001584 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001585 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001586 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001587
1588 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001589 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001590 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001591 pCount, cCount = op["operands"]
1592 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001593 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1594 self, error_name, input_list, output_list
1595 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001596
Les Bell729b0352021-11-24 10:28:21 +00001597 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001598 self.ser,
1599 validator_fcns,
1600 error_name,
1601 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001602 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001603 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001604 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001605 output_dtype=result_tensor.dtype,
1606 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001607 input_list=input_list,
1608 output_list=output_list,
1609 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001610 ):
1611 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001612
Tai Ly8690a082023-12-18 20:40:24 +00001613 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001614
1615 compliance = self.tensorComplianceMetaData(
1616 op, a.dtype, args_dict, result_tensor, error_name
1617 )
1618
1619 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001620
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001621 def build_reverse(
1622 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1623 ):
1624 assert len(inputs) == 1
1625 a = inputs[0]
1626 axis = args_dict["axis"]
1627 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001628
1629 # Invalidate Input/Output list for error if checks.
1630 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001631 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001632 pCount, cCount = op["operands"]
1633 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001634 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1635 self, error_name, input_list, output_list
1636 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001637
Les Bell729b0352021-11-24 10:28:21 +00001638 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001639 self.ser,
1640 validator_fcns,
1641 error_name,
1642 op=op,
1643 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001644 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001645 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001646 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001647 output_dtype=result_tensor.dtype,
1648 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001649 input_list=input_list,
1650 output_list=output_list,
1651 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001652 ):
1653 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001654
1655 attr = ts.TosaSerializerAttribute()
1656 attr.AxisAttribute(axis)
1657
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001658 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001659 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001660
Matthew Haddone807aae2021-10-11 18:12:58 +01001661 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1662 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001663
Kevin Chengfe392ce2021-10-18 21:51:55 +00001664 attr = ts.TosaSerializerAttribute()
1665 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001666
Matthew Haddone807aae2021-10-11 18:12:58 +01001667 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001668 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001669 output_list = [result_tens.name]
1670 pCount, cCount = op["operands"]
1671 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001672 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1673 self, error_name, input_list, output_list
1674 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001675
Les Bell729b0352021-11-24 10:28:21 +00001676 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001677 self.ser,
1678 validator_fcns,
1679 error_name,
1680 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001681 input_shape=a.shape,
1682 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001683 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001684 input_dtype=a.dtype,
1685 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001686 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001687 input_list=input_list,
1688 output_list=output_list,
1689 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001690 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001691 ):
1692 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001693
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001694 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001695 return result_tens
1696
evacha017f7d4252024-01-24 12:08:09 +00001697 def build_slice(
1698 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1699 ):
1700 assert len(inputs) == 1
1701 a = inputs[0]
1702 start = args_dict["start"]
1703 size = args_dict["size"]
1704
1705 result_tensor = OutputShaper.sliceOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001706 self.ser, self.rng, a, start, size, error_name
1707 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001708
1709 # Invalidate Input/Output list for error if checks.
1710 input_list = [a.name]
evacha017f7d4252024-01-24 12:08:09 +00001711 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001712 pCount, cCount = op["operands"]
1713 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001714 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1715 self, error_name, input_list, output_list
1716 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001717
Les Bell729b0352021-11-24 10:28:21 +00001718 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001719 self.ser,
1720 validator_fcns,
1721 error_name,
1722 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001723 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001724 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001725 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001726 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001727 start=start,
1728 size=size,
evacha017f7d4252024-01-24 12:08:09 +00001729 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001730 input_list=input_list,
1731 output_list=output_list,
1732 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001733 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001734 ):
1735 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001736
1737 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001738 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001739
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001740 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha017f7d4252024-01-24 12:08:09 +00001741
1742 compliance = self.tensorComplianceMetaData(
1743 op, a.dtype, args_dict, result_tensor, error_name
1744 )
1745
1746 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001747
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001748 def build_tile(
1749 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1750 ):
Tai Ly8690a082023-12-18 20:40:24 +00001751 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001752 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001753 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001754 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001755 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001756 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001757 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001758
1759 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001760 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001761 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001762 pCount, cCount = op["operands"]
1763 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001764 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1765 self, error_name, input_list, output_list
1766 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001767
Les Bell729b0352021-11-24 10:28:21 +00001768 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001769 self.ser,
1770 validator_fcns,
1771 error_name,
1772 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001773 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001774 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001775 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001776 output_dtype=result_tensor.dtype,
1777 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001778 input_list=input_list,
1779 output_list=output_list,
1780 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001781 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001782 ):
1783 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001784
Tai Ly8690a082023-12-18 20:40:24 +00001785 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001786
1787 compliance = self.tensorComplianceMetaData(
1788 op, a.dtype, args_dict, result_tensor, error_name
1789 )
1790
1791 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001792
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001793 def build_gather(
1794 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1795 ):
1796 assert len(inputs) == 2
1797 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001798
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001799 result_tensor = OutputShaper.gatherOp(
1800 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001801 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001802
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001803 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001804 input_list = [values.name, indices.name]
1805 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001806 pCount, cCount = op["operands"]
1807 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001808 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1809 self, error_name, input_list, output_list
1810 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001811
Les Bell729b0352021-11-24 10:28:21 +00001812 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001813 self.ser,
1814 validator_fcns,
1815 error_name,
1816 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001817 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001818 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001819 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001820 output_dtype=result_tensor.dtype,
1821 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001822 input_list=input_list,
1823 output_list=output_list,
1824 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001825 ):
1826 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001827
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001828 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001829
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001830 compliance = self.tensorComplianceMetaData(
1831 op, values.dtype, args_dict, result_tensor, error_name
1832 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001833
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001834 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001835
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001836 def build_scatter(
1837 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1838 ):
1839 assert len(inputs) == 3
1840 values_in, indices, input = inputs
1841 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001842 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001843 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001844
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001845 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001846 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001847 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001848 pCount, cCount = op["operands"]
1849 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001850 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1851 self, error_name, input_list, output_list
1852 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001853
Les Bell729b0352021-11-24 10:28:21 +00001854 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001855 self.ser,
1856 validator_fcns,
1857 error_name,
1858 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001859 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001860 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001861 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001862 output_dtype=result_tensor.dtype,
1863 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001864 input_list=input_list,
1865 output_list=output_list,
1866 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001867 ):
1868 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001869
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001870 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001871
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001872 compliance = self.tensorComplianceMetaData(
1873 op, values_in.dtype, args_dict, result_tensor, error_name
1874 )
1875
1876 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001877
Kevin Cheng550ccc52021-03-03 11:21:43 -08001878 def build_resize(
1879 self,
1880 op,
1881 input,
1882 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001883 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001884 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001885 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001886 input_dtype,
1887 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001888 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001889 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001890 ):
1891 result_tens = OutputShaper.resizeOp(
1892 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001893 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001894 input,
1895 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001896 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001897 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001898 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001899 input_dtype,
1900 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001901 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001902 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001903
Matthew Haddon848efb42021-09-09 12:30:53 +01001904 # Invalidate Input/Output list for error if checks.
1905 input_list = [input.name]
1906 output_list = [result_tens.name]
1907 pCount, cCount = op["operands"]
1908 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001909 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1910 self, error_name, input_list, output_list
1911 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001912
Les Bell729b0352021-11-24 10:28:21 +00001913 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001914 self.ser,
1915 validator_fcns,
1916 error_name,
1917 op=op,
1918 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001919 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001920 input_dtype=input_dtype,
1921 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001922 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001923 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001924 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001925 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001926 input_list=input_list,
1927 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001928 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001929 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001930 ):
1931 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001932
Eric Kunzee5e26762020-10-13 16:11:07 -07001933 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001934
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001935 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001936
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001937 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001938 return result_tens
1939
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001940 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1941 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1942 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001943 self.ser.addOperator(
1944 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1945 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001946 return result_tens
1947
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001948 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001949 self.ser.addOutputTensor(val)
1950 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001951
1952 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00001953 def build_cast(
1954 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1955 ):
1956 assert len(inputs) == 1
1957 val = inputs[0]
1958 out_dtype = args_dict["out_type"]
1959
1960 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001961 self.ser, self.rng, val, out_dtype, error_name
1962 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001963
1964 # Invalidate Input/Output list for error if checks.
1965 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00001966 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001967 pCount, cCount = op["operands"]
1968 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001969 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1970 self, error_name, input_list, output_list
1971 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001972
Les Bell729b0352021-11-24 10:28:21 +00001973 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001974 self.ser,
1975 validator_fcns,
1976 error_name,
1977 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001978 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00001979 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001980 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00001981 output_dtype=result_tensor.dtype,
1982 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001983 input_list=input_list,
1984 output_list=output_list,
1985 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001986 ):
1987 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001988
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001989 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00001990
1991 compliance = self.tensorComplianceMetaData(
1992 op, val.dtype, args_dict, result_tensor, error_name
1993 )
1994
1995 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001996
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001997 def build_rescale(
1998 self,
1999 op,
2000 val,
2001 out_dtype,
2002 scale32,
2003 double_round,
2004 per_channel,
2005 validator_fcns,
2006 error_name,
2007 ):
2008 result_tens = OutputShaper.typeConversionOp(
2009 self.ser, self.rng, val, out_dtype, error_name
2010 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002011
2012 if per_channel:
2013 nc = val.shape[-1]
2014 else:
2015 nc = 1
2016
2017 in_type_width = self.typeWidth(val.dtype)
2018 out_type_width = self.typeWidth(out_dtype)
2019
Tai Ly8690a082023-12-18 20:40:24 +00002020 input_unsigned = False
2021 output_unsigned = False
2022
Kevin Cheng3a478572021-01-22 17:21:02 -08002023 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002024 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002025 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002026 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002027 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002028 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002029 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002030 elif error_name in [
2031 ErrorIf.InputZeroPointNotZero,
2032 ErrorIf.U16InputZeroPointNotValid,
2033 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002034 input_zp = self.randInt(-128, 128)
2035 if input_zp == 0:
2036 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002037 in_type_width += 1
2038 elif val.dtype == DType.UINT16:
2039 # Must come after ErrorIf.U16InputZeroPointNotValid check
2040 input_zp = self.rng.choice([0, 32768])
2041 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002042 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002043 else:
2044 input_zp = 0
2045
Kevin Cheng3a478572021-01-22 17:21:02 -08002046 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002047 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002048 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002049 elif out_dtype == DType.UINT8:
2050 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002051 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002052 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002053 elif error_name in [
2054 ErrorIf.OutputZeroPointNotZero,
2055 ErrorIf.U16OutputZeroPointNotValid,
2056 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002057 output_zp = self.randInt(-128, 128)
2058 if output_zp == 0:
2059 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002060 out_type_width += 1
2061 elif out_dtype == DType.UINT16:
2062 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2063 output_zp = self.rng.choice([0, 32768])
2064 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002065 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002066 else:
2067 output_zp = 0
2068
2069 # Calculate scale based on:
2070 # scale = a *(2^output_width)/(2^input_width))
2071
2072 a = np.float32(self.rng.random(size=[nc]))
2073 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2074
2075 if scale32:
2076 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002077 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002078 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2079 else:
2080 # Cap the scaling at 2^15 - 1 for scale16
2081 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2082
Kevin Cheng550ccc52021-03-03 11:21:43 -08002083 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002084
2085 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2086 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002087 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2088 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002089
2090 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002091 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2092 scale_arr[i], scale32
2093 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002094 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2095 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002096
Kevin Cheng550ccc52021-03-03 11:21:43 -08002097 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002098 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002099 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002100 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002101 assert val.placeholderFilename
2102 values = np.load(
2103 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2104 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002105 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2106 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2107 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002108 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2109 # Check we can safely convert to the expected dtype
2110 assert (
2111 val_adj.all() >= np.iinfo(values.dtype).min
2112 and val_adj.all() <= np.iinfo(values.dtype).max
2113 )
2114
2115 # Force casting to output datatype
2116 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2117
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002118 if not np.all(np.array_equal(values, val_adj)):
2119 # Values changed so overwrite file with new values
2120 np.save(
2121 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2122 val_adj,
2123 False,
2124 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002125
Matthew Haddonc2025212021-10-08 21:21:05 +01002126 # Invalidate Input/Output list for error if checks.
2127 input_list = [val.name]
2128 output_list = [result_tens.name]
2129 pCount, cCount = op["operands"]
2130 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002131 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2132 self, error_name, input_list, output_list
2133 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002134
2135 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002136 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002137 self.ser,
2138 validator_fcns,
2139 error_name,
2140 op=op,
2141 input_dtype=val.dtype,
2142 output_dtype=out_dtype,
2143 input_shape=val.shape,
2144 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002145 scale32=scale32,
2146 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002147 input_list=input_list,
2148 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002149 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002150 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002151 ):
2152 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002153
Eric Kunzee5e26762020-10-13 16:11:07 -07002154 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002155 attr.RescaleAttribute(
2156 input_zp,
2157 output_zp,
2158 multiplier_arr,
2159 shift_arr,
2160 scale32,
2161 double_round,
2162 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002163 input_unsigned,
2164 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002165 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002166
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002167 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002168 return result_tens
2169
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002170 def _get_condition_tensor(self, op, cond, error_name):
2171 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002172 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002173 else:
2174 cond_type = DType.BOOL
2175 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2176 choice = self.rng.choice([1, 2])
2177 if choice == 1:
2178 cond_shape = [2]
2179 else:
2180 cond_shape = [1, 2]
2181 else:
2182 # Must be of size 1 (rank 0)
2183 cond_shape = []
2184 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2185 return cond_tens
2186
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002187 def build_cond_if_const(
2188 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2189 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002190 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002191 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002192 # and fill them with const nodes for the body.
2193
2194 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002195 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002196
2197 # Make then/else tensors
2198 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002199
2200 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002201 if error_name in [
2202 ErrorIf.CondIfOutputListThenGraphMismatch,
2203 ErrorIf.CondIfOutputListElseGraphMismatch,
2204 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002205 incorrect_shape = deepcopy(then_tens.shape)
2206 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002207 incorrect_shape[i] += (
2208 self.rng.choice([-3, -2, 2, 3])
2209 if incorrect_shape[i] > 3
2210 else self.rng.choice([1, 2, 4])
2211 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002212 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2213
Jeremy Johnson18e26662021-07-22 16:15:29 +01002214 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2215 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002216
2217 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002218 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002219
2220 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002221 then_block = "THEN_BLOCK"
2222 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002223 attr = ts.TosaSerializerAttribute()
2224 attr.CondIfAttribute(then_block, else_block)
2225
2226 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002227 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002228
Jerry Ge9e94af82022-10-27 09:57:00 -07002229 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002230 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002231 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2232 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2233 else:
2234 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002235 self.ser.addOutputTensor(then_tens)
2236
Jerry Ge9e94af82022-10-27 09:57:00 -07002237 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002238 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2239 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2240 else:
2241 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002242 self.ser.addOutputTensor(else_tens)
2243
Les Bell729b0352021-11-24 10:28:21 +00002244 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002245 self.ser,
2246 validator_fcns,
2247 error_name,
2248 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002249 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002250 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002251 ):
2252 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002253
Eric Kunzee5e26762020-10-13 16:11:07 -07002254 return result_tens
2255
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002256 def build_cond_if_binary(
2257 self, op, a, b, cond, validator_fcns=None, error_name=None
2258 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002259 # For cond_if with a binary op in the then/else blocks, take a and b and
2260 # alternately add or subtract them based on the condition
2261
2262 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002263 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002264
Kevin Cheng550ccc52021-03-03 11:21:43 -08002265 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002266
2267 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002268 then_block = "THEN_BLOCK"
2269 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002270 attr = ts.TosaSerializerAttribute()
2271 attr.CondIfAttribute(then_block, else_block)
2272
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002273 if error_name in [
2274 ErrorIf.CondIfInputListThenGraphMismatch,
2275 ErrorIf.CondIfInputListElseGraphMismatch,
2276 ErrorIf.CondIfOutputListElseGraphMismatch,
2277 ErrorIf.CondIfOutputListThenGraphMismatch,
2278 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002279 incorrect_shape = a.shape.copy()
2280 for i in range(len(incorrect_shape)):
2281 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2282 incorrect_block_input = deepcopy(a)
2283 incorrect_block_input.shape = incorrect_shape
2284
Eric Kunzee5e26762020-10-13 16:11:07 -07002285 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002286 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002287 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002288 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002289
James Ward24dbc422022-10-19 12:20:31 +01002290 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002291 then_op, else_op = Op.ADD, Op.SUB
2292 elif a.dtype in (DType.INT8, DType.INT16):
2293 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2294 else:
2295 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002296
Les Bell6040b4d2021-10-11 12:50:31 +01002297 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002298 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002299 if (
2300 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2301 and block == then_block
2302 ) or (
2303 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2304 and block == else_block
2305 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002306 self.ser.addInputTensor(incorrect_block_input)
2307 self.ser.addInputTensor(b)
2308 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002309 elif (
2310 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2311 and block == then_block
2312 ) or (
2313 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2314 and block == else_block
2315 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002316 self.ser.addInputTensor(a)
2317 self.ser.addInputTensor(b)
2318 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2319 else:
2320 self.ser.addInputTensor(a)
2321 self.ser.addInputTensor(b)
2322 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002323 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002324
Les Bell729b0352021-11-24 10:28:21 +00002325 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002326 self.ser,
2327 validator_fcns,
2328 error_name,
2329 op=op,
2330 a=a,
2331 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002332 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002333 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002334 ):
2335 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002336
Eric Kunzee5e26762020-10-13 16:11:07 -07002337 return result_tens
2338
Matthew Haddon630c17c2021-10-14 15:05:41 +01002339 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002340 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002341
Kevin Cheng550ccc52021-03-03 11:21:43 -08002342 cond_block = "COND_BLOCK"
2343 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002344
2345 attr = ts.TosaSerializerAttribute()
2346 attr.WhileLoopAttribute(cond_block, body_block)
2347
2348 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002349 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002350 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002351 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002352
2353 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002354 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2355 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002356 if error_name == ErrorIf.InputListOutputListMismatch:
2357 incorrect_acc = deepcopy(acc)
2358 for i in range(len(incorrect_acc.shape)):
2359 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2360 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2361 else:
2362 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002363
2364 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002365 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002366 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002367 [iter.name, a.name, acc.name],
2368 [iter_out.name, a_out.name, acc_out.name],
2369 attr,
2370 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002371 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002372
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002373 if error_name in [
2374 ErrorIf.InputListCondGraphMismatch,
2375 ErrorIf.InputListBodyGraphInputMismatch,
2376 ErrorIf.InputListBodyGraphOutputMismatch,
2377 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002378 incorrect_iter = deepcopy(iter)
2379 for i in range(len(incorrect_iter.shape)):
2380 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2381 if len(incorrect_iter.shape) == 0:
2382 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2383
2384 incorrect_acc = deepcopy(acc)
2385 for i in range(len(incorrect_acc.shape)):
2386 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2387
Eric Kunzee5e26762020-10-13 16:11:07 -07002388 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002389 self.ser.addBasicBlock(cond_block)
2390
Matthew Haddon630c17c2021-10-14 15:05:41 +01002391 if error_name == ErrorIf.InputListCondGraphMismatch:
2392 self.ser.addInputTensor(incorrect_iter)
2393 self.ser.addInputTensor(a)
2394 self.ser.addInputTensor(incorrect_acc)
2395 else:
2396 self.ser.addInputTensor(iter)
2397 self.ser.addInputTensor(a)
2398 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002399 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002400
2401 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002402 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002403 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002404 cond_type = DType.BOOL
2405 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2406 choice = self.rng.choice([1, 2])
2407 if choice == 1:
2408 cond_shape = [3]
2409 else:
2410 cond_shape = [1, 2]
2411 else:
2412 cond_shape = []
2413 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002414
Kevin Cheng550ccc52021-03-03 11:21:43 -08002415 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002416
2417 # BODY block (input: a, acc, iter, output: a, acc, iter)
2418 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002419 self.ser.addBasicBlock(body_block)
2420
Matthew Haddon630c17c2021-10-14 15:05:41 +01002421 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2422 self.ser.addInputTensor(incorrect_iter)
2423 self.ser.addInputTensor(a)
2424 self.ser.addInputTensor(incorrect_acc)
2425 else:
2426 self.ser.addInputTensor(iter)
2427 self.ser.addInputTensor(a)
2428 self.ser.addInputTensor(acc)
2429
Kevin Cheng550ccc52021-03-03 11:21:43 -08002430 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002431
2432 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002433 iter_body_out = self.ser.addIntermediate(
2434 incorrect_iter.shape, incorrect_iter.dtype
2435 )
2436 acc_body_out = self.ser.addIntermediate(
2437 incorrect_acc.shape, incorrect_acc.dtype
2438 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002439 else:
2440 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2441 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2442
Eric Kunzee5e26762020-10-13 16:11:07 -07002443 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2444 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2445 self.ser.addOutputTensor(iter_body_out)
2446 self.ser.addOutputTensor(a)
2447 self.ser.addOutputTensor(acc_body_out)
2448
Les Bell729b0352021-11-24 10:28:21 +00002449 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002450 self.ser,
2451 validator_fcns,
2452 error_name,
2453 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002454 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002455 ):
2456 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002457
Eric Kunzee5e26762020-10-13 16:11:07 -07002458 return acc_out
2459
Luke Hutton57287132023-02-06 14:54:18 +00002460 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002461 self,
2462 op,
2463 val1,
2464 val2,
2465 inverse,
2466 validator_fcns=None,
2467 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002468 ):
2469 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2470
2471 input_names = [val1.name, val2.name]
2472 pCount, cCount = op["operands"]
2473 num_operands = pCount + cCount
2474
2475 output_names = [res.name for res in results]
2476 output_shapes = [res.shape for res in results]
2477 output_dtypes = [res.dtype for res in results]
2478
2479 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2480 self, error_name, input_names, output_names
2481 )
2482
2483 if not TosaErrorValidator.evValidateErrorIfs(
2484 self.ser,
2485 validator_fcns,
2486 error_name,
2487 op=op,
2488 inverse=inverse,
2489 input1=val1,
2490 input2=val2,
2491 input_shape=val1.shape,
2492 input_dtype=val1.dtype,
2493 output_shape=output_shapes,
2494 output_dtype=output_dtypes,
2495 result_tensors=results,
2496 input_list=input_names,
2497 output_list=output_names,
2498 num_operands=num_operands,
2499 ):
2500 return None
2501
Tai Lyd3797f02023-11-15 23:06:19 +00002502 # TODO - Test local_bound, for now set local bound attribute to False
2503 local_bound = False
2504
Luke Hutton57287132023-02-06 14:54:18 +00002505 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002506 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002507
2508 self.ser.addOperator(op["op"], input_names, output_names, attr)
2509 return results
2510
Tai Lyd3797f02023-11-15 23:06:19 +00002511 def build_rfft2d(
2512 self,
2513 op,
2514 val,
2515 validator_fcns=None,
2516 error_name=None,
2517 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002518 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2519
2520 input_names = [val.name]
2521 pCount, cCount = op["operands"]
2522 num_operands = pCount + cCount
2523
2524 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002525 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002526 output_dtypes = [res.dtype for res in results]
2527
2528 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2529 self, error_name, input_names, output_names
2530 )
2531
2532 if not TosaErrorValidator.evValidateErrorIfs(
2533 self.ser,
2534 validator_fcns,
2535 error_name,
2536 op=op,
2537 input_shape=val.shape,
2538 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002539 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002540 output_dtype=output_dtypes,
2541 result_tensors=results,
2542 input_list=input_names,
2543 output_list=output_names,
2544 num_operands=num_operands,
2545 ):
2546 return None
2547
Tai Lyd3797f02023-11-15 23:06:19 +00002548 # TODO - Test local_bound, for now set local bound attribute to False
2549 local_bound = False
2550
2551 attr = ts.TosaSerializerAttribute()
2552 attr.RFFTAttribute(local_bound)
2553
2554 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002555 return results
2556
Won Jeon74342e52024-01-09 00:34:40 +00002557 def build_shape_op(
2558 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2559 ):
2560 assert len(inputs) == 2
2561 a, b = inputs
2562
2563 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2564
2565 # Invalidate Input/Output list for error if checks.
2566 input_list = [a.name, b.name]
2567 output_list = [result_tensor.name]
2568 pCount, cCount = op["operands"]
2569 num_operands = pCount + cCount
2570 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2571 self, error_name, input_list, output_list
2572 )
2573
2574 if not TosaErrorValidator.evValidateErrorIfs(
2575 self.ser,
2576 validator_fcns,
2577 error_name,
2578 op=op,
2579 input1=a,
2580 input2=b,
2581 input_shape=a.shape,
2582 input_dtype=a.dtype,
2583 output_shape=result_tensor.shape,
2584 output_dtype=result_tensor.dtype,
2585 result_tensors=[result_tensor],
2586 input_list=input_list,
2587 output_list=output_list,
2588 num_operands=num_operands,
2589 ):
2590 return None
2591
2592 self.ser.addOperator(
2593 op["op"],
2594 input_list,
2595 output_list,
2596 )
2597 compliance = self.tensorComplianceMetaData(
2598 op, a.dtype, args_dict, result_tensor, error_name
2599 )
2600
2601 return TosaTestGen.BuildInfo(result_tensor, compliance)
2602
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002603 def create_filter_lists(
2604 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2605 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002606 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2607 default_test_rank_range = range(1, 5)
2608 if not shapeFilter:
2609 shapeFilter = [None]
2610
2611 # Calculate the filters based on what is requested and what the operator allows
2612 rmin, rmax = op["rank"]
2613 if rankFilter is not None:
2614 cleanRankFilter = []
2615 # Ensure rankFilter values are allowed by operator
2616 for rank in rankFilter:
2617 if rank >= rmin and rank <= rmax:
2618 cleanRankFilter.append(rank)
2619 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002620 # Ensure default behaviour is bounded by default range or by operator,
2621 # whichever is the smaller range of ranks.
2622 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002623 cleanRankFilter = (
2624 opRankRange
2625 if len(opRankRange) <= len(default_test_rank_range)
2626 else default_test_rank_range
2627 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002628 else:
2629 cleanRankFilter = range(rmin, rmax + 1)
2630
2631 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002632
Matthew Haddon1c00b712021-10-01 15:51:03 +01002633 if dtypeFilter is not None:
2634 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002635 # Create list of operator dtypes filtered by requested dtypes
2636 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002637 if dtype in dtypeFilter or (
2638 isinstance(dtype, list) and dtype[0] in dtypeFilter
2639 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002640 cleanDtypeFilter.append(dtype)
2641 else:
2642 cleanDtypeFilter = dtypes
2643
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002644 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002645 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002646 "shapeFilter": shapeFilter,
2647 "rankFilter": cleanRankFilter,
2648 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002649 }
2650 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002651 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002652 if validator is not None:
2653 validator_info = validator(check=False, op=op)
2654 else:
2655 return None
2656
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002657 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002658
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002659 # Set parameters as required
2660 if error_arguments["rank"] is not None:
2661 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002662 else:
2663 rankFilter = cleanRankFilter
2664
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002665 if error_arguments["dtype"] is not None:
2666 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002667 else:
2668 dtypeFilter = cleanDtypeFilter
2669
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002670 if error_arguments["shape"] is not None:
2671 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002672 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002673 shapeFilter = shapeFilter[
2674 :2
2675 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002676
2677 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002678 "shapeFilter": shapeFilter,
2679 "rankFilter": rankFilter,
2680 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002681 }
2682 return filterDict
2683
Kevin Cheng550ccc52021-03-03 11:21:43 -08002684 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002685 self,
2686 opName,
2687 shapeFilter=[None],
2688 rankFilter=None,
2689 dtypeFilter=None,
2690 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002691 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002692
2693 try:
2694 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002695 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002696 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002697
2698 # Initialize a new random number generator
2699 self.rng = np.random.default_rng(self.random_seed)
2700
Jeremy Johnson1271c442023-09-05 11:39:26 +01002701 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002702
Eric Kunzee5e26762020-10-13 16:11:07 -07002703 # Test list consists of a tuple of:
2704 # (opName, testNameStr, dtype, shapeList, argumentsList)
2705 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002706 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002707 error_if_validators = op["error_if_validators"]
2708 else:
2709 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002710
Matthew Haddon1c00b712021-10-01 15:51:03 +01002711 for validator in error_if_validators:
2712 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002713 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002714 else:
2715 error_name = None
2716
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002717 filterDict = self.create_filter_lists(
2718 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2719 )
2720 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002721 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002722 cleanRankFilter = filterDict["rankFilter"]
2723 cleanDtypeFilter = filterDict["dtypeFilter"]
2724 cleanShapeFilter = filterDict["shapeFilter"]
2725 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002726
2727 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002728 for t in cleanDtypeFilter:
2729 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002730 # Filter out by rank
2731 if shape is not None and len(shape) != r:
2732 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002733 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002734 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002735
Matthew Haddon74567092021-07-16 15:38:20 +01002736 shapeStr = self.shapeStr(shapeList[0])
2737 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002738
Matthew Haddon74567092021-07-16 15:38:20 +01002739 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2740 argList = []
2741 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002742 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002743 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002744 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002745
Matthew Haddon74567092021-07-16 15:38:20 +01002746 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002747 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002748 if argStr:
2749 testStr = "{}_{}_{}_{}".format(
2750 opName, shapeStr, typeStr, argStr
2751 )
2752 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002753 testStr = "{}_{}_{}".format(
2754 opName, shapeStr, typeStr
2755 )
2756 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002757 if argStr:
2758 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2759 opName, error_name, shapeStr, typeStr, argStr
2760 )
2761 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002762 testStr = "{}_ERRORIF_{}_{}_{}".format(
2763 opName, error_name, shapeStr, typeStr
2764 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002765
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002766 testList.append(
2767 (opName, testStr, t, error_name, shapeList, args)
2768 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002769
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002770 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002771 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2772 if "invalid_test_validators" in op:
2773 invalid_test_validators = op["invalid_test_validators"]
2774 clean_testList = []
2775 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002776 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002777 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002778 if validator_fcn(
2779 opName=test[0],
2780 input_dtype=test[2],
2781 shapeList=test[4],
2782 args=test[5],
2783 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002784 remove_test = True
2785 if not remove_test:
2786 clean_testList.append(test)
2787 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002788
2789 return testList
2790
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002791 def serializeTest(
2792 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2793 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002794 try:
2795 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002796 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002797 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002798
Jeremy Johnson0c716862023-04-13 17:18:19 +01002799 if self.args.verbose:
2800 print(f"Creating {testStr}")
2801
Eric Kunzee5e26762020-10-13 16:11:07 -07002802 # Create a serializer
2803 self.createSerializer(opName, testStr)
2804
Jeremy Johnson1271c442023-09-05 11:39:26 +01002805 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002806 if "error_if_validators" in op:
2807 error_if_validators = op["error_if_validators"]
2808 else:
2809 error_if_validators = None
2810
Kevin Cheng550ccc52021-03-03 11:21:43 -08002811 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002812 num_operands = pCount + cCount
2813
2814 if isinstance(dtype_or_dtypeList, list):
2815 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00002816 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002817 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002818 else:
2819 dtypeList = [dtype_or_dtypeList] * (num_operands)
2820
Won Jeon74342e52024-01-09 00:34:40 +00002821 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002822 assert (
2823 len(shapeList) == num_operands
2824 ), "shapeList length {} must match number of operands {}".format(
2825 len(shapeList), num_operands
2826 )
2827 assert (
2828 len(dtypeList) == num_operands
2829 ), "dtypeList length {} must match number of operands {}".format(
2830 len(dtypeList), num_operands
2831 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002832
2833 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002834 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002835 except KeyError:
2836 qgen = None
2837
2838 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002839
Matthew Haddon1c00b712021-10-01 15:51:03 +01002840 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002841 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002842 else:
2843 qinfo = None
2844
Jeremy Johnson1271c442023-09-05 11:39:26 +01002845 # Extra meta data for the desc.json
2846 tensMeta = {}
2847
2848 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002849 if isinstance(testArgs, dict):
2850 # New interface with args info in dictionary
2851 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002852 assert "dg_type" in argsDict
2853 tvgInfo = tvgen_fcn(
2854 self, opName, dtypeList, shapeList, argsDict, error_name
2855 )
2856 if tvgInfo.dataGenDict:
2857 tensMeta["data_gen"] = tvgInfo.dataGenDict
2858 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002859
2860 result = build_fcn(
2861 self,
2862 op,
2863 tens,
2864 argsDict,
2865 validator_fcns=error_if_validators,
2866 error_name=error_name,
2867 qinfo=qinfo,
2868 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002869 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002870 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002871 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002872
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002873 try:
2874 if error_if_validators is None:
2875 if qinfo is not None:
2876 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2877 else:
2878 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002879 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002880 if qinfo is not None:
2881 result = build_fcn(
2882 self,
2883 op,
2884 *tens,
2885 *testArgs,
2886 validator_fcns=error_if_validators,
2887 error_name=error_name,
2888 qinfo=qinfo,
2889 )
2890 else:
2891 result = build_fcn(
2892 self,
2893 op,
2894 *tens,
2895 *testArgs,
2896 validator_fcns=error_if_validators,
2897 error_name=error_name,
2898 )
2899 except TypeError as e:
2900 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2901 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002902
Jeremy Johnson1271c442023-09-05 11:39:26 +01002903 if result:
Les Bell729b0352021-11-24 10:28:21 +00002904 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002905 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2906 # Add the compliance meta data
2907 # NOTE: This currently expects only one result output
2908 tensMeta["compliance"] = {
2909 "version": "0.1",
2910 "tensors": {result.resultTensor.name: result.complianceDict},
2911 }
2912 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002913 else:
2914 # The test is not valid
2915 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002916
Eric Kunzee5e26762020-10-13 16:11:07 -07002917 def createDynamicOpLists(self):
2918
Jeremy Johnson00423432022-09-12 17:27:37 +01002919 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2920 # Already created these lists (can occur when class is initialized more than once)
2921 return
2922
Eric Kunzee5e26762020-10-13 16:11:07 -07002923 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002924 if not self.args.level8k:
2925 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2926 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2927 else:
2928 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2929 KERNELS_2D = [[1, bigK], [bigK, 2]]
2930 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002931
Kevin Cheng1533b852021-09-01 12:51:58 -07002932 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002933 testName = "conv2d_{}x{}".format(k[0], k[1])
2934 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2935 self.TOSA_OP_LIST[testName]["filter"] = k
2936 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002937
Kevin Cheng550ccc52021-03-03 11:21:43 -08002938 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2939 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2940 "depthwise_conv2d_TEMPLATE"
2941 ].copy()
2942 self.TOSA_OP_LIST[testName]["filter"] = k
2943 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002944
Kevin Cheng550ccc52021-03-03 11:21:43 -08002945 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2946 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2947 "transpose_conv2d_TEMPLATE"
2948 ].copy()
2949 self.TOSA_OP_LIST[testName]["filter"] = k
2950 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002951
Kevin Cheng1533b852021-09-01 12:51:58 -07002952 for k in KERNELS_3D:
2953 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2954 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2955 self.TOSA_OP_LIST[testName]["filter"] = k
2956 self.TOSA_OP_LIST[testName]["template"] = False
2957
Eric Kunzee5e26762020-10-13 16:11:07 -07002958 # Delete any templates after having created any dynamic ops
2959 # This is a two-pass operation because it's bad practice to delete
2960 # keys from dictionaries while iterating
2961 keyList = []
2962 for k in self.TOSA_OP_LIST:
2963 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002964 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002965 keyList.append(k)
2966 continue
2967 except KeyError:
2968 pass
2969
2970 for k in keyList:
2971 del self.TOSA_OP_LIST[k]
2972
2973 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002974 """Fill in default fields for ops if they aren't already specified.
2975 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002976 for op in self.TOSA_OP_LIST:
2977
2978 # Required fields
2979 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002980 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002981 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002982 raise Exception(
2983 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2984 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002985
2986 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002987 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002988 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002989 raise Exception(
2990 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2991 op
2992 )
2993 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002994
2995 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002996 _ = self.TOSA_OP_LIST[op]["types"]
2997 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002998 raise Exception(
2999 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3000 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003001
3002 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003003 _ = self.TOSA_OP_LIST[op]["op"]
3004 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003005 raise Exception(
3006 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3007 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003008
3009 # Put in default rank range, if missing
3010 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003011 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003012 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003013 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003014
3015 # Tensor operator list
3016 # 'op': op name
3017 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003018 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3019 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003020 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3021 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003022 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003023
Kevin Cheng550ccc52021-03-03 11:21:43 -08003024 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003025 TYPE_INT_FP = [
3026 DType.INT8,
3027 DType.INT16,
3028 DType.INT32,
3029 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003030 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003031 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003032 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003033
Kevin Cheng550ccc52021-03-03 11:21:43 -08003034 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003035 TYPE_FI32 = [
3036 DType.FP32,
3037 DType.FP16,
3038 DType.BF16,
3039 DType.INT32,
3040 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003041 TYPE_FIB = [
3042 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003043 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003044 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003045 DType.INT8,
3046 DType.INT16,
3047 DType.INT32,
3048 DType.BOOL,
3049 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003050 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003051
James Ward24dbc422022-10-19 12:20:31 +01003052 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07003053
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003054 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003055 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003056 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003057 [DType.INT8, DType.INT8, DType.INT32],
3058 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003059 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003060 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003061 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003062 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003063 ]
3064
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003065 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003066
3067 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003068 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003069 "argmax": {
3070 "op": Op.ARGMAX,
3071 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003072 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003073 "build_fcn": (
3074 build_argmax,
3075 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003076 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003077 TosaArgGen.agAxis,
3078 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003079 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003080 "error_if_validators": (
3081 TosaErrorValidator.evAxisSmallerZero,
3082 TosaErrorValidator.evAxisLargerRank,
3083 TosaErrorValidator.evArgmaxOutputRankMismatch,
3084 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3085 TosaErrorValidator.evWrongRank,
3086 TosaErrorValidator.evWrongInputType,
3087 TosaErrorValidator.evWrongOutputType,
3088 TosaErrorValidator.evWrongInputList,
3089 TosaErrorValidator.evWrongOutputList,
3090 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003091 "data_gen": {
3092 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3093 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003094 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003095 "avg_pool2d": {
3096 "op": Op.AVG_POOL2D,
3097 "operands": (1, 0),
3098 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003099 "build_fcn": (
3100 build_pool2d,
3101 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003102 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003103 TosaArgGen.agPooling,
3104 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003105 "qgen": TosaQuantGen.qgUnary,
3106 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003107 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003108 "error_if_validators": (
3109 TosaErrorValidator.evKernelSmallerOne,
3110 TosaErrorValidator.evStrideSmallerOne,
3111 TosaErrorValidator.evPadSmallerZero,
3112 TosaErrorValidator.evWrongRank,
3113 TosaErrorValidator.evWrongInputType,
3114 TosaErrorValidator.evWrongOutputType,
3115 TosaErrorValidator.evWrongInputList,
3116 TosaErrorValidator.evWrongOutputList,
3117 TosaErrorValidator.evInputZeroPointNotZero,
3118 TosaErrorValidator.evOutputZeroPointNotZero,
3119 TosaErrorValidator.evPadLargerEqualKernel,
3120 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003121 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003122 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003123 "data_gen": {
3124 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3125 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003126 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003127 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003128 "conv2d_TEMPLATE": {
3129 "op": Op.CONV2D,
3130 "operands": (1, 2),
3131 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003132 "build_fcn": (
3133 build_conv2d,
3134 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003135 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003136 TosaArgGen.agConv,
3137 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003138 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003139 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003140 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3141 "error_if_validators": (
3142 TosaErrorValidator.evWrongInputType,
3143 TosaErrorValidator.evWrongOutputType,
3144 TosaErrorValidator.evWrongInputList,
3145 TosaErrorValidator.evWrongOutputList,
3146 TosaErrorValidator.evInputZeroPointNotZero,
3147 TosaErrorValidator.evWeightZeroPointNotZero,
3148 TosaErrorValidator.evPadSmallerZero,
3149 TosaErrorValidator.evStrideSmallerOne,
3150 TosaErrorValidator.evDilationSmallerOne,
3151 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003152 TosaErrorValidator.evConvOutputShapeMismatch,
3153 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003154 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003155 "data_gen": {
3156 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3157 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003158 "template": True,
3159 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003160 # Templated operator. Filled in by createDynamicOpLists
3161 "conv3d_TEMPLATE": {
3162 "op": Op.CONV3D,
3163 "operands": (1, 2),
3164 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003165 "build_fcn": (
3166 build_conv3d,
3167 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003168 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003169 TosaArgGen.agConv,
3170 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003171 "qgen": TosaQuantGen.qgConv,
3172 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003173 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3174 "error_if_validators": (
3175 TosaErrorValidator.evWrongInputType,
3176 TosaErrorValidator.evWrongOutputType,
3177 TosaErrorValidator.evWrongInputList,
3178 TosaErrorValidator.evWrongOutputList,
3179 TosaErrorValidator.evInputZeroPointNotZero,
3180 TosaErrorValidator.evWeightZeroPointNotZero,
3181 TosaErrorValidator.evPadSmallerZero,
3182 TosaErrorValidator.evStrideSmallerOne,
3183 TosaErrorValidator.evDilationSmallerOne,
3184 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003185 TosaErrorValidator.evConvOutputShapeMismatch,
3186 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003187 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003188 "template": True,
3189 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003190 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003191 "depthwise_conv2d_TEMPLATE": {
3192 "op": Op.DEPTHWISE_CONV2D,
3193 "operands": (1, 2),
3194 "filter": [1, 1],
3195 "rank": (4, 4),
3196 "build_fcn": (
3197 build_depthwise_conv2d,
3198 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003199 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003200 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003201 ),
3202 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003203 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003204 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3205 "error_if_validators": (
3206 TosaErrorValidator.evWrongInputType,
3207 TosaErrorValidator.evWrongOutputType,
3208 TosaErrorValidator.evWrongInputList,
3209 TosaErrorValidator.evWrongOutputList,
3210 TosaErrorValidator.evInputZeroPointNotZero,
3211 TosaErrorValidator.evWeightZeroPointNotZero,
3212 TosaErrorValidator.evPadSmallerZero,
3213 TosaErrorValidator.evStrideSmallerOne,
3214 TosaErrorValidator.evDilationSmallerOne,
3215 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003216 TosaErrorValidator.evConvOutputShapeMismatch,
3217 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003218 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003219 "data_gen": {
3220 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3221 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003222 "template": True,
3223 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003224 "fully_connected": {
3225 "op": Op.FULLY_CONNECTED,
3226 "operands": (1, 2),
3227 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003228 "build_fcn": (
3229 build_fully_connected,
3230 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003231 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003232 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003233 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003234 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003235 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003236 "error_if_validators": (
3237 TosaErrorValidator.evInputZeroPointNotZero,
3238 TosaErrorValidator.evWeightZeroPointNotZero,
3239 TosaErrorValidator.evWrongRank,
3240 TosaErrorValidator.evWrongInputType,
3241 TosaErrorValidator.evWrongOutputType,
3242 TosaErrorValidator.evWrongInputList,
3243 TosaErrorValidator.evWrongOutputList,
3244 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003245 "data_gen": {
3246 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3247 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003248 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003249 "matmul": {
3250 "op": Op.MATMUL,
3251 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003252 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003253 "build_fcn": (
3254 build_matmul,
3255 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003256 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003257 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003258 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003259 "qgen": TosaQuantGen.qgMatmul,
3260 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003261 "error_if_validators": (
3262 TosaErrorValidator.evInputZeroPointNotZero,
3263 TosaErrorValidator.evWrongRank,
3264 TosaErrorValidator.evWrongInputType,
3265 TosaErrorValidator.evWrongOutputType,
3266 TosaErrorValidator.evWrongInputList,
3267 TosaErrorValidator.evWrongOutputList,
3268 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003269 "data_gen": {
3270 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003271 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003272 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003273 "max_pool2d": {
3274 "op": Op.MAX_POOL2D,
3275 "operands": (1, 0),
3276 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003277 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003278 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003279 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003280 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003281 TosaArgGen.agPooling,
3282 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003283 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003284 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003285 "error_if_validators": (
3286 TosaErrorValidator.evKernelSmallerOne,
3287 TosaErrorValidator.evStrideSmallerOne,
3288 TosaErrorValidator.evPadSmallerZero,
3289 TosaErrorValidator.evWrongRank,
3290 TosaErrorValidator.evWrongInputType,
3291 TosaErrorValidator.evWrongOutputType,
3292 TosaErrorValidator.evWrongInputList,
3293 TosaErrorValidator.evWrongOutputList,
3294 TosaErrorValidator.evPadLargerEqualKernel,
3295 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003296 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003297 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003298 "data_gen": {
3299 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3300 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003301 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003302 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003303 "transpose_conv2d_TEMPLATE": {
3304 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003305 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003306 "rank": (4, 4),
3307 "build_fcn": (
3308 build_transpose_conv2d,
3309 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003310 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003311 TosaArgGen.agTransposeConv2D,
3312 ),
3313 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003314 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003315 "invalid_test_validators": (
3316 TosaInvalidValidator.ivHeightWidthInvalid,
3317 TosaInvalidValidator.ivNonPositiveOutputShape,
3318 ),
3319 "error_if_validators": (
3320 TosaErrorValidator.evWrongInputType,
3321 TosaErrorValidator.evWrongOutputType,
3322 TosaErrorValidator.evWrongInputList,
3323 TosaErrorValidator.evWrongOutputList,
3324 TosaErrorValidator.evInputZeroPointNotZero,
3325 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003326 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003327 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003328 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003329 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003330 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003331 "template": True,
3332 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003333 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003334 "clamp": {
3335 "op": Op.CLAMP,
3336 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003337 "build_fcn": (
3338 build_clamp,
3339 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003340 TosaTensorValuesGen.tvgLazyGenDefault,
3341 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003342 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003343 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003344 "error_if_validators": (
3345 TosaErrorValidator.evMaxSmallerMin,
3346 TosaErrorValidator.evWrongInputType,
3347 TosaErrorValidator.evWrongOutputType,
3348 TosaErrorValidator.evWrongInputList,
3349 TosaErrorValidator.evWrongOutputList,
3350 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003351 "data_gen": {
3352 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3353 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003354 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003355 "sigmoid": {
3356 "op": Op.SIGMOID,
3357 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003358 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003359 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003360 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003361 TosaTensorValuesGen.tvgLazyGenDefault,
3362 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003363 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003364 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003365 "error_if_validators": (
3366 TosaErrorValidator.evWrongInputType,
3367 TosaErrorValidator.evWrongOutputType,
3368 TosaErrorValidator.evWrongInputList,
3369 TosaErrorValidator.evWrongOutputList,
3370 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003371 "data_gen": {
3372 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3373 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003374 },
3375 "tanh": {
3376 "op": Op.TANH,
3377 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003378 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003379 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003380 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003381 TosaTensorValuesGen.tvgLazyGenDefault,
3382 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003383 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003384 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003385 "error_if_validators": (
3386 TosaErrorValidator.evWrongInputType,
3387 TosaErrorValidator.evWrongOutputType,
3388 TosaErrorValidator.evWrongInputList,
3389 TosaErrorValidator.evWrongOutputList,
3390 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003391 "data_gen": {
3392 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3393 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003394 "compliance": {
3395 "abs_error_lower_bound": 0.5,
3396 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003397 },
Won Jeon78155c62023-06-10 00:20:04 +00003398 "erf": {
3399 "op": Op.ERF,
3400 "operands": (1, 0),
3401 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003402 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003403 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003404 TosaTensorValuesGen.tvgLazyGenDefault,
3405 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003406 ),
3407 "types": TYPE_FP,
3408 "error_if_validators": (
3409 TosaErrorValidator.evWrongInputType,
3410 TosaErrorValidator.evWrongOutputType,
3411 TosaErrorValidator.evWrongInputList,
3412 TosaErrorValidator.evWrongOutputList,
3413 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003414 "data_gen": {
3415 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3416 },
3417 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003418 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003419 # Elementwise Binary Operators
3420 "add": {
3421 "op": Op.ADD,
3422 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003423 "build_fcn": (
3424 build_binary_broadcast,
3425 TosaTensorGen.tgBroadcastFuzz,
3426 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003427 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003428 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003429 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003430 "error_if_validators": (
3431 TosaErrorValidator.evRankMismatch,
3432 TosaErrorValidator.evWrongInputType,
3433 TosaErrorValidator.evWrongOutputType,
3434 TosaErrorValidator.evWrongInputList,
3435 TosaErrorValidator.evWrongOutputList,
3436 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003437 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003438 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003439 "data_gen": {
3440 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3441 },
3442 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003443 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003444 "arithmetic_right_shift": {
3445 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3446 "operands": (2, 0),
3447 "build_fcn": (
3448 build_arithmetic_right_shift,
3449 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003450 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003451 TosaArgGen.agArithmeticRightShift,
3452 ),
3453 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003454 "error_if_validators": (
3455 TosaErrorValidator.evRankMismatch,
3456 TosaErrorValidator.evWrongInputType,
3457 TosaErrorValidator.evWrongOutputType,
3458 TosaErrorValidator.evWrongInputList,
3459 TosaErrorValidator.evWrongOutputList,
3460 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003461 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003462 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003463 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003464 "bitwise_and": {
3465 "op": Op.BITWISE_AND,
3466 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003467 "build_fcn": (
3468 build_binary_broadcast,
3469 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003470 TosaTensorValuesGen.tvgLazyGenDefault,
3471 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003472 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003474 "error_if_validators": (
3475 TosaErrorValidator.evRankMismatch,
3476 TosaErrorValidator.evWrongInputType,
3477 TosaErrorValidator.evWrongOutputType,
3478 TosaErrorValidator.evWrongInputList,
3479 TosaErrorValidator.evWrongOutputList,
3480 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003481 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003482 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003483 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003484 "bitwise_or": {
3485 "op": Op.BITWISE_OR,
3486 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003487 "build_fcn": (
3488 build_binary_broadcast,
3489 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003490 TosaTensorValuesGen.tvgLazyGenDefault,
3491 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003492 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003493 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003494 "error_if_validators": (
3495 TosaErrorValidator.evRankMismatch,
3496 TosaErrorValidator.evWrongInputType,
3497 TosaErrorValidator.evWrongOutputType,
3498 TosaErrorValidator.evWrongInputList,
3499 TosaErrorValidator.evWrongOutputList,
3500 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003501 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003502 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003503 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003504 "bitwise_xor": {
3505 "op": Op.BITWISE_XOR,
3506 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003507 "build_fcn": (
3508 build_binary_broadcast,
3509 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003510 TosaTensorValuesGen.tvgLazyGenDefault,
3511 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003512 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003513 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003514 "error_if_validators": (
3515 TosaErrorValidator.evRankMismatch,
3516 TosaErrorValidator.evWrongInputType,
3517 TosaErrorValidator.evWrongOutputType,
3518 TosaErrorValidator.evWrongInputList,
3519 TosaErrorValidator.evWrongOutputList,
3520 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003521 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003522 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003523 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003524 "intdiv": {
3525 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003526 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003527 "build_fcn": (
3528 build_binary_broadcast,
3529 TosaTensorGen.tgBroadcastFuzz,
3530 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003531 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003532 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003533 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003534 "error_if_validators": (
3535 TosaErrorValidator.evRankMismatch,
3536 TosaErrorValidator.evWrongInputType,
3537 TosaErrorValidator.evWrongOutputType,
3538 TosaErrorValidator.evWrongInputList,
3539 TosaErrorValidator.evWrongOutputList,
3540 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003541 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003542 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003543 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003544 "logical_and": {
3545 "op": Op.LOGICAL_AND,
3546 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003547 "build_fcn": (
3548 build_binary_broadcast,
3549 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003550 TosaTensorValuesGen.tvgLazyGenDefault,
3551 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003552 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003553 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003554 "error_if_validators": (
3555 TosaErrorValidator.evRankMismatch,
3556 TosaErrorValidator.evWrongInputType,
3557 TosaErrorValidator.evWrongOutputType,
3558 TosaErrorValidator.evWrongInputList,
3559 TosaErrorValidator.evWrongOutputList,
3560 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003561 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003562 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003563 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003564 "logical_left_shift": {
3565 "op": Op.LOGICAL_LEFT_SHIFT,
3566 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003567 "build_fcn": (
3568 build_binary_broadcast,
3569 TosaTensorGen.tgBroadcastFuzz,
3570 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003571 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003572 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003573 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003574 "error_if_validators": (
3575 TosaErrorValidator.evRankMismatch,
3576 TosaErrorValidator.evWrongInputType,
3577 TosaErrorValidator.evWrongOutputType,
3578 TosaErrorValidator.evWrongInputList,
3579 TosaErrorValidator.evWrongOutputList,
3580 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003581 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003582 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003583 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003584 "logical_right_shift": {
3585 "op": Op.LOGICAL_RIGHT_SHIFT,
3586 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003587 "build_fcn": (
3588 build_binary_broadcast,
3589 TosaTensorGen.tgBroadcastFuzz,
3590 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003591 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003592 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003593 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003594 "error_if_validators": (
3595 TosaErrorValidator.evRankMismatch,
3596 TosaErrorValidator.evWrongInputType,
3597 TosaErrorValidator.evWrongOutputType,
3598 TosaErrorValidator.evWrongInputList,
3599 TosaErrorValidator.evWrongOutputList,
3600 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003601 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003602 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003603 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003604 "logical_or": {
3605 "op": Op.LOGICAL_OR,
3606 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003607 "build_fcn": (
3608 build_binary_broadcast,
3609 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003610 TosaTensorValuesGen.tvgLazyGenDefault,
3611 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003612 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003613 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003614 "error_if_validators": (
3615 TosaErrorValidator.evRankMismatch,
3616 TosaErrorValidator.evWrongInputType,
3617 TosaErrorValidator.evWrongOutputType,
3618 TosaErrorValidator.evWrongInputList,
3619 TosaErrorValidator.evWrongOutputList,
3620 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003621 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003622 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003623 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003624 "logical_xor": {
3625 "op": Op.LOGICAL_XOR,
3626 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003627 "build_fcn": (
3628 build_binary_broadcast,
3629 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003630 TosaTensorValuesGen.tvgLazyGenDefault,
3631 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003632 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003633 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003634 "error_if_validators": (
3635 TosaErrorValidator.evRankMismatch,
3636 TosaErrorValidator.evWrongInputType,
3637 TosaErrorValidator.evWrongOutputType,
3638 TosaErrorValidator.evWrongInputList,
3639 TosaErrorValidator.evWrongOutputList,
3640 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003641 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003642 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003643 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003644 "maximum": {
3645 "op": Op.MAXIMUM,
3646 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003647 "build_fcn": (
3648 build_binary_broadcast,
3649 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003650 TosaTensorValuesGen.tvgLazyGenDefault,
3651 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003652 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003653 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003654 "error_if_validators": (
3655 TosaErrorValidator.evRankMismatch,
3656 TosaErrorValidator.evWrongInputType,
3657 TosaErrorValidator.evWrongOutputType,
3658 TosaErrorValidator.evWrongInputList,
3659 TosaErrorValidator.evWrongOutputList,
3660 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003661 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003662 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003663 "data_gen": {
3664 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3665 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003666 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003667 "minimum": {
3668 "op": Op.MINIMUM,
3669 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003670 "build_fcn": (
3671 build_binary_broadcast,
3672 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003673 TosaTensorValuesGen.tvgLazyGenDefault,
3674 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003675 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003676 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003677 "error_if_validators": (
3678 TosaErrorValidator.evRankMismatch,
3679 TosaErrorValidator.evWrongInputType,
3680 TosaErrorValidator.evWrongOutputType,
3681 TosaErrorValidator.evWrongInputList,
3682 TosaErrorValidator.evWrongOutputList,
3683 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003684 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003685 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003686 "data_gen": {
3687 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3688 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003689 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003690 "mul": {
3691 "op": Op.MUL,
3692 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003693 "build_fcn": (
3694 build_mul,
3695 TosaTensorGen.tgBroadcastFuzz,
3696 TosaTensorValuesGen.tvgMul,
3697 TosaArgGen.agMul,
3698 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003699 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003700 "error_if_validators": (
3701 TosaErrorValidator.evWrongInputType,
3702 TosaErrorValidator.evWrongOutputType,
3703 TosaErrorValidator.evWrongInputList,
3704 TosaErrorValidator.evWrongOutputList,
3705 TosaErrorValidator.evRankMismatch,
3706 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003707 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003708 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003709 "data_gen": {
3710 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3711 },
3712 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003713 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003714 "pow": {
3715 "op": Op.POW,
3716 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003717 "build_fcn": (
3718 build_binary_broadcast,
3719 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003720 TosaTensorValuesGen.tvgPow,
3721 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003722 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003723 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003724 "error_if_validators": (
3725 TosaErrorValidator.evRankMismatch,
3726 TosaErrorValidator.evWrongInputType,
3727 TosaErrorValidator.evWrongOutputType,
3728 TosaErrorValidator.evWrongInputList,
3729 TosaErrorValidator.evWrongOutputList,
3730 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003731 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003732 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003733 "data_gen": {
3734 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3735 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003736 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003737 "sub": {
3738 "op": Op.SUB,
3739 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003740 "build_fcn": (
3741 build_binary_broadcast,
3742 TosaTensorGen.tgBroadcastFuzz,
3743 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003744 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003745 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003746 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003747 "error_if_validators": (
3748 TosaErrorValidator.evRankMismatch,
3749 TosaErrorValidator.evWrongInputType,
3750 TosaErrorValidator.evWrongOutputType,
3751 TosaErrorValidator.evWrongInputList,
3752 TosaErrorValidator.evWrongOutputList,
3753 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003754 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003755 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003756 "data_gen": {
3757 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3758 },
3759 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003760 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003761 "table": {
3762 "op": Op.TABLE,
3763 # Use the automatic generation functions to create the input array
3764 # but create the table tensor in the build function, as it may be
3765 # a different type from the input
3766 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003767 "build_fcn": (
3768 build_table,
3769 TosaTensorGen.tgBasic,
3770 TosaTensorValuesGen.tvgDefault,
3771 TosaArgGen.agTable,
3772 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003773 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003774 "error_if_validators": (
3775 TosaErrorValidator.evWrongInputType,
3776 TosaErrorValidator.evWrongOutputType,
3777 TosaErrorValidator.evWrongInputList,
3778 TosaErrorValidator.evWrongOutputList,
3779 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003780 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 # Elementwise Unary operators
3782 "abs": {
3783 "op": Op.ABS,
3784 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003785 "build_fcn": (
3786 build_unary,
3787 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003788 TosaTensorValuesGen.tvgLazyGenDefault,
3789 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003790 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003791 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003792 "error_if_validators": (
3793 TosaErrorValidator.evWrongInputType,
3794 TosaErrorValidator.evWrongOutputType,
3795 TosaErrorValidator.evWrongInputList,
3796 TosaErrorValidator.evWrongOutputList,
3797 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003798 "data_gen": {
3799 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3800 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003801 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003802 "bitwise_not": {
3803 "op": Op.BITWISE_NOT,
3804 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003805 "build_fcn": (
3806 build_unary,
3807 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003808 TosaTensorValuesGen.tvgLazyGenDefault,
3809 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003810 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003811 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003812 "error_if_validators": (
3813 TosaErrorValidator.evWrongInputType,
3814 TosaErrorValidator.evWrongOutputType,
3815 TosaErrorValidator.evWrongInputList,
3816 TosaErrorValidator.evWrongOutputList,
3817 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003818 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003819 "ceil": {
3820 "op": Op.CEIL,
3821 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003822 "build_fcn": (
3823 build_unary,
3824 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003825 TosaTensorValuesGen.tvgLazyGenDefault,
3826 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003827 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003828 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003829 "error_if_validators": (
3830 TosaErrorValidator.evWrongInputType,
3831 TosaErrorValidator.evWrongOutputType,
3832 TosaErrorValidator.evWrongInputList,
3833 TosaErrorValidator.evWrongOutputList,
3834 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003835 "data_gen": {
3836 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3837 },
3838 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003839 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003840 "clz": {
3841 "op": Op.CLZ,
3842 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003843 "build_fcn": (
3844 build_unary,
3845 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003846 TosaTensorValuesGen.tvgLazyGenDefault,
3847 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003848 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003849 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003850 "error_if_validators": (
3851 TosaErrorValidator.evWrongInputType,
3852 TosaErrorValidator.evWrongOutputType,
3853 TosaErrorValidator.evWrongInputList,
3854 TosaErrorValidator.evWrongOutputList,
3855 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003856 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003857 "exp": {
3858 "op": Op.EXP,
3859 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003860 "build_fcn": (
3861 build_unary,
3862 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003863 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003864 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003865 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003866 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003867 "error_if_validators": (
3868 TosaErrorValidator.evWrongInputType,
3869 TosaErrorValidator.evWrongOutputType,
3870 TosaErrorValidator.evWrongInputList,
3871 TosaErrorValidator.evWrongOutputList,
3872 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003873 "data_gen": {
3874 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3875 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003876 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003877 "floor": {
3878 "op": Op.FLOOR,
3879 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003880 "build_fcn": (
3881 build_unary,
3882 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003883 TosaTensorValuesGen.tvgLazyGenDefault,
3884 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003885 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003886 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003887 "error_if_validators": (
3888 TosaErrorValidator.evWrongInputType,
3889 TosaErrorValidator.evWrongOutputType,
3890 TosaErrorValidator.evWrongInputList,
3891 TosaErrorValidator.evWrongOutputList,
3892 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003893 "data_gen": {
3894 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3895 },
3896 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003897 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003898 "log": {
3899 "op": Op.LOG,
3900 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003901 "build_fcn": (
3902 build_unary,
3903 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003904 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003905 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003906 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003907 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003908 "error_if_validators": (
3909 TosaErrorValidator.evWrongInputType,
3910 TosaErrorValidator.evWrongOutputType,
3911 TosaErrorValidator.evWrongInputList,
3912 TosaErrorValidator.evWrongOutputList,
3913 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003914 "data_gen": {
3915 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3916 },
3917 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003918 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003919 "logical_not": {
3920 "op": Op.LOGICAL_NOT,
3921 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003922 "build_fcn": (
3923 build_unary,
3924 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003925 TosaTensorValuesGen.tvgLazyGenDefault,
3926 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003927 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003928 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003929 "error_if_validators": (
3930 TosaErrorValidator.evWrongInputType,
3931 TosaErrorValidator.evWrongOutputType,
3932 TosaErrorValidator.evWrongInputList,
3933 TosaErrorValidator.evWrongOutputList,
3934 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003935 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003936 "negate": {
3937 "op": Op.NEGATE,
3938 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003939 "build_fcn": (
3940 build_unary,
3941 TosaTensorGen.tgBasic,
3942 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003943 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003944 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003945 "qgen": TosaQuantGen.qgUnary,
3946 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003947 "error_if_validators": (
3948 TosaErrorValidator.evInputZeroPointNotZero,
3949 TosaErrorValidator.evOutputZeroPointNotZero,
3950 TosaErrorValidator.evWrongInputType,
3951 TosaErrorValidator.evWrongOutputType,
3952 TosaErrorValidator.evWrongInputList,
3953 TosaErrorValidator.evWrongOutputList,
3954 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003955 "data_gen": {
3956 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3957 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003958 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003959 "reciprocal": {
3960 "op": Op.RECIPROCAL,
3961 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003962 "build_fcn": (
3963 build_unary,
3964 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003965 TosaTensorValuesGen.tvgLazyGenDefault,
3966 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003967 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003968 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003969 "error_if_validators": (
3970 TosaErrorValidator.evWrongInputType,
3971 TosaErrorValidator.evWrongOutputType,
3972 TosaErrorValidator.evWrongInputList,
3973 TosaErrorValidator.evWrongOutputList,
3974 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003975 "data_gen": {
3976 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3977 },
3978 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003979 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003980 "rsqrt": {
3981 "op": Op.RSQRT,
3982 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003983 "build_fcn": (
3984 build_unary,
3985 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003986 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003987 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003988 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003989 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003990 "error_if_validators": (
3991 TosaErrorValidator.evWrongInputType,
3992 TosaErrorValidator.evWrongOutputType,
3993 TosaErrorValidator.evWrongInputList,
3994 TosaErrorValidator.evWrongOutputList,
3995 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003996 "data_gen": {
3997 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3998 },
3999 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004000 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004001 # Elementwise Ternary operators
4002 "select": {
4003 "op": Op.SELECT,
4004 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004005 "build_fcn": (
4006 build_select,
4007 TosaTensorGen.tgBroadcastFuzz,
4008 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004009 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004010 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004011 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004012 "error_if_validators": (
4013 TosaErrorValidator.evRankMismatch,
4014 TosaErrorValidator.evWrongInputType,
4015 TosaErrorValidator.evWrongOutputType,
4016 TosaErrorValidator.evWrongInputList,
4017 TosaErrorValidator.evWrongOutputList,
4018 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004019 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004020 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004021 "data_gen": {
4022 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4023 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004024 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004025 # Comparison operators
4026 "equal": {
4027 "op": Op.EQUAL,
4028 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004029 "build_fcn": (
4030 build_comparison,
4031 TosaTensorGen.tgBroadcastFuzz,
4032 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004033 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004034 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004035 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004036 "error_if_validators": (
4037 TosaErrorValidator.evRankMismatch,
4038 TosaErrorValidator.evWrongInputType,
4039 TosaErrorValidator.evWrongOutputType,
4040 TosaErrorValidator.evWrongInputList,
4041 TosaErrorValidator.evWrongOutputList,
4042 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004043 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004044 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004045 "data_gen": {
4046 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4047 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004048 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004049 "greater_equal": {
4050 "op": Op.GREATER_EQUAL,
4051 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004052 "build_fcn": (
4053 build_comparison,
4054 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004055 TosaTensorValuesGen.tvgLazyGenDefault,
4056 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004057 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004058 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004059 "error_if_validators": (
4060 TosaErrorValidator.evRankMismatch,
4061 TosaErrorValidator.evWrongInputType,
4062 TosaErrorValidator.evWrongOutputType,
4063 TosaErrorValidator.evWrongInputList,
4064 TosaErrorValidator.evWrongOutputList,
4065 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004066 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004067 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004068 "data_gen": {
4069 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4070 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004071 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004072 "greater": {
4073 "op": Op.GREATER,
4074 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004075 "build_fcn": (
4076 build_comparison,
4077 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004078 TosaTensorValuesGen.tvgLazyGenDefault,
4079 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004080 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004081 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004082 "error_if_validators": (
4083 TosaErrorValidator.evRankMismatch,
4084 TosaErrorValidator.evWrongInputType,
4085 TosaErrorValidator.evWrongOutputType,
4086 TosaErrorValidator.evWrongInputList,
4087 TosaErrorValidator.evWrongOutputList,
4088 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004089 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004090 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004091 "data_gen": {
4092 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4093 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004094 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004095 # Reduction operators
4096 "reduce_all": {
4097 "op": Op.REDUCE_ALL,
4098 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004099 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004100 "build_fcn": (
4101 build_reduce,
4102 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004103 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004104 TosaArgGen.agAxis,
4105 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004106 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004107 "error_if_validators": (
4108 TosaErrorValidator.evAxisLargerRank,
4109 TosaErrorValidator.evAxisSmallerZero,
4110 TosaErrorValidator.evShapeOfAxisNotOne,
4111 TosaErrorValidator.evWrongInputType,
4112 TosaErrorValidator.evWrongOutputType,
4113 TosaErrorValidator.evWrongRank,
4114 TosaErrorValidator.evWrongInputList,
4115 TosaErrorValidator.evWrongOutputList,
4116 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004117 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004118 "reduce_any": {
4119 "op": Op.REDUCE_ANY,
4120 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004121 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004122 "build_fcn": (
4123 build_reduce,
4124 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004125 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004126 TosaArgGen.agAxis,
4127 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004128 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004129 "error_if_validators": (
4130 TosaErrorValidator.evAxisLargerRank,
4131 TosaErrorValidator.evAxisSmallerZero,
4132 TosaErrorValidator.evShapeOfAxisNotOne,
4133 TosaErrorValidator.evWrongInputType,
4134 TosaErrorValidator.evWrongOutputType,
4135 TosaErrorValidator.evWrongRank,
4136 TosaErrorValidator.evWrongInputList,
4137 TosaErrorValidator.evWrongOutputList,
4138 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004139 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004140 "reduce_max": {
4141 "op": Op.REDUCE_MAX,
4142 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004143 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004144 "build_fcn": (
4145 build_reduce,
4146 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004147 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004148 TosaArgGen.agAxis,
4149 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004150 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004151 "error_if_validators": (
4152 TosaErrorValidator.evAxisLargerRank,
4153 TosaErrorValidator.evAxisSmallerZero,
4154 TosaErrorValidator.evShapeOfAxisNotOne,
4155 TosaErrorValidator.evWrongInputType,
4156 TosaErrorValidator.evWrongOutputType,
4157 TosaErrorValidator.evWrongRank,
4158 TosaErrorValidator.evWrongInputList,
4159 TosaErrorValidator.evWrongOutputList,
4160 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004161 "data_gen": {
4162 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4163 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004164 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004165 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004166 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004167 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004168 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004169 "build_fcn": (
4170 build_reduce,
4171 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004172 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004173 TosaArgGen.agAxis,
4174 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004175 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004176 "error_if_validators": (
4177 TosaErrorValidator.evAxisLargerRank,
4178 TosaErrorValidator.evAxisSmallerZero,
4179 TosaErrorValidator.evShapeOfAxisNotOne,
4180 TosaErrorValidator.evWrongInputType,
4181 TosaErrorValidator.evWrongOutputType,
4182 TosaErrorValidator.evWrongRank,
4183 TosaErrorValidator.evWrongInputList,
4184 TosaErrorValidator.evWrongOutputList,
4185 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004186 "data_gen": {
4187 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4188 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004189 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004190 "reduce_product": {
4191 "op": Op.REDUCE_PRODUCT,
4192 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004193 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004194 "build_fcn": (
4195 build_reduce,
4196 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004197 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004198 TosaArgGen.agAxis,
4199 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004200 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004201 "error_if_validators": (
4202 TosaErrorValidator.evAxisLargerRank,
4203 TosaErrorValidator.evAxisSmallerZero,
4204 TosaErrorValidator.evShapeOfAxisNotOne,
4205 TosaErrorValidator.evWrongInputType,
4206 TosaErrorValidator.evWrongOutputType,
4207 TosaErrorValidator.evWrongRank,
4208 TosaErrorValidator.evWrongInputList,
4209 TosaErrorValidator.evWrongOutputList,
4210 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004211 "data_gen": {
4212 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4213 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004214 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004215 "reduce_sum": {
4216 "op": Op.REDUCE_SUM,
4217 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004218 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004219 "build_fcn": (
4220 build_reduce,
4221 TosaTensorGen.tgBasic,
4222 TosaTensorValuesGen.tvgReduceSum,
4223 TosaArgGen.agAxis,
4224 ),
James Ward24dbc422022-10-19 12:20:31 +01004225 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004226 "error_if_validators": (
4227 TosaErrorValidator.evAxisLargerRank,
4228 TosaErrorValidator.evAxisSmallerZero,
4229 TosaErrorValidator.evShapeOfAxisNotOne,
4230 TosaErrorValidator.evWrongInputType,
4231 TosaErrorValidator.evWrongOutputType,
4232 TosaErrorValidator.evWrongRank,
4233 TosaErrorValidator.evWrongInputList,
4234 TosaErrorValidator.evWrongOutputList,
4235 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004236 "data_gen": {
4237 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4238 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004239 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004240 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004241 "concat": {
4242 "op": Op.CONCAT,
4243 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004244 "build_fcn": (
4245 build_concat,
4246 TosaTensorGen.tgConcat,
4247 TosaTensorValuesGen.tvgConcat,
4248 TosaArgGen.agAxis,
4249 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004250 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004251 "error_if_validators": (
4252 TosaErrorValidator.evAxisLargerRank,
4253 TosaErrorValidator.evAxisSmallerZero,
4254 TosaErrorValidator.evConcatInputRankMismatch,
4255 TosaErrorValidator.evConcatShapeSumMismatch,
4256 TosaErrorValidator.evConcatInputDimMismatch,
4257 TosaErrorValidator.evWrongInputType,
4258 TosaErrorValidator.evWrongOutputType,
4259 TosaErrorValidator.evWrongOutputList,
4260 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004261 "data_gen": {
4262 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4263 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004264 },
4265 "pad": {
4266 "op": Op.PAD,
4267 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004268 "build_fcn": (
4269 build_pad,
4270 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004271 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004272 TosaArgGen.agPad,
4273 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004274 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004275 "error_if_validators": (
4276 TosaErrorValidator.evWrongInputType,
4277 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004278 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004279 TosaErrorValidator.evWrongOutputType,
4280 TosaErrorValidator.evWrongInputList,
4281 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004282 TosaErrorValidator.evRankMismatch,
4283 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004284 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004285 "data_gen": {
4286 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4287 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004288 },
Won Jeona21b2e82023-08-10 10:33:01 +00004289 "dim": {
4290 "op": Op.DIM,
4291 "operands": (1, 0),
4292 "build_fcn": (
4293 build_dim,
4294 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004295 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004296 TosaArgGen.agAxis,
4297 ),
4298 "types": TYPE_FIB,
4299 "error_if_validators": (
4300 TosaErrorValidator.evAxisLargerRank,
4301 TosaErrorValidator.evAxisSmallerZero,
4302 TosaErrorValidator.evWrongInputType,
4303 TosaErrorValidator.evWrongInputList,
4304 TosaErrorValidator.evWrongOutputList,
4305 TosaErrorValidator.evWrongRank,
4306 ),
4307 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004308 "reshape": {
4309 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004310 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004311 "build_fcn": (
4312 build_reshape,
4313 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004314 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004315 TosaArgGen.agReshape,
4316 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004317 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004318 "error_if_validators": (
4319 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4320 TosaErrorValidator.evWrongInputType,
4321 TosaErrorValidator.evWrongOutputType,
4322 TosaErrorValidator.evWrongInputList,
4323 TosaErrorValidator.evWrongOutputList,
4324 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004325 "data_gen": {
4326 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4327 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004328 },
4329 "reverse": {
4330 "op": Op.REVERSE,
4331 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004332 "build_fcn": (
4333 build_reverse,
4334 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004335 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004336 TosaArgGen.agAxis,
4337 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004338 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004339 "error_if_validators": (
4340 TosaErrorValidator.evAxisSmallerZero,
4341 TosaErrorValidator.evAxisLargerRank,
4342 TosaErrorValidator.evWrongInputType,
4343 TosaErrorValidator.evWrongOutputType,
4344 TosaErrorValidator.evWrongInputList,
4345 TosaErrorValidator.evWrongOutputList,
4346 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004347 },
4348 "slice": {
4349 "op": Op.SLICE,
4350 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004351 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004352 "build_fcn": (
4353 build_slice,
4354 TosaTensorGen.tgBasic,
evacha017f7d4252024-01-24 12:08:09 +00004355 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004356 TosaArgGen.agSlice,
4357 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004358 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004359 "error_if_validators": (
4360 TosaErrorValidator.evStartSmallerZero,
4361 TosaErrorValidator.evSizeSmallerEqualZero,
4362 TosaErrorValidator.evStartSizeOutsideBounds,
4363 TosaErrorValidator.evSizeOutputShapeMismatch,
4364 TosaErrorValidator.evInputSizeStartLengthMismatch,
4365 TosaErrorValidator.evWrongRank,
4366 TosaErrorValidator.evWrongInputType,
4367 TosaErrorValidator.evWrongOutputType,
4368 TosaErrorValidator.evWrongInputList,
4369 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004370 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004371 ),
evacha017f7d4252024-01-24 12:08:09 +00004372 "data_gen": {
4373 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4374 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004375 },
4376 "tile": {
4377 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004378 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004379 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004380 "build_fcn": (
4381 build_tile,
4382 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004383 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004384 TosaArgGen.agTile,
4385 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004386 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004387 "error_if_validators": (
4388 TosaErrorValidator.evWrongInputType,
4389 TosaErrorValidator.evWrongOutputType,
4390 TosaErrorValidator.evWrongInputList,
4391 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004392 TosaErrorValidator.evRankMismatch,
4393 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004394 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004395 "data_gen": {
4396 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4397 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004398 },
4399 "transpose": {
4400 "op": Op.TRANSPOSE,
4401 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004402 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004403 "build_fcn": (
4404 build_transpose,
4405 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004406 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004407 TosaArgGen.agTranspose,
4408 ),
4409 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004410 "error_if_validators": (
4411 TosaErrorValidator.evIndexOutsideBounds,
4412 TosaErrorValidator.evIndexUsedTwice,
4413 TosaErrorValidator.evWrongInputType,
4414 TosaErrorValidator.evWrongOutputType,
4415 TosaErrorValidator.evWrongInputList,
4416 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004417 TosaErrorValidator.evWrongRank,
4418 TosaErrorValidator.evRankMismatch,
4419 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004420 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004421 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004422 # Data nodes
4423 "const": {
4424 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004425 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004426 "build_fcn": (
4427 build_const,
4428 TosaTensorGen.tgBasic,
4429 TosaTensorValuesGen.tvgDefault,
4430 None,
4431 ),
Luke Hutton65872422023-02-20 10:33:04 +00004432 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004433 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004434 "identity": {
4435 "op": Op.IDENTITY,
4436 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004437 "build_fcn": (
4438 build_unary,
4439 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004440 TosaTensorValuesGen.tvgLazyGenDefault,
4441 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004442 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004443 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004444 "data_gen": {
4445 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4446 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004447 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004448 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004449 "gather": {
4450 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004451 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004452 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004453 "build_fcn": (
4454 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004455 TosaTensorGen.tgGather,
4456 TosaTensorValuesGen.tvgGather,
4457 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004458 ),
James Ward24dbc422022-10-19 12:20:31 +01004459 "types": (
4460 DType.INT8,
4461 DType.INT16,
4462 DType.INT32,
4463 DType.FP16,
4464 DType.BF16,
4465 DType.FP32,
4466 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004467 "error_if_validators": (
4468 TosaErrorValidator.evWrongInputType,
4469 TosaErrorValidator.evWrongOutputType,
4470 TosaErrorValidator.evWrongInputList,
4471 TosaErrorValidator.evWrongOutputList,
4472 TosaErrorValidator.evWrongRank,
4473 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004474 "data_gen": {
4475 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4476 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004477 },
4478 "scatter": {
4479 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004480 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004481 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004482 "build_fcn": (
4483 build_scatter,
4484 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004485 TosaTensorValuesGen.tvgScatter,
4486 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004487 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004488 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004489 "error_if_validators": (
4490 TosaErrorValidator.evWrongInputType,
4491 TosaErrorValidator.evWrongOutputType,
4492 TosaErrorValidator.evWrongInputList,
4493 TosaErrorValidator.evWrongOutputList,
4494 TosaErrorValidator.evWrongRank,
4495 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004496 "data_gen": {
4497 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4498 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004499 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004500 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004501 "resize": {
4502 "op": Op.RESIZE,
4503 "operands": (1, 0),
4504 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004505 "build_fcn": (
4506 build_resize,
4507 TosaTensorGen.tgNHWC,
4508 TosaTensorValuesGen.tvgDefault,
4509 TosaArgGen.agResize,
4510 ),
James Ward24dbc422022-10-19 12:20:31 +01004511 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004512 "invalid_test_validators": (
4513 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004514 ),
4515 "error_if_validators": (
4516 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004517 TosaErrorValidator.evScaleSmallerEqualZero,
4518 TosaErrorValidator.evScaleNLargerMax,
4519 TosaErrorValidator.evScaleDLargerMax,
4520 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004521 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004522 TosaErrorValidator.evBorderSmallerMin,
4523 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004524 TosaErrorValidator.evWrongInputType,
4525 TosaErrorValidator.evWrongOutputType,
4526 TosaErrorValidator.evWrongRank,
4527 TosaErrorValidator.evWrongInputList,
4528 TosaErrorValidator.evWrongOutputList,
4529 TosaErrorValidator.evBatchMismatch,
4530 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004531 TosaErrorValidator.evResizeOutputShapeMismatch,
4532 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004533 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004534 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004535 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004536 "cast": {
4537 "op": Op.CAST,
4538 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004539 "build_fcn": (
4540 build_cast,
4541 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004542 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004543 TosaArgGen.agCast,
4544 ),
James Ward8b390432022-08-12 20:48:56 +01004545 "types": (
4546 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004547 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004548 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004549 DType.INT8,
4550 DType.INT16,
4551 DType.INT32,
4552 DType.BOOL,
4553 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004554 "error_if_validators": (
4555 TosaErrorValidator.evWrongInputType,
4556 TosaErrorValidator.evWrongOutputType,
4557 TosaErrorValidator.evWrongInputList,
4558 TosaErrorValidator.evWrongOutputList,
4559 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004560 "data_gen": {
4561 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4562 },
4563 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004564 },
4565 "rescale": {
4566 "op": Op.RESCALE,
4567 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004568 "build_fcn": (
4569 build_rescale,
4570 TosaTensorGen.tgBasic,
4571 TosaTensorValuesGen.tvgDefault,
4572 TosaArgGen.agRescale,
4573 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004574 "types": [
4575 DType.UINT8,
4576 DType.INT8,
4577 DType.INT16,
4578 DType.INT32,
4579 DType.INT48,
4580 DType.UINT16,
4581 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004582 "error_if_validators": (
4583 TosaErrorValidator.evInputZeroPointNotZero,
4584 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004585 TosaErrorValidator.evU16InputZeroPointNotValid,
4586 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004587 TosaErrorValidator.evScaleTrue,
4588 TosaErrorValidator.evScaleNotTrue,
4589 TosaErrorValidator.evWrongInputType,
4590 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004591 TosaErrorValidator.evWrongInputList,
4592 TosaErrorValidator.evWrongOutputList,
4593 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004594 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004595 # Custom
4596 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004597 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004598 # Two varients of cond_if, one that generates one of two constant tensors (no
4599 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4600 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004601 "cond_if_const": {
4602 "op": Op.COND_IF,
4603 "operands": (0, 2),
4604 "build_fcn": (
4605 build_cond_if_const,
4606 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004607 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004608 TosaArgGen.agCondIf,
4609 ),
4610 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004611 "error_if_validators": (
4612 TosaErrorValidator.evOutputListThenGraphMismatch,
4613 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004614 TosaErrorValidator.evCondIfCondNotMatchingBool,
4615 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004616 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004617 },
4618 "cond_if_binary": {
4619 "op": Op.COND_IF,
4620 "operands": (2, 0),
4621 "build_fcn": (
4622 build_cond_if_binary,
4623 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004624 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004625 TosaArgGen.agCondIf,
4626 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004627 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004628 "error_if_validators": (
4629 TosaErrorValidator.evInputListThenGraphMismatch,
4630 TosaErrorValidator.evInputListElseGraphMismatch,
4631 TosaErrorValidator.evOutputListThenGraphMismatch,
4632 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004633 TosaErrorValidator.evCondIfCondNotMatchingBool,
4634 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004635 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004636 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004637 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004638 "while_loop": {
4639 "op": Op.WHILE_LOOP,
4640 "operands": (0, 1),
4641 "build_fcn": (
4642 build_while_loop,
4643 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004644 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004645 TosaArgGen.agWhileLoop,
4646 ),
4647 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004648 "error_if_validators": (
4649 TosaErrorValidator.evInputListOutputListMismatch,
4650 TosaErrorValidator.evInputListCondGraphMismatch,
4651 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4652 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4653 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004654 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004655 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004656 },
Luke Hutton57287132023-02-06 14:54:18 +00004657 "fft2d": {
4658 "op": Op.FFT2D,
4659 "operands": (2, 0),
4660 "rank": (3, 3),
4661 "build_fcn": (
4662 build_fft2d,
4663 TosaTensorGen.tgFFT2d,
4664 TosaTensorValuesGen.tvgDefault,
4665 TosaArgGen.agFFT2d,
4666 ),
4667 "types": [DType.FP32],
4668 "error_if_validators": (
4669 TosaErrorValidator.evWrongInputType,
4670 TosaErrorValidator.evWrongOutputType,
4671 TosaErrorValidator.evWrongInputList,
4672 TosaErrorValidator.evWrongOutputList,
4673 TosaErrorValidator.evWrongRank,
4674 TosaErrorValidator.evBatchMismatch,
4675 TosaErrorValidator.evKernelNotPowerOfTwo,
4676 TosaErrorValidator.evFFTInputShapeMismatch,
4677 TosaErrorValidator.evFFTOutputShapeMismatch,
4678 ),
4679 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004680 "rfft2d": {
4681 "op": Op.RFFT2D,
4682 "operands": (1, 0),
4683 "rank": (3, 3),
4684 "build_fcn": (
4685 build_rfft2d,
4686 TosaTensorGen.tgRFFT2d,
4687 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004688 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004689 ),
4690 "types": [DType.FP32],
4691 "error_if_validators": (
4692 TosaErrorValidator.evWrongInputType,
4693 TosaErrorValidator.evWrongOutputType,
4694 TosaErrorValidator.evWrongInputList,
4695 TosaErrorValidator.evWrongOutputList,
4696 TosaErrorValidator.evWrongRank,
4697 TosaErrorValidator.evBatchMismatch,
4698 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004699 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004700 ),
4701 },
Won Jeon74342e52024-01-09 00:34:40 +00004702 # Shape
4703 "add_shape": {
4704 "op": Op.ADD_SHAPE,
4705 "operands": (2, 0),
4706 "build_fcn": (
4707 build_shape_op,
4708 TosaTensorGen.tgShape,
4709 TosaTensorValuesGen.tvgAddSub,
4710 TosaArgGen.agNone,
4711 ),
4712 "types": [DType.SHAPE],
4713 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4714 },
4715 "sub_shape": {
4716 "op": Op.SUB_SHAPE,
4717 "operands": (2, 0),
4718 "build_fcn": (
4719 build_shape_op,
4720 TosaTensorGen.tgShape,
4721 TosaTensorValuesGen.tvgAddSub,
4722 TosaArgGen.agNone,
4723 ),
4724 "types": [DType.SHAPE],
4725 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4726 },
4727 "mul_shape": {
4728 "op": Op.MUL_SHAPE,
4729 "operands": (2, 0),
4730 "build_fcn": (
4731 build_shape_op,
4732 TosaTensorGen.tgShape,
4733 TosaTensorValuesGen.tvgMul,
4734 TosaArgGen.agNone,
4735 ),
4736 "types": [DType.SHAPE],
4737 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4738 },
4739 "div_shape": {
4740 "op": Op.DIV_SHAPE,
4741 "operands": (2, 0),
4742 "build_fcn": (
4743 build_shape_op,
4744 TosaTensorGen.tgShape,
4745 TosaTensorValuesGen.tvgIntDiv,
4746 TosaArgGen.agNone,
4747 ),
4748 "types": [DType.SHAPE],
4749 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4750 },
4751 "concat_shape": {
4752 "op": Op.CONCAT_SHAPE,
4753 "operands": (2, 0),
4754 "build_fcn": (
4755 build_concat,
4756 TosaTensorGen.tgConcat,
4757 TosaTensorValuesGen.tvgConcat,
4758 TosaArgGen.agNone,
4759 ),
4760 "types": [DType.SHAPE],
4761 "error_if_validators": (),
4762 },
4763 "const_shape": {
4764 "op": Op.CONST_SHAPE,
4765 "operands": (0, 1),
4766 "build_fcn": (
4767 build_const,
4768 TosaTensorGen.tgBasic,
4769 TosaTensorValuesGen.tvgDefault,
4770 None,
4771 ),
4772 "types": [DType.SHAPE],
4773 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004774 }
4775
Kevin Cheng550ccc52021-03-03 11:21:43 -08004776
Eric Kunzee5e26762020-10-13 16:11:07 -07004777class OutputShaper:
4778 # Methods in this class compute the expected output shape and datatype
4779 # for common classes of operations
4780 def __init__(self):
4781 pass
4782
4783 # These methods return arguments that can be used for
4784 # creating a new output tensor
4785 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004786 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4787 if error_name != ErrorIf.RankMismatch:
4788 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004789 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004790
4791 shape = []
4792 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004793 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004794 shape.append(b.shape[i])
4795 else:
4796 shape.append(a.shape[i])
4797
Jerry Ge135c9552023-05-23 20:59:32 +00004798 fuzz_idx = rng.integers(0, len(a.shape))
4799 if error_name == ErrorIf.DimensionMismatch:
4800 shape[fuzz_idx] += 1
4801
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004802 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004803 all_dtypes = [
4804 DType.INT8,
4805 DType.INT16,
4806 DType.INT32,
4807 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004808 DType.FP16,
4809 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004810 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004811 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004812 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4813 outputDType = rng.choice(wrong_dtypes)
4814 else:
4815 outputDType = a.dtype
4816
4817 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004818
4819 @staticmethod
4820 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004821 assert len(a.shape) == len(b.shape)
4822 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004823
4824 shape = []
4825 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004826 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004827 shape.append(a.shape[i])
4828
Kevin Cheng550ccc52021-03-03 11:21:43 -08004829 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004830
4831 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004832 def unaryOp(ser, rng, a, error_name=None):
4833 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004834 all_dtypes = [
4835 DType.INT8,
4836 DType.INT16,
4837 DType.INT32,
4838 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004839 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004840 DType.FP16,
4841 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004842 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004843 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4844 outputDType = rng.choice(wrong_dtypes)
4845 else:
4846 outputDType = a.dtype
4847
4848 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004849
4850 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004851 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004852 if error_name != ErrorIf.RankMismatch:
4853 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004854 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004855
4856 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004857 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004858 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004859 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4860 else:
4861 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004862
Jerry Ge135c9552023-05-23 20:59:32 +00004863 fuzz_idx = rng.integers(0, len(a.shape))
4864 if error_name == ErrorIf.DimensionMismatch:
4865 shape[fuzz_idx] += 1
4866
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004867 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004868 all_dtypes = [
4869 DType.INT8,
4870 DType.INT16,
4871 DType.INT32,
4872 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004873 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004874 DType.FP16,
4875 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004876 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004877 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4878 outputDType = rng.choice(wrong_dtypes)
4879 else:
4880 outputDType = a.dtype
4881
4882 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004883
4884 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004885 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004886 if error_name != ErrorIf.RankMismatch:
4887 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004888 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004889
4890 # Do broadcast
4891 shape = []
4892 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004893 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004894 shape.append(b.shape[i])
4895 else:
4896 shape.append(a.shape[i])
4897
Jerry Ge135c9552023-05-23 20:59:32 +00004898 fuzz_idx = rng.integers(0, len(a.shape))
4899 if error_name == ErrorIf.DimensionMismatch:
4900 shape[fuzz_idx] += 1
4901
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004902 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004903 wrong_dtypes = [
4904 DType.INT8,
4905 DType.INT16,
4906 DType.INT32,
4907 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004908 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004909 DType.FP16,
4910 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004911 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004912 outputDType = rng.choice(wrong_dtypes)
4913 else:
4914 outputDType = DType.BOOL
4915
4916 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004917
4918 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004919 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004920 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004921 if error_name not in [
4922 ErrorIf.AxisSmallerZero,
4923 ErrorIf.AxisLargerRank,
4924 ErrorIf.ShapeOfAxisNotOne,
4925 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004926 shape[axis] = 1
4927 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4928 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004929
Matthew Haddond6ce7252021-09-29 15:35:44 +01004930 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004931 all_dtypes = [
4932 DType.INT8,
4933 DType.INT16,
4934 DType.INT32,
4935 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004936 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004937 DType.FP16,
4938 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004939 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004940 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4941 outputDType = rng.choice(wrong_dtypes)
4942 else:
4943 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004944
Matthew Haddond6ce7252021-09-29 15:35:44 +01004945 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004946
4947 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004948 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004949 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004950
4951 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4952 del shape[axis]
4953
4954 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4955 remove = rng.choice([True, False])
4956 if remove and len(shape) > 1:
4957 del shape[0]
4958 else:
4959 shape.append(1)
4960 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4961 for i in range(len(shape)):
4962 shape[i] = shape[i] + rng.integers(1, 10)
4963
4964 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004965 all_dtypes = [
4966 DType.INT8,
4967 DType.INT16,
4968 DType.INT32,
4969 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004970 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004971 DType.FP16,
4972 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004973 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004974 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4975 outputDType = rng.choice(wrong_dtypes)
4976 else:
4977 outputDType = DType.INT32
4978
4979 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004980
4981 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004982 def conv2dOp(
4983 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4984 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004985
4986 # IFM: NHWC
4987 # Filter: OHWI
4988 # OFM: NHWC
4989
Kevin Cheng550ccc52021-03-03 11:21:43 -08004990 h = (
4991 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004992 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004993 + padding[0]
4994 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004995 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004996 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004997
Kevin Cheng550ccc52021-03-03 11:21:43 -08004998 w = (
4999 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005000 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005001 + padding[2]
5002 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005003 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005004 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005005
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005006 if error_name == ErrorIf.ConvOutputShapeMismatch:
5007 choices = [1, 2, 3]
5008 change = rng.choice(choices)
5009 # increment in multiples of stride to not hit non-integer error case
5010 if change in [1, 3]:
5011 h = h + (rng.choice(choices) * strides[0])
5012 if change in [2, 3]:
5013 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005014
Eric Kunzee5e26762020-10-13 16:11:07 -07005015 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5016
James Ward8b390432022-08-12 20:48:56 +01005017 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005018 # Pick some potentially correct output dtype if input type is incorrect
5019 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005020 else:
James Ward8b390432022-08-12 20:48:56 +01005021 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005022
5023 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005024 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005025 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005026 else:
5027 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005028 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005029 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005030
Kevin Cheng550ccc52021-03-03 11:21:43 -08005031 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005032
5033 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005034 def conv3dOp(
5035 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5036 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005037
5038 # IFM: NDHWC
5039 # Filter: ODHWI
5040 # OFM: NDHWC
5041
5042 d = (
5043 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005044 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005045 + padding[0]
5046 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005047 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005048 ) // strides[0] + 1
5049
5050 h = (
5051 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005052 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005053 + padding[2]
5054 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005055 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005056 ) // strides[1] + 1
5057
5058 w = (
5059 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005060 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005061 + padding[4]
5062 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005063 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005064 ) // strides[2] + 1
5065
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005066 if error_name == ErrorIf.ConvOutputShapeMismatch:
5067 choices = [1, 2, 3, 4]
5068 change = rng.choice(choices)
5069 # increment in multiples of stride to not hit non-integer error case
5070 if change in [1, 4]:
5071 d = d + (rng.choice(choices) * strides[0])
5072 if change in [2, 4]:
5073 h = h + (rng.choice(choices) * strides[1])
5074 if change in [3, 4]:
5075 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005076
Kevin Cheng1533b852021-09-01 12:51:58 -07005077 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5078
James Ward8b390432022-08-12 20:48:56 +01005079 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005080 # Pick some potentially correct output dtype if input type is incorrect
5081 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005082 else:
James Ward8b390432022-08-12 20:48:56 +01005083 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005084
5085 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005086 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005087 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005088 else:
5089 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005090 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005091 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005092
5093 return ser.addOutput(ofm_shape, out_dtype)
5094
5095 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005096 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005097 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005098 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005099 # IFM: NHWC
5100 # Filter: HWCM
5101 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005102
Kevin Cheng550ccc52021-03-03 11:21:43 -08005103 h = (
5104 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005105 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005106 + padding[0]
5107 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005108 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005109 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005110
Kevin Cheng550ccc52021-03-03 11:21:43 -08005111 w = (
5112 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005113 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005114 + padding[2]
5115 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005116 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005117 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005118
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005119 if error_name == ErrorIf.ConvOutputShapeMismatch:
5120 choices = [1, 2, 3]
5121 change = rng.choice(choices)
5122 # increment in multiples of stride to not hit non-integer error case
5123 if change in [1, 3]:
5124 h = h + (rng.choice(choices) * strides[0])
5125 if change in [2, 3]:
5126 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005127
Eric Kunzee5e26762020-10-13 16:11:07 -07005128 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5129
James Ward8b390432022-08-12 20:48:56 +01005130 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005131 # Pick some potentially correct output dtype if input type is incorrect
5132 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005133 else:
James Ward8b390432022-08-12 20:48:56 +01005134 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005135
5136 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005137 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005138 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005139 else:
5140 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005141 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005142 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005143
Kevin Cheng550ccc52021-03-03 11:21:43 -08005144 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005145
5146 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005147 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005148 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005149 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005150 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005151 h = 1
5152 w = 1
5153 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005154 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5155 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005156
5157 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005158 choices = [1, 2, 3]
5159 change = rng.choice(choices)
5160 # increment in multiples of stride to not hit non-integer error case
5161 if change in [1, 3]:
5162 h = h + (rng.choice(choices) * stride[0])
5163 if change in [2, 3]:
5164 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005165 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005166
5167 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005168 all_dtypes = [
5169 DType.INT8,
5170 DType.INT16,
5171 DType.INT32,
5172 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005173 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005174 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005175 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005176 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005177 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5178 outputDType = rng.choice(wrong_dtypes)
5179 else:
5180 outputDType = ifm.dtype
5181
5182 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005183
5184 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005185 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005186 # input: N, IC
5187 # filter: OC, IC
5188 # output: N, OC
5189
5190 output_shape = [input.shape[0], filter.shape[0]]
5191
James Ward8b390432022-08-12 20:48:56 +01005192 # Validated in arg_gen (also invalidated for ErrorIf)
5193 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005194
Kevin Cheng550ccc52021-03-03 11:21:43 -08005195 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005196
5197 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005198 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005199 # a: N, H, C
5200 # b: N, C, W
5201 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005202
Kevin Cheng2d60f002021-06-09 14:18:32 -07005203 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005204
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005205 if error_name == ErrorIf.WrongOutputType:
5206 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005207 incorrect_types = (
5208 DType.INT4,
5209 DType.INT8,
5210 DType.INT16,
5211 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005212 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005213 DType.FP16,
5214 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005215 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005216 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005217 incorrect_types = (
5218 DType.INT4,
5219 DType.INT8,
5220 DType.INT16,
5221 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005222 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005223 DType.FP16,
5224 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005225 )
James Ward24dbc422022-10-19 12:20:31 +01005226 elif (
5227 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5228 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005229 incorrect_types = (
5230 DType.INT4,
5231 DType.INT8,
5232 DType.INT16,
5233 DType.INT32,
5234 DType.INT48,
5235 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005236 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005237 elif error_name == ErrorIf.WrongInputType:
5238 # Pick some potentially correct output dtype if input type is incorrect
5239 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005240 else:
James Ward8b390432022-08-12 20:48:56 +01005241 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005242
Kevin Cheng550ccc52021-03-03 11:21:43 -08005243 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005244
5245 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005246 def concatOp(ser, rng, axis, inputs, error_name=None):
5247 input1 = inputs[0]
5248 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005249
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005250 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005251 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005252 if not (
5253 # unable to concat tensors of different ranks
5254 error_name == ErrorIf.ConcatInputRankMismatch
5255 # unable to concat tensors along an invalid axis
5256 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005257 ):
5258 for tensor in remaining_inputs:
5259 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005260
Matthew Haddon01c359d2021-10-15 16:30:48 +01005261 if error_name == ErrorIf.ConcatShapeSumMismatch:
5262 output_shape[axis] += rng.integers(5, 10)
5263
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005264 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005265 all_dtypes = {
5266 DType.INT8,
5267 DType.INT16,
5268 DType.INT32,
5269 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005270 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005271 DType.FP16,
5272 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005273 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005274 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5275 outputDType = rng.choice(wrong_dtypes)
5276 else:
5277 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005278
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005279 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005280
5281 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005282 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005283
5284 output_shape = a.shape.copy()
5285
5286 for i in range(len(output_shape)):
5287 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5288
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005289 if error_name == ErrorIf.PadOutputShapeMismatch:
5290 bad_dim = rng.choice(range(len(output_shape)))
5291 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005292 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005293 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005294
Matthew Haddone807aae2021-10-11 18:12:58 +01005295 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005296 all_dtypes = [
5297 DType.INT8,
5298 DType.INT16,
5299 DType.INT32,
5300 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005301 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005302 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005303 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005304 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005305 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5306 outputDType = rng.choice(wrong_dtypes)
5307 else:
5308 outputDType = a.dtype
5309
5310 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005311
5312 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005313 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005314 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005315
5316 if error_name == ErrorIf.WrongOutputType:
5317 all_dtypes = [
5318 DType.INT8,
5319 DType.INT16,
5320 DType.INT32,
5321 DType.INT48,
5322 DType.FP32,
5323 DType.FP16,
5324 DType.BF16,
5325 ]
5326 wrong_dtypes = list(set(all_dtypes))
5327 outputDType = rng.choice(wrong_dtypes)
5328 else:
5329 outputDType = DType.SHAPE
5330
5331 return ser.addOutput(output_shape, outputDType)
5332
5333 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005334 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005335 output_shape = shape.copy()
5336
Matthew Haddone807aae2021-10-11 18:12:58 +01005337 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5338 for i in range(len(output_shape)):
5339 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5340
5341 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005342 all_dtypes = [
5343 DType.INT8,
5344 DType.INT16,
5345 DType.INT32,
5346 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005347 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005348 DType.FP16,
5349 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005350 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005351 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5352 outputDType = rng.choice(wrong_dtypes)
5353 else:
5354 outputDType = a.dtype
5355
5356 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005357
5358 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005359 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005360
Matthew Haddone807aae2021-10-11 18:12:58 +01005361 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005362 all_dtypes = [
5363 DType.INT8,
5364 DType.INT16,
5365 DType.INT32,
5366 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005367 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005368 DType.FP16,
5369 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005370 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005371 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005372 outputDType = rng.choice(wrong_dtypes)
5373 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005374 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005375
Luke Huttona4e48ca2023-02-22 11:53:48 +00005376 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005377 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005378 for index in range(len(output_shape)):
5379 if output_shape[index] <= 2:
5380 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5381 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005382 output_shape[index] = output_shape[index] + rng.choice(
5383 [-2, -1, 1, 2]
5384 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005385 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5386 output_shape = input.shape.copy()
5387 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005388 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005389
5390 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005391
5392 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005393 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005394
5395 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005396 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005397
5398 for i in range(len(output_shape)):
5399 output_shape[i] = a.shape[i] * multiples[i]
5400
Luke Huttona4e48ca2023-02-22 11:53:48 +00005401 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005402 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005403
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005404 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005405 all_dtypes = [
5406 DType.INT8,
5407 DType.INT16,
5408 DType.INT32,
5409 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005410 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005411 DType.FP16,
5412 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005413 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005414 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5415 outputDType = rng.choice(wrong_dtypes)
5416 else:
5417 outputDType = a.dtype
5418
5419 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005420
5421 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005422 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005423 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005424
Kevin Cheng550ccc52021-03-03 11:21:43 -08005425 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005426
Luke Huttona4e48ca2023-02-22 11:53:48 +00005427 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005428 for i in range(len(output_shape)):
5429 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005430
Luke Huttona4e48ca2023-02-22 11:53:48 +00005431 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5432 for i in range(len(output_shape)):
5433 output_shape[i] += rng.integers(1, 10)
5434 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005435 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005436
Matthew Haddone807aae2021-10-11 18:12:58 +01005437 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005438 all_dtypes = [
5439 DType.INT8,
5440 DType.INT16,
5441 DType.INT32,
5442 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005443 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005444 DType.FP16,
5445 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005446 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005447 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5448 outputDType = rng.choice(wrong_dtypes)
5449 else:
5450 outputDType = a.dtype
5451
5452 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005453
5454 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005455 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005456 if error_name != ErrorIf.WrongRank:
5457 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005458 assert len(indices.shape) == 2
5459 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005460
Kevin Cheng77d0f762020-11-24 10:26:32 -08005461 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5462
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005463 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005464 all_dtypes = [
5465 DType.INT8,
5466 DType.INT16,
5467 DType.INT32,
5468 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005469 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005470 DType.FP16,
5471 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005472 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005473 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5474 outputDType = rng.choice(wrong_dtypes)
5475 else:
5476 outputDType = values.dtype
5477
5478 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005479
5480 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005481 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005482 if error_name != ErrorIf.WrongRank:
5483 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005484 assert len(indices.shape) == 2
5485 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005486 assert values_in.shape[0] == indices.shape[0] # N
5487 assert input.shape[1] == indices.shape[1] # W
5488 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005489
5490 output_shape = values_in.shape
5491
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005492 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005493 all_dtypes = [
5494 DType.INT8,
5495 DType.INT16,
5496 DType.INT32,
5497 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005498 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005499 DType.FP16,
5500 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005501 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005502 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5503 outputDType = rng.choice(wrong_dtypes)
5504 else:
5505 outputDType = values_in.dtype
5506
5507 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005508
5509 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005510 def tableOp(ser, rng, input, error_name=None):
5511 # Same shape as the input, dtype dependent on input dtype
5512 if error_name != ErrorIf.WrongInputType:
5513 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005514 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005515 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005516 wrong_dtypes = [
5517 DType.INT8,
5518 DType.INT16,
5519 DType.INT32,
5520 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005521 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005522 DType.FP16,
5523 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005524 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005525 wrong_dtypes.remove(output_dtype)
5526 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005527 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005528
5529 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005530 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005531 serializer,
5532 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005533 input,
5534 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005535 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005536 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005537 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005538 input_dtype,
5539 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005540 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005541 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005542 # Calculate OH, OW
5543 scale_y_n = scale[0]
5544 scale_y_d = scale[1]
5545 scale_x_n = scale[2]
5546 scale_x_d = scale[3]
5547 if error_name == ErrorIf.ScaleSmallerEqualZero:
5548 scale_y_n = max(scale_y_n, 1)
5549 scale_y_d = max(scale_y_d, 1)
5550 scale_x_n = max(scale_x_n, 1)
5551 scale_x_d = max(scale_x_d, 1)
5552
5553 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5554 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5555
5556 if error_name is not None:
5557 # Make sure the output tensor is valid, which can occur when
5558 # scale, offset or border have been changed for ERROR_IFs
5559 oh = max(oh, 1)
5560 ow = max(ow, 1)
5561 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005562 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5563 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005564
5565 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5566 choices = [1, 2, 3]
5567 change = rng.choice(choices)
5568 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5569 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005570 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005571 oh -= scale_y_d
5572 assert oh > 0 # Should have been caught in agResize
5573 else:
5574 oh += scale_y_d
5575 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005576 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005577 ow -= scale_x_d
5578 assert ow > 0 # Should have been caught in agResize
5579 else:
5580 ow += scale_x_d
5581
Matthew Haddon848efb42021-09-09 12:30:53 +01005582 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005583 output_dims = [
5584 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005585 oh,
5586 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005587 input.shape[0],
5588 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005589 elif error_name == ErrorIf.BatchMismatch:
5590 output_dims = [
5591 input.shape[0] + rng.integers(1, 10),
5592 oh,
5593 ow,
5594 input.shape[3],
5595 ]
5596 elif error_name == ErrorIf.ChannelMismatch:
5597 output_dims = [
5598 input.shape[0],
5599 oh,
5600 ow,
5601 input.shape[3] + rng.integers(1, 10),
5602 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005603 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005604 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005605
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005606 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005607
5608 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005609 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005610 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005611
5612 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005613 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005614 if error_name == ErrorIf.ConvOutputShapeMismatch:
5615 choices = [1, 2, 3]
5616 change = rng.choice(choices)
5617 if change in [1, 3]:
5618 output_shape[1] = output_shape[1] + rng.choice(choices)
5619 if change in [2, 3]:
5620 output_shape[2] = output_shape[2] + rng.choice(choices)
5621
James Ward8b390432022-08-12 20:48:56 +01005622 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005623 # Pick some potentially correct output dtype if input type is incorrect
5624 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005625 else:
James Ward8b390432022-08-12 20:48:56 +01005626 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005627
5628 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005629 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005630 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005631 else:
5632 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005633 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005634 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005635
Kevin Cheng550ccc52021-03-03 11:21:43 -08005636 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005637
5638 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005639 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5640 outputs = []
5641
5642 assert ifm1.dtype == ifm2.dtype
5643 input_dtype = ifm1.dtype
5644
5645 if error_name != ErrorIf.FFTInputShapeMismatch:
5646 assert ifm1.shape == ifm2.shape
5647
5648 input_shape = ifm1.shape
5649 if error_name != ErrorIf.WrongRank:
5650 assert len(input_shape) == 3
5651
5652 output_shape = input_shape.copy()
5653 output_dtype = input_dtype
5654
5655 if error_name == ErrorIf.WrongOutputType:
5656 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005657 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005658 output_dtype = rng.choice(wrong_dtypes)
5659 elif error_name == ErrorIf.BatchMismatch:
5660 output_shape[0] += rng.integers(1, 10)
5661 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5662 modify_dim = rng.choice([1, 2])
5663 output_shape[modify_dim] += rng.integers(1, 10)
5664
5665 outputs.append(serializer.addOutput(output_shape, output_dtype))
5666 outputs.append(serializer.addOutput(output_shape, output_dtype))
5667 return outputs
5668
5669 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005670 def rfft2dOp(serializer, rng, value, error_name=None):
5671 outputs = []
5672
5673 input_shape = value.shape
5674 if error_name != ErrorIf.WrongRank:
5675 assert len(input_shape) == 3
5676
5677 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5678
5679 output_dtype = value.dtype
5680 if error_name == ErrorIf.WrongOutputType:
5681 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005682 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005683 output_dtype = rng.choice(wrong_dtypes)
5684 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005685 output_shape[0] += rng.integers(1, 10)
5686 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5687 modify_dim = rng.choice([1, 2])
5688 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005689
5690 outputs.append(serializer.addOutput(output_shape, output_dtype))
5691 outputs.append(serializer.addOutput(output_shape, output_dtype))
5692 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005693
5694 @staticmethod
5695 def addShapeOp(ser, rng, a, b, error_name=None):
5696 if error_name != ErrorIf.RankMismatch:
5697 assert len(a.shape) == len(b.shape)
5698 assert a.dtype == b.dtype
5699
5700 shape = []
5701 for i in range(len(a.shape)):
5702 shape.append(a.shape[i])
5703
5704 fuzz_idx = rng.integers(0, len(a.shape))
5705 if error_name == ErrorIf.DimensionMismatch:
5706 shape[fuzz_idx] += 1
5707
5708 if error_name == ErrorIf.WrongOutputType:
5709 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5710 outputDType = rng.choice(wrong_dtypes)
5711 else:
5712 outputDType = DType.SHAPE
5713 return ser.addOutput(shape, outputDType)