blob: 39b064d509ae4c74f2ca277dc5e1540bcc0d00d5 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
79 for dtype in (DType.FP32, DType.FP16, DType.BF16):
80 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000198 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 return np.int64(self.rng.integers(low=low, high=high, size=shape))
200 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
201 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
202
203 if dtype == DType.FP16:
204 return np.float16(f_tensor)
205 else:
206 f32_tensor = np.float32(f_tensor)
207 if dtype == DType.BF16:
208 # Floor the last 16 bits of each f32 value
209 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
210 else:
211 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100213 # All other integer types
214 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
Kevin Cheng989cb052021-04-28 16:29:44 -0700216 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700217 placeholders = []
218
Kevin Cheng989cb052021-04-28 16:29:44 -0700219 assert len(shape_list) == len(dtype_list)
220
Jeremy Johnson1271c442023-09-05 11:39:26 +0100221 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700222 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 if not self.args.lazy_data_gen:
224 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700225 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700226
227 return placeholders
228
Kevin Cheng989cb052021-04-28 16:29:44 -0700229 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700230 consts = []
231
Kevin Cheng989cb052021-04-28 16:29:44 -0700232 assert len(shape_list) == len(dtype_list)
233
Jeremy Johnson1271c442023-09-05 11:39:26 +0100234 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700235 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100236 if not self.args.lazy_data_gen:
237 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700238 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700239
240 return consts
241
242 def makeShape(self, rank):
243 if self.targetted_shape:
244 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 return np.int32(
246 self.rng.integers(
247 low=self.args.tensor_shape_range[0],
248 high=self.args.tensor_shape_range[1],
249 size=rank,
250 )
251 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700252
253 def setTargetShape(self, shape):
254 self.targetted_shape = shape
255
256 def randInt(self, low=0, high=256):
257 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
258
259 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 low, high = self.getDTypeRange(dtype)
261
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100262 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100263 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100264 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100265 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100266 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100267 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
268 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700269 elif dtype == DType.BOOL:
270 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000271 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700272 # Special size
273 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 return np.int32(self.rng.integers(low, high, size=1))[0]
276
277 def shapeStr(self, shape):
278
279 sStr = []
280 # Convert to strings
281 for i in shape:
282 sStr.append(str(i))
283
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100286 def typeStr(self, dtype):
287 if isinstance(dtype, list) or isinstance(dtype, tuple):
288 assert len(dtype) >= 2
289 strs = [self.typeStr(t) for t in dtype]
290 # Limit types to the first 2 as the 3rd is the accumulator
291 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100293 if dtype in gtu.DTYPE_ATTRIBUTES:
294 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700295 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100296 raise Exception(
297 "Unknown dtype, cannot convert to string: {}".format(dtype)
298 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100300 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100301 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100302 if dtype in gtu.DTYPE_ATTRIBUTES:
303 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100305 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700306
Luke Hutton57287132023-02-06 14:54:18 +0000307 def constrictBatchSize(self, shape):
308 # Limit the batch size unless an explicit target shape set
309 if self.args.max_batch_size and not self.args.target_shapes:
310 shape[0] = min(shape[0], self.args.max_batch_size)
311 return shape
312
James Ward30124a82023-02-02 14:56:33 +0000313 def makeDimension(self):
314 return self.randInt(
315 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
316 )
317
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100318 def tensorComplianceMetaData(
319 self, op, inputType, argsDict, outputTensor, errorName
320 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000321 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
322 UNSUPPORTED_NON_FP32_INPUT_OPS = (
323 Op.MATMUL,
324 Op.CONV2D,
325 Op.FULLY_CONNECTED,
326 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000327 Op.TRANSPOSE_CONV2D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000328 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100329 if (
330 errorName
331 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000332 or (
333 not gtu.dtypeIsSupportedByCompliance(inputType)
334 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
335 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100336 ):
337 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100338 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100339
Jeremy Johnson1271c442023-09-05 11:39:26 +0100340 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100341 compliance_tens = {
342 "mode": None,
343 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
344 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
345 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100346 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
347 mode = gtu.ComplianceMode.DOT_PRODUCT
348 compliance_tens["dot_product_info"] = {
349 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100350 "ks": int(argsDict["ksb"])
351 if "ksb" in argsDict
352 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100353 }
354 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
355 mode = gtu.ComplianceMode.FP_SPECIAL
356 elif "compliance" in op and "ulp" in op["compliance"]:
357 mode = gtu.ComplianceMode.ULP
358 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
359 elif op["op"] == Op.REDUCE_PRODUCT:
360 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000361 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000362 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000363 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000364 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
365 compliance_tens["abs_error_info"] = {
366 "lower_bound": op["compliance"]["abs_error_lower_bound"]
367 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100368 else:
369 mode = gtu.ComplianceMode.EXACT
370 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
371
372 return compliance_tens
373
374 # Build Op functions
375 # Create the output tensor (calling OutputShaper as needed)
376 # Do final tweaks to attributes (if necessary for errorIf)
377 # Add Op into graph
378 # Return resulting tensor information or BuildInfo
379
380 class BuildInfo:
381 """Enhanced build information containing result tensor and associated compliance dict."""
382
383 def __init__(self, resultTensor, complianceDict):
384 self.resultTensor = resultTensor
385 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700386
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000387 def build_unary(
388 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
389 ):
390 assert len(inputs) == 1
391 a = inputs[0]
392 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100393
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000394 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100395
396 # Ensure new output type has correct qinfo
397 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000398 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000399 qinfo = [
400 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000401 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000402 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100403
404 # Invalidate Input/Output list for error if checks.
405 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000406 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100407 pCount, cCount = op["operands"]
408 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000409 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
410 self, error_name, input_list, output_list
411 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100412
Les Bell729b0352021-11-24 10:28:21 +0000413 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100414 self.ser,
415 validator_fcns,
416 error_name,
417 op=op,
418 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000419 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000420 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000421 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100422 input_list=input_list,
423 output_list=output_list,
424 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000425 ):
426 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100427
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000428 attr = None
429 if op["op"] == Op.NEGATE:
430 attr = ts.TosaSerializerAttribute()
431 attr.NegateAttribute(qinfo[0], qinfo[1])
432
433 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000434
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000435 compliance = self.tensorComplianceMetaData(
436 op, a.dtype, args_dict, result_tensor, error_name
437 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000438 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700439
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000440 def build_binary_broadcast(
441 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
442 ):
443 assert len(inputs) == 2
444 a, b = inputs
445 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000446 self.ser, self.rng, a, b, error_name
447 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100448
449 # Invalidate Input/Output list for error if checks.
450 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000451 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100452 pCount, cCount = op["operands"]
453 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000454 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
455 self, error_name, input_list, output_list
456 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100457
Les Bell729b0352021-11-24 10:28:21 +0000458 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100459 self.ser,
460 validator_fcns,
461 error_name,
462 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000463 input1=a,
464 input2=b,
465 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000466 output_dtype=result_tensor.dtype,
467 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100468 input_list=input_list,
469 output_list=output_list,
470 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000471 ):
472 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100473
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000474 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000475
Jeremy Johnson9a758382023-11-07 16:27:35 +0000476 compliance = self.tensorComplianceMetaData(
477 op, a.dtype, args_dict, result_tensor, error_name
478 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000479
480 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700481
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100482 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700483 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000484 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700485 return result_tens
486
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000487 def build_arithmetic_right_shift(
488 self, op, a, b, round, validator_fcns=None, error_name=None
489 ):
490 result_tens = OutputShaper.binaryBroadcastOp(
491 self.ser, self.rng, a, b, error_name
492 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100493
494 # Invalidate Input/Output list for error if checks.
495 input_list = [a.name, b.name]
496 output_list = [result_tens.name]
497 pCount, cCount = op["operands"]
498 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000499 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
500 self, error_name, input_list, output_list
501 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100502
Les Bell729b0352021-11-24 10:28:21 +0000503 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100504 self.ser,
505 validator_fcns,
506 error_name,
507 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000508 input1=a,
509 input2=b,
510 input_dtype=a.dtype,
511 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000512 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100513 input_list=input_list,
514 output_list=output_list,
515 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000516 ):
517 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800518
519 attr = ts.TosaSerializerAttribute()
520 attr.ArithmeticRightShiftAttribute(round)
521
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000522 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800523 return result_tens
524
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100525 def build_mul(
526 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
527 ):
528 assert len(inputs) == 2
529 a, b = inputs
530 shift = args_dict["shift"]
531
532 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000533 self.ser, self.rng, a, b, error_name
534 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700535
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100536 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100537 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100538 result_tensor.setDtype(DType.INT32)
539
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100540 if error_name == ErrorIf.WrongOutputType:
541 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
542 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100543 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100544
545 # Invalidate Input/Output list for error if checks.
546 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100547 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100548 pCount, cCount = op["operands"]
549 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000550 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
551 self, error_name, input_list, output_list
552 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100553
Les Bell729b0352021-11-24 10:28:21 +0000554 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100555 self.ser,
556 validator_fcns,
557 error_name,
558 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000559 input1=a,
560 input2=b,
561 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100562 output_dtype=result_tensor.dtype,
563 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100564 input_list=input_list,
565 output_list=output_list,
566 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000567 ):
568 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700569
Kevin Chengaee1fac2020-11-11 13:54:06 -0800570 attr = ts.TosaSerializerAttribute()
571 attr.MulAttribute(shift)
572
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000573 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100574
575 compliance = self.tensorComplianceMetaData(
576 op, a.dtype, args_dict, result_tensor, error_name
577 )
578
579 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700580
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100581 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
582 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700583
Kevin Chengfe392ce2021-10-18 21:51:55 +0000584 attr = ts.TosaSerializerAttribute()
585 attr.TableAttribute(table)
586
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100587 # Invalidate Input/Output list for error if checks.
588 input_list = [a.name]
589 output_list = [result_tens.name]
590 pCount, cCount = op["operands"]
591 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000592 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
593 self, error_name, input_list, output_list
594 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100595
Les Bell729b0352021-11-24 10:28:21 +0000596 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100597 self.ser,
598 validator_fcns,
599 error_name,
600 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000601 input_shape=a.shape,
602 input_dtype=a.dtype,
603 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000604 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100605 input_list=input_list,
606 output_list=output_list,
607 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000608 ):
609 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100610
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000611 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700612
613 return result_tens
614
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000615 def build_select(
616 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
617 ):
618 assert len(inputs) == 3
619 cond, a, b = inputs
620
621 result_tensor = OutputShaper.selectOp(
622 self.ser, self.rng, cond, a, b, error_name
623 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100624
625 # Invalidate Input/Output list for error if checks.
626 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000627 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100628 pCount, cCount = op["operands"]
629 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000630 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
631 self, error_name, input_list, output_list
632 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100633
Les Bell729b0352021-11-24 10:28:21 +0000634 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100635 self.ser,
636 validator_fcns,
637 error_name,
638 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000639 input1=cond,
640 input2=a,
641 input3=b,
642 input_shape=a.shape,
643 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000644 output_dtype=result_tensor.dtype,
645 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100646 input_list=input_list,
647 output_list=output_list,
648 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000649 ):
650 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100651
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000652 self.ser.addOperator(
653 op["op"],
654 input_list,
655 output_list,
656 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000657 compliance = self.tensorComplianceMetaData(
658 op, a.dtype, args_dict, result_tensor, error_name
659 )
660
661 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700662
Jeremy Johnsona0150012023-11-15 15:52:06 +0000663 def build_comparison(
664 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
665 ):
666 assert len(inputs) == 2
667 a, b = inputs
668
669 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000670 self.ser, self.rng, a, b, error_name
671 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100672
673 # Invalidate Input/Output list for error if checks.
674 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000675 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100676 pCount, cCount = op["operands"]
677 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000678 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
679 self, error_name, input_list, output_list
680 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100681
Les Bell729b0352021-11-24 10:28:21 +0000682 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100683 self.ser,
684 validator_fcns,
685 error_name,
686 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000687 input1=a,
688 input2=b,
689 input_shape=a.shape,
690 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000691 output_shape=result_tensor.shape,
692 output_dtype=result_tensor.dtype,
693 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100694 input_list=input_list,
695 output_list=output_list,
696 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000697 ):
698 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100699
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000700 self.ser.addOperator(
701 op["op"],
702 input_list,
703 output_list,
704 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000705
706 compliance = self.tensorComplianceMetaData(
707 op, a.dtype, args_dict, result_tensor, error_name
708 )
709 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700710
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000711 def build_argmax(
712 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
713 ):
714 assert len(inputs) == 1
715 a = inputs[0]
716 axis = args_dict["axis"]
717 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100718
719 # Invalidate Input/Output list for error if checks.
720 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000721 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100722 pCount, cCount = op["operands"]
723 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000724 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
725 self, error_name, input_list, output_list
726 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100727
Les Bell729b0352021-11-24 10:28:21 +0000728 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100729 self.ser,
730 validator_fcns,
731 error_name,
732 op=op,
733 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000734 input_shape=a.shape,
735 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000736 output_shape=result_tensor.shape,
737 output_dtype=result_tensor.dtype,
738 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100739 input_list=input_list,
740 output_list=output_list,
741 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000742 ):
743 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700744
745 attr = ts.TosaSerializerAttribute()
746 attr.AxisAttribute(axis)
747
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000748 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000749
750 compliance = self.tensorComplianceMetaData(
751 op, inputs[0].dtype, args_dict, result_tensor, error_name
752 )
753 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700754
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000755 def build_pool2d(
756 self,
757 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100758 inputs,
759 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000760 validator_fcns=None,
761 error_name=None,
762 qinfo=None,
763 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100764 assert len(inputs) == 1
765 input = inputs[0]
766 # max_pool has no accum_dtype
767 accum_dtype = (
768 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
769 )
770 stride = args_dict["stride"]
771 pad = args_dict["pad"]
772 kernel = args_dict["kernel"]
773
Jeremy Johnson0601f802023-11-08 16:28:09 +0000774 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000775 self.ser, self.rng, input, kernel, stride, pad, error_name
776 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100777
778 # Ensure new output type has correct qinfo
779 if error_name == ErrorIf.WrongInputType:
780 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000781 qinfo = [
782 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000783 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000784 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100785
786 # Invalidate Input/Output list for error if checks.
787 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000788 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100789 pCount, cCount = op["operands"]
790 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000791 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
792 self, error_name, input_list, output_list
793 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100794
Les Bell729b0352021-11-24 10:28:21 +0000795 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100796 self.ser,
797 validator_fcns,
798 error_name,
799 op=op,
800 input_shape=input.shape,
801 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000802 output_shape=result_tensor.shape,
803 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100804 kernel=kernel,
805 stride=stride,
806 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000807 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000808 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100809 input_list=input_list,
810 output_list=output_list,
811 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000812 ):
813 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700814
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000815 if qinfo is None:
816 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700817
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000818 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100819 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000820
821 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700822
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100823 compliance = self.tensorComplianceMetaData(
824 op, inputs[0].dtype, args_dict, result_tensor, error_name
825 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100826
827 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100828
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000829 def build_conv2d(
830 self,
831 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100832 inputs,
833 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000834 validator_fcns=None,
835 error_name=None,
836 qinfo=None,
837 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100838 assert len(inputs) == 3
839 ifm, filter, bias = inputs
840 accum_dtype = args_dict["acc_type"]
841 strides = args_dict["stride"]
842 padding = args_dict["pad"]
843 dilations = args_dict["dilation"]
844
Kevin Cheng550ccc52021-03-03 11:21:43 -0800845 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100846 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100847 self.ser,
848 self.rng,
849 ifm,
850 filter,
851 accum_dtype,
852 strides,
853 padding,
854 dilations,
855 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000856 )
857
858 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000859 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
860 DType.INT8,
861 DType.UINT8,
862 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000863 qinfo = [
864 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100865 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000866 ]
Les Bell0e027d42021-11-09 14:42:14 +0000867
868 # Invalidate Input/Output list for error_if checks.
869 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100870 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000871 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000872 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
873 self, error_name, input_list, output_list
874 )
Les Bell0e027d42021-11-09 14:42:14 +0000875
Les Bell729b0352021-11-24 10:28:21 +0000876 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000877 self.ser,
878 validator_fcns,
879 error_name,
880 op=op,
881 input_dtype=ifm.dtype,
882 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100883 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000884 qinfo=qinfo,
885 input_list=input_list,
886 num_operands=num_operands,
887 output_list=output_list,
888 pad=padding,
889 stride=strides,
890 dilation=dilations,
891 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100892 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100893 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000894 ):
895 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700896
Tai Lyd3797f02023-11-15 23:06:19 +0000897 # TODO - Test local_bound, for now set local bound attribute to False
898 local_bound = False
899
Eric Kunzee5e26762020-10-13 16:11:07 -0700900 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000901 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700902
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000903 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100904
905 compliance = self.tensorComplianceMetaData(
906 op, ifm.dtype, args_dict, result_tensor, error_name
907 )
908
909 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700910
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000911 def build_conv3d(
912 self,
913 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100914 inputs,
915 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000916 validator_fcns=None,
917 error_name=None,
918 qinfo=None,
919 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100920 assert len(inputs) == 3
921 ifm, filter, bias = inputs
922 accum_dtype = args_dict["acc_type"]
923 strides = args_dict["stride"]
924 padding = args_dict["pad"]
925 dilations = args_dict["dilation"]
926
Kevin Cheng1533b852021-09-01 12:51:58 -0700927 assert len(padding) == 6
928 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100929 self.ser,
930 self.rng,
931 ifm,
932 filter,
933 accum_dtype,
934 strides,
935 padding,
936 dilations,
937 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000938 )
939
940 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000941 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
942 DType.INT8,
943 DType.UINT8,
944 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000945 qinfo = [
946 TosaQuantGen.getZeroPoint(self, ifm.dtype),
947 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
948 ]
Les Bell0e027d42021-11-09 14:42:14 +0000949
950 # Invalidate Input/Output list for error_if checks.
951 input_list = [ifm.name, filter.name, bias.name]
952 output_list = [result_tens.name]
953 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000954 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
955 self, error_name, input_list, output_list
956 )
Les Bell0e027d42021-11-09 14:42:14 +0000957
Les Bell729b0352021-11-24 10:28:21 +0000958 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000959 self.ser,
960 validator_fcns,
961 error_name,
962 op=op,
963 input_dtype=ifm.dtype,
964 weight_dtype=filter.dtype,
965 output_dtype=result_tens.dtype,
966 qinfo=qinfo,
967 input_list=input_list,
968 num_operands=num_operands,
969 output_list=output_list,
970 pad=padding,
971 stride=strides,
972 dilation=dilations,
973 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100974 weight_shape=filter.shape,
975 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000976 ):
977 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700978
Tai Lyd3797f02023-11-15 23:06:19 +0000979 # TODO - Test local_bound, for now set local bound attribute to False
980 local_bound = False
981
Kevin Cheng1533b852021-09-01 12:51:58 -0700982 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000983 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700984
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000985 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700986 return result_tens
987
Kevin Cheng550ccc52021-03-03 11:21:43 -0800988 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000989 self,
990 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000991 inputs,
992 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000993 validator_fcns=None,
994 error_name=None,
995 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800996 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +0000997 assert len(inputs) == 3
998 ifm, filter, bias = inputs
999 accum_dtype = args_dict["acc_type"]
1000 strides = args_dict["stride"]
1001 out_pad = args_dict["pad"]
1002 output_shape = args_dict["out_shape"]
1003
TatWai Chong24594f52022-06-08 00:48:04 -07001004 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001005 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001006 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001007 )
Les Bell0e027d42021-11-09 14:42:14 +00001008
1009 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001010 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1011 DType.INT8,
1012 DType.UINT8,
1013 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001014 qinfo = [
1015 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001016 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001017 ]
Les Bell0e027d42021-11-09 14:42:14 +00001018
1019 # Invalidate Input/Output list for error_if checks.
1020 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001021 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001022 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001023 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1024 self, error_name, input_list, output_list
1025 )
Les Bell0e027d42021-11-09 14:42:14 +00001026
Les Bell729b0352021-11-24 10:28:21 +00001027 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001028 self.ser,
1029 validator_fcns,
1030 error_name,
1031 op=op,
1032 input_dtype=ifm.dtype,
1033 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001034 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001035 qinfo=qinfo,
1036 input_list=input_list,
1037 num_operands=num_operands,
1038 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001039 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001040 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001041 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001042 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001043 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001044 ):
1045 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001046
Tai Lyd3797f02023-11-15 23:06:19 +00001047 # TODO - Test local_bound, for now set local bound attribute to False
1048 local_bound = False
1049
Eric Kunzee5e26762020-10-13 16:11:07 -07001050 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001051 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001052 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001053 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001054
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001055 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001056
1057 compliance = self.tensorComplianceMetaData(
1058 op, ifm.dtype, args_dict, result_tensor, error_name
1059 )
1060
1061 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001062
Kevin Cheng550ccc52021-03-03 11:21:43 -08001063 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001064 self,
1065 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001066 inputs,
1067 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001068 validator_fcns=None,
1069 error_name=None,
1070 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001071 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001072 assert len(inputs) == 3
1073 ifm, filter, bias = inputs
1074 accum_dtype = args_dict["acc_type"]
1075 strides = args_dict["stride"]
1076 padding = args_dict["pad"]
1077 dilations = args_dict["dilation"]
1078
Jeremy Johnson4f931302024-01-04 17:05:24 +00001079 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001080 self.ser,
1081 self.rng,
1082 ifm,
1083 filter,
1084 accum_dtype,
1085 strides,
1086 padding,
1087 dilations,
1088 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001089 )
1090
1091 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001092 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1093 DType.INT8,
1094 DType.UINT8,
1095 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001096 qinfo = [
1097 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001098 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001099 ]
Les Bell0e027d42021-11-09 14:42:14 +00001100
1101 # Invalidate Input/Output list for error_if checks.
1102 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001103 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001104 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001105 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1106 self, error_name, input_list, output_list
1107 )
Les Bell0e027d42021-11-09 14:42:14 +00001108
Les Bell729b0352021-11-24 10:28:21 +00001109 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001110 self.ser,
1111 validator_fcns,
1112 error_name,
1113 op=op,
1114 input_dtype=ifm.dtype,
1115 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001116 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001117 qinfo=qinfo,
1118 input_list=input_list,
1119 num_operands=num_operands,
1120 output_list=output_list,
1121 pad=padding,
1122 stride=strides,
1123 dilation=dilations,
1124 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001125 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001126 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001127 ):
1128 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001129
Tai Lyd3797f02023-11-15 23:06:19 +00001130 # TODO - Test local_bound, for now set local bound attribute to False
1131 local_bound = False
1132
Eric Kunzee5e26762020-10-13 16:11:07 -07001133 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001134 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001135
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001136 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001137
1138 compliance = self.tensorComplianceMetaData(
1139 op, ifm.dtype, args_dict, result_tensor, error_name
1140 )
1141
1142 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001143
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001144 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001145 self,
1146 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001147 inputs,
1148 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001149 validator_fcns=None,
1150 error_name=None,
1151 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001152 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001153 assert len(inputs) == 3
1154 ifm, filter, bias = inputs
1155 accum_dtype = args_dict["acc_type"]
1156
1157 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001158 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001159 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001160
1161 # Invalidate Input/Output list for error if checks.
1162 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001163 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001164 pCount, cCount = op["operands"]
1165 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001166 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1167 self, error_name, input_list, output_list
1168 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001169
Les Bell729b0352021-11-24 10:28:21 +00001170 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001171 self.ser,
1172 validator_fcns,
1173 error_name,
1174 op=op,
1175 input_shape=ifm.shape,
1176 input_dtype=ifm.dtype,
1177 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001178 output_shape=result_tensor.shape,
1179 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001180 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001181 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001182 input_list=input_list,
1183 output_list=output_list,
1184 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001185 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001186 ):
1187 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001188
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001189 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001190 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001191
1192 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001193
1194 compliance = self.tensorComplianceMetaData(
1195 op, ifm.dtype, args_dict, result_tensor, error_name
1196 )
1197
1198 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001199
James Ward8b390432022-08-12 20:48:56 +01001200 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001201 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001202 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001203 assert len(inputs) == 2
1204 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001205 accum_dtype = args_dict["acc_type"]
1206 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001207 self.ser, self.rng, a, b, accum_dtype, error_name
1208 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001209
1210 # Invalidate Input/Output list for error if checks.
1211 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001212 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001213 pCount, cCount = op["operands"]
1214 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001215 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1216 self, error_name, input_list, output_list
1217 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001218
Les Bell729b0352021-11-24 10:28:21 +00001219 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001220 self.ser,
1221 validator_fcns,
1222 error_name,
1223 op=op,
1224 input_shape=a.shape,
1225 input_dtype=a.dtype,
1226 input2_shape=b.shape,
1227 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001228 output_shape=result_tensor.shape,
1229 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001230 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001231 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001232 input_list=input_list,
1233 output_list=output_list,
1234 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001235 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001236 ):
1237 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001238
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001239 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001240 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001241
1242 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001243
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001244 compliance = self.tensorComplianceMetaData(
1245 op, a.dtype, args_dict, result_tensor, error_name
1246 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001247
1248 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001249
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001250 def build_reduce(
1251 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1252 ):
1253 assert len(inputs) == 1
1254 a = inputs[0]
1255 axis = args_dict["axis"]
1256 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001257
1258 # Invalidate Input/Output list for error if checks.
1259 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001260 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001261 pCount, cCount = op["operands"]
1262 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001263 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1264 self, error_name, input_list, output_list
1265 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001266
Les Bell729b0352021-11-24 10:28:21 +00001267 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001268 self.ser,
1269 validator_fcns,
1270 error_name,
1271 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001272 axis=axis,
1273 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001274 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001275 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001276 output_dtype=result_tensor.dtype,
1277 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001278 input_list=input_list,
1279 output_list=output_list,
1280 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001281 ):
1282 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001283
1284 attr = ts.TosaSerializerAttribute()
1285 attr.AxisAttribute(axis)
1286
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001287 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001288
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001289 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1290 # Number of products - needed for compliance
1291 args_dict["n"] = a.shape[axis]
1292
1293 compliance = self.tensorComplianceMetaData(
1294 op, a.dtype, args_dict, result_tensor, error_name
1295 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001296
1297 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001298
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001299 def build_clamp(
1300 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1301 ):
1302 assert len(inputs) == 1
1303 a = inputs[0]
1304
1305 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001306
Jeremy Johnson18e26662021-07-22 16:15:29 +01001307 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001308
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001309 if error_name == ErrorIf.MaxSmallerMin:
1310 # Make sure the numbers are different to invoke this error
1311 while v[0] == v[1]:
1312 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1313 max_val = min(v)
1314 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001315 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001316 max_val = max(v)
1317 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001318
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001319 # Invalidate Input/Output list for error if checks.
1320 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001321 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001322 pCount, cCount = op["operands"]
1323 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001324 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1325 self, error_name, input_list, output_list
1326 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001327
Les Bell729b0352021-11-24 10:28:21 +00001328 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001329 self.ser,
1330 validator_fcns,
1331 error_name,
1332 op=op,
1333 max_val=max_val,
1334 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001335 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001336 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001337 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001338 output_dtype=result_tensor.dtype,
1339 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001340 input_list=input_list,
1341 output_list=output_list,
1342 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001343 ):
1344 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001345
1346 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001347 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1348 if a.dtype == DType.FP16:
1349 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1350 min_val = min_val.astype(np.float32)
1351 max_val = max_val.astype(np.float32)
1352
1353 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001354 else:
James Ward34071252022-12-07 15:48:47 +00001355 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001356
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001357 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001358
1359 compliance = self.tensorComplianceMetaData(
1360 op, a.dtype, args_dict, result_tensor, error_name
1361 )
1362
1363 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001364
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001365 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1366 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001367 attr = ts.TosaSerializerAttribute()
1368
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001369 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001370
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001371 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001372 return result_tens
1373
1374 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001375 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1376 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001377
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001378 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001379 return result_tens
1380
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001381 def build_activation(
1382 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1383 ):
1384 assert len(inputs) == 1
1385 a = inputs[0]
1386
1387 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001388
1389 # Invalidate Input/Output list for error if checks.
1390 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001391 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001392 pCount, cCount = op["operands"]
1393 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001394 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1395 self, error_name, input_list, output_list
1396 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001397
Les Bell729b0352021-11-24 10:28:21 +00001398 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001399 self.ser,
1400 validator_fcns,
1401 error_name,
1402 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001403 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001404 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001405 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001406 output_dtype=result_tensor.dtype,
1407 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001408 input_list=input_list,
1409 output_list=output_list,
1410 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001411 ):
1412 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001413
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001414 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001415
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001416 compliance = self.tensorComplianceMetaData(
1417 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001418 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001419
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001420 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001421
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001422 def build_concat(
1423 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1424 ):
Won Jeon74342e52024-01-09 00:34:40 +00001425 if op["op"] == Op.CONCAT_SHAPE:
1426 axis = 0
1427 else:
1428 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001429 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001430 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001431
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001432 result_tensor = OutputShaper.concatOp(
1433 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001434 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001435
Matthew Haddon818ab902021-07-27 09:12:49 +01001436 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001437 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001438 input_tensor_names.append(tensor.name)
1439
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001440 # Invalidate Input/Output list for error if checks.
1441 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001442 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001443 pCount, cCount = op["operands"]
1444 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001445 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1446 self, error_name, input_list, output_list
1447 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001448
Les Bell729b0352021-11-24 10:28:21 +00001449 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001450 self.ser,
1451 validator_fcns,
1452 error_name,
1453 op=op,
1454 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001455 input_shape=inputs[0].shape,
1456 output_shape=result_tensor.shape,
1457 input_dtype=inputs[0].dtype,
1458 output_dtype=result_tensor.dtype,
1459 inputs=inputs,
1460 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001461 input_list=input_list,
1462 output_list=output_list,
1463 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001464 ):
1465 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001466
Won Jeon74342e52024-01-09 00:34:40 +00001467 if op["op"] == Op.CONCAT:
1468 attr = ts.TosaSerializerAttribute()
1469 attr.AxisAttribute(axis)
1470 else:
1471 assert op["op"] == Op.CONCAT_SHAPE
1472 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001473 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001474
1475 compliance = self.tensorComplianceMetaData(
1476 op, inputs[0].dtype, args_dict, result_tensor, error_name
1477 )
1478
1479 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001480
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001481 def build_pad(
1482 self,
1483 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001484 inputs,
1485 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001486 validator_fcns=None,
1487 error_name=None,
1488 qinfo=None,
1489 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001490 assert len(inputs) == 1
1491 a = inputs[0]
1492 padding = args_dict["pad"]
1493 pad_const_int = args_dict["pad_const_int"]
1494 pad_const_float = args_dict["pad_const_fp"]
1495
1496 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001497
Kevin Chengfe392ce2021-10-18 21:51:55 +00001498 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001499 attr.PadAttribute(
1500 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1501 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001502
Matthew Haddone807aae2021-10-11 18:12:58 +01001503 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001504 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001505 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001506 pCount, cCount = op["operands"]
1507 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001508 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1509 self, error_name, input_list, output_list
1510 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001511
Les Bell729b0352021-11-24 10:28:21 +00001512 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001513 self.ser,
1514 validator_fcns,
1515 error_name,
1516 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001517 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001518 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001519 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001520 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001521 pad=padding,
1522 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001523 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001524 input_list=input_list,
1525 output_list=output_list,
1526 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001527 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001528 ):
1529 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001530
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001531 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001532
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001533 compliance = self.tensorComplianceMetaData(
1534 op, a.dtype, args_dict, result_tensor, error_name
1535 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001536
1537 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001538
Won Jeona21b2e82023-08-10 10:33:01 +00001539 def build_dim(
1540 self,
1541 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001542 inputs,
1543 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001544 validator_fcns=None,
1545 error_name=None,
1546 qinfo=None,
1547 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001548 assert len(inputs) == 1
1549 a = inputs[0]
1550 axis = args_dict["axis"]
1551 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001552
1553 # Invalidate Input/Output list for error if checks.
1554 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001555 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001556 pCount, cCount = op["operands"]
1557 num_operands = pCount + cCount
1558 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1559 self, error_name, input_list, output_list
1560 )
1561
1562 if not TosaErrorValidator.evValidateErrorIfs(
1563 self.ser,
1564 validator_fcns,
1565 error_name,
1566 op=op,
1567 axis=axis,
1568 input_shape=a.shape,
1569 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001570 output_shape=result_tensor.shape,
1571 output_dtype=result_tensor.dtype,
1572 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001573 input_list=input_list,
1574 output_list=output_list,
1575 num_operands=num_operands,
1576 ):
1577 return None
1578
1579 attr = ts.TosaSerializerAttribute()
1580 attr.AxisAttribute(axis)
1581
1582 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001583 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001584
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001585 def build_reshape(
1586 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1587 ):
Tai Ly8690a082023-12-18 20:40:24 +00001588 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001589 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001590 shape = inputs[1]
1591 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001592 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001593 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001594 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001595
1596 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001597 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001598 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001599 pCount, cCount = op["operands"]
1600 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001601 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1602 self, error_name, input_list, output_list
1603 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001604
Les Bell729b0352021-11-24 10:28:21 +00001605 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001606 self.ser,
1607 validator_fcns,
1608 error_name,
1609 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001610 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001611 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001612 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001613 output_dtype=result_tensor.dtype,
1614 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001615 input_list=input_list,
1616 output_list=output_list,
1617 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001618 ):
1619 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001620
Tai Ly8690a082023-12-18 20:40:24 +00001621 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001622
1623 compliance = self.tensorComplianceMetaData(
1624 op, a.dtype, args_dict, result_tensor, error_name
1625 )
1626
1627 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001628
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001629 def build_reverse(
1630 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1631 ):
1632 assert len(inputs) == 1
1633 a = inputs[0]
1634 axis = args_dict["axis"]
1635 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001636
1637 # Invalidate Input/Output list for error if checks.
1638 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001639 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001640 pCount, cCount = op["operands"]
1641 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001642 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1643 self, error_name, input_list, output_list
1644 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001645
Les Bell729b0352021-11-24 10:28:21 +00001646 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001647 self.ser,
1648 validator_fcns,
1649 error_name,
1650 op=op,
1651 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001652 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001653 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001654 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001655 output_dtype=result_tensor.dtype,
1656 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001657 input_list=input_list,
1658 output_list=output_list,
1659 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001660 ):
1661 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001662
1663 attr = ts.TosaSerializerAttribute()
1664 attr.AxisAttribute(axis)
1665
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001666 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001667 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001668
Matthew Haddone807aae2021-10-11 18:12:58 +01001669 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1670 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001671
Kevin Chengfe392ce2021-10-18 21:51:55 +00001672 attr = ts.TosaSerializerAttribute()
1673 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001674
Matthew Haddone807aae2021-10-11 18:12:58 +01001675 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001676 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001677 output_list = [result_tens.name]
1678 pCount, cCount = op["operands"]
1679 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001680 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1681 self, error_name, input_list, output_list
1682 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001683
Les Bell729b0352021-11-24 10:28:21 +00001684 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001685 self.ser,
1686 validator_fcns,
1687 error_name,
1688 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001689 input_shape=a.shape,
1690 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001691 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001692 input_dtype=a.dtype,
1693 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001694 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001695 input_list=input_list,
1696 output_list=output_list,
1697 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001698 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001699 ):
1700 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001701
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001702 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001703 return result_tens
1704
evacha017f7d4252024-01-24 12:08:09 +00001705 def build_slice(
1706 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1707 ):
1708 assert len(inputs) == 1
1709 a = inputs[0]
1710 start = args_dict["start"]
1711 size = args_dict["size"]
1712
1713 result_tensor = OutputShaper.sliceOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001714 self.ser, self.rng, a, start, size, error_name
1715 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001716
1717 # Invalidate Input/Output list for error if checks.
1718 input_list = [a.name]
evacha017f7d4252024-01-24 12:08:09 +00001719 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001720 pCount, cCount = op["operands"]
1721 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001722 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1723 self, error_name, input_list, output_list
1724 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001725
Les Bell729b0352021-11-24 10:28:21 +00001726 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001727 self.ser,
1728 validator_fcns,
1729 error_name,
1730 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001731 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001732 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001733 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001734 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001735 start=start,
1736 size=size,
evacha017f7d4252024-01-24 12:08:09 +00001737 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001738 input_list=input_list,
1739 output_list=output_list,
1740 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001741 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001742 ):
1743 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001744
1745 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001746 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001747
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001748 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha017f7d4252024-01-24 12:08:09 +00001749
1750 compliance = self.tensorComplianceMetaData(
1751 op, a.dtype, args_dict, result_tensor, error_name
1752 )
1753
1754 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001755
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001756 def build_tile(
1757 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1758 ):
Tai Ly8690a082023-12-18 20:40:24 +00001759 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001760 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001761 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001762 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001763 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001764 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001765 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001766
1767 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001768 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001769 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001770 pCount, cCount = op["operands"]
1771 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001772 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1773 self, error_name, input_list, output_list
1774 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001775
Les Bell729b0352021-11-24 10:28:21 +00001776 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001777 self.ser,
1778 validator_fcns,
1779 error_name,
1780 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001781 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001782 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001783 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001784 output_dtype=result_tensor.dtype,
1785 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001786 input_list=input_list,
1787 output_list=output_list,
1788 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001789 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001790 ):
1791 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001792
Tai Ly8690a082023-12-18 20:40:24 +00001793 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001794
1795 compliance = self.tensorComplianceMetaData(
1796 op, a.dtype, args_dict, result_tensor, error_name
1797 )
1798
1799 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001800
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001801 def build_gather(
1802 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1803 ):
1804 assert len(inputs) == 2
1805 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001806
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001807 result_tensor = OutputShaper.gatherOp(
1808 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001809 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001810
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001811 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001812 input_list = [values.name, indices.name]
1813 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001814 pCount, cCount = op["operands"]
1815 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001816 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1817 self, error_name, input_list, output_list
1818 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001819
Les Bell729b0352021-11-24 10:28:21 +00001820 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001821 self.ser,
1822 validator_fcns,
1823 error_name,
1824 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001825 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001826 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001827 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001828 output_dtype=result_tensor.dtype,
1829 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001830 input_list=input_list,
1831 output_list=output_list,
1832 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001833 ):
1834 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001835
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001836 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001837
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001838 compliance = self.tensorComplianceMetaData(
1839 op, values.dtype, args_dict, result_tensor, error_name
1840 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001841
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001842 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001843
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001844 def build_scatter(
1845 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1846 ):
1847 assert len(inputs) == 3
1848 values_in, indices, input = inputs
1849 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001850 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001851 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001852
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001853 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001854 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001855 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001856 pCount, cCount = op["operands"]
1857 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001858 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1859 self, error_name, input_list, output_list
1860 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001861
Les Bell729b0352021-11-24 10:28:21 +00001862 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001863 self.ser,
1864 validator_fcns,
1865 error_name,
1866 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001867 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001868 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001869 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001870 output_dtype=result_tensor.dtype,
1871 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001872 input_list=input_list,
1873 output_list=output_list,
1874 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001875 ):
1876 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001877
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001878 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001879
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001880 compliance = self.tensorComplianceMetaData(
1881 op, values_in.dtype, args_dict, result_tensor, error_name
1882 )
1883
1884 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001885
Kevin Cheng550ccc52021-03-03 11:21:43 -08001886 def build_resize(
1887 self,
1888 op,
1889 input,
1890 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001891 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001892 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001893 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001894 input_dtype,
1895 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001896 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001897 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001898 ):
1899 result_tens = OutputShaper.resizeOp(
1900 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001901 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001902 input,
1903 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001904 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001905 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001906 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001907 input_dtype,
1908 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001909 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001910 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001911
Matthew Haddon848efb42021-09-09 12:30:53 +01001912 # Invalidate Input/Output list for error if checks.
1913 input_list = [input.name]
1914 output_list = [result_tens.name]
1915 pCount, cCount = op["operands"]
1916 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001917 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1918 self, error_name, input_list, output_list
1919 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001920
Les Bell729b0352021-11-24 10:28:21 +00001921 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001922 self.ser,
1923 validator_fcns,
1924 error_name,
1925 op=op,
1926 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001927 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001928 input_dtype=input_dtype,
1929 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001930 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001931 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001932 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001933 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001934 input_list=input_list,
1935 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001936 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001937 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001938 ):
1939 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001940
Eric Kunzee5e26762020-10-13 16:11:07 -07001941 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001942
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001943 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001944
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001945 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001946 return result_tens
1947
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001948 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1949 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1950 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001951 self.ser.addOperator(
1952 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1953 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001954 return result_tens
1955
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001956 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001957 self.ser.addOutputTensor(val)
1958 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001959
1960 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00001961 def build_cast(
1962 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1963 ):
1964 assert len(inputs) == 1
1965 val = inputs[0]
1966 out_dtype = args_dict["out_type"]
1967
1968 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001969 self.ser, self.rng, val, out_dtype, error_name
1970 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001971
1972 # Invalidate Input/Output list for error if checks.
1973 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00001974 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001975 pCount, cCount = op["operands"]
1976 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001977 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1978 self, error_name, input_list, output_list
1979 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001980
Les Bell729b0352021-11-24 10:28:21 +00001981 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001982 self.ser,
1983 validator_fcns,
1984 error_name,
1985 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001986 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00001987 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001988 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00001989 output_dtype=result_tensor.dtype,
1990 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001991 input_list=input_list,
1992 output_list=output_list,
1993 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001994 ):
1995 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001996
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001997 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00001998
1999 compliance = self.tensorComplianceMetaData(
2000 op, val.dtype, args_dict, result_tensor, error_name
2001 )
2002
2003 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002004
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002005 def build_rescale(
2006 self,
2007 op,
2008 val,
2009 out_dtype,
2010 scale32,
2011 double_round,
2012 per_channel,
2013 validator_fcns,
2014 error_name,
2015 ):
2016 result_tens = OutputShaper.typeConversionOp(
2017 self.ser, self.rng, val, out_dtype, error_name
2018 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002019
2020 if per_channel:
2021 nc = val.shape[-1]
2022 else:
2023 nc = 1
2024
2025 in_type_width = self.typeWidth(val.dtype)
2026 out_type_width = self.typeWidth(out_dtype)
2027
Tai Ly8690a082023-12-18 20:40:24 +00002028 input_unsigned = False
2029 output_unsigned = False
2030
Kevin Cheng3a478572021-01-22 17:21:02 -08002031 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002032 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002033 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002034 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002035 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002036 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002037 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002038 elif error_name in [
2039 ErrorIf.InputZeroPointNotZero,
2040 ErrorIf.U16InputZeroPointNotValid,
2041 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002042 input_zp = self.randInt(-128, 128)
2043 if input_zp == 0:
2044 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002045 in_type_width += 1
2046 elif val.dtype == DType.UINT16:
2047 # Must come after ErrorIf.U16InputZeroPointNotValid check
2048 input_zp = self.rng.choice([0, 32768])
2049 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002050 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002051 else:
2052 input_zp = 0
2053
Kevin Cheng3a478572021-01-22 17:21:02 -08002054 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002055 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002056 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002057 elif out_dtype == DType.UINT8:
2058 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002059 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002060 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002061 elif error_name in [
2062 ErrorIf.OutputZeroPointNotZero,
2063 ErrorIf.U16OutputZeroPointNotValid,
2064 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002065 output_zp = self.randInt(-128, 128)
2066 if output_zp == 0:
2067 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002068 out_type_width += 1
2069 elif out_dtype == DType.UINT16:
2070 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2071 output_zp = self.rng.choice([0, 32768])
2072 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002073 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002074 else:
2075 output_zp = 0
2076
2077 # Calculate scale based on:
2078 # scale = a *(2^output_width)/(2^input_width))
2079
2080 a = np.float32(self.rng.random(size=[nc]))
2081 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2082
2083 if scale32:
2084 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002085 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002086 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2087 else:
2088 # Cap the scaling at 2^15 - 1 for scale16
2089 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2090
Kevin Cheng550ccc52021-03-03 11:21:43 -08002091 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002092
2093 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2094 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002095 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2096 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002097
2098 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002099 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2100 scale_arr[i], scale32
2101 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002102 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2103 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002104
Kevin Cheng550ccc52021-03-03 11:21:43 -08002105 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002106 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002107 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002108 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002109 assert val.placeholderFilename
2110 values = np.load(
2111 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2112 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002113 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2114 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2115 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002116 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2117 # Check we can safely convert to the expected dtype
2118 assert (
2119 val_adj.all() >= np.iinfo(values.dtype).min
2120 and val_adj.all() <= np.iinfo(values.dtype).max
2121 )
2122
2123 # Force casting to output datatype
2124 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2125
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002126 if not np.all(np.array_equal(values, val_adj)):
2127 # Values changed so overwrite file with new values
2128 np.save(
2129 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2130 val_adj,
2131 False,
2132 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002133
Matthew Haddonc2025212021-10-08 21:21:05 +01002134 # Invalidate Input/Output list for error if checks.
2135 input_list = [val.name]
2136 output_list = [result_tens.name]
2137 pCount, cCount = op["operands"]
2138 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002139 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2140 self, error_name, input_list, output_list
2141 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002142
2143 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002144 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002145 self.ser,
2146 validator_fcns,
2147 error_name,
2148 op=op,
2149 input_dtype=val.dtype,
2150 output_dtype=out_dtype,
2151 input_shape=val.shape,
2152 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002153 scale32=scale32,
2154 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002155 input_list=input_list,
2156 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002157 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002158 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002159 ):
2160 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002161
Eric Kunzee5e26762020-10-13 16:11:07 -07002162 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002163 attr.RescaleAttribute(
2164 input_zp,
2165 output_zp,
2166 multiplier_arr,
2167 shift_arr,
2168 scale32,
2169 double_round,
2170 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002171 input_unsigned,
2172 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002173 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002174
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002175 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002176 return result_tens
2177
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002178 def _get_condition_tensor(self, op, cond, error_name):
2179 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002180 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002181 else:
2182 cond_type = DType.BOOL
2183 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2184 choice = self.rng.choice([1, 2])
2185 if choice == 1:
2186 cond_shape = [2]
2187 else:
2188 cond_shape = [1, 2]
2189 else:
2190 # Must be of size 1 (rank 0)
2191 cond_shape = []
2192 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2193 return cond_tens
2194
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002195 def build_cond_if_const(
2196 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2197 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002198 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002199 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002200 # and fill them with const nodes for the body.
2201
2202 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002203 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002204
2205 # Make then/else tensors
2206 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002207
2208 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002209 if error_name in [
2210 ErrorIf.CondIfOutputListThenGraphMismatch,
2211 ErrorIf.CondIfOutputListElseGraphMismatch,
2212 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002213 incorrect_shape = deepcopy(then_tens.shape)
2214 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002215 incorrect_shape[i] += (
2216 self.rng.choice([-3, -2, 2, 3])
2217 if incorrect_shape[i] > 3
2218 else self.rng.choice([1, 2, 4])
2219 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002220 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2221
Jeremy Johnson18e26662021-07-22 16:15:29 +01002222 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2223 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002224
2225 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002226 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002227
2228 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002229 then_block = "THEN_BLOCK"
2230 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002231 attr = ts.TosaSerializerAttribute()
2232 attr.CondIfAttribute(then_block, else_block)
2233
2234 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002235 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002236
Jerry Ge9e94af82022-10-27 09:57:00 -07002237 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002238 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002239 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2240 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2241 else:
2242 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002243 self.ser.addOutputTensor(then_tens)
2244
Jerry Ge9e94af82022-10-27 09:57:00 -07002245 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002246 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2247 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2248 else:
2249 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002250 self.ser.addOutputTensor(else_tens)
2251
Les Bell729b0352021-11-24 10:28:21 +00002252 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002253 self.ser,
2254 validator_fcns,
2255 error_name,
2256 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002257 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002258 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002259 ):
2260 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002261
Eric Kunzee5e26762020-10-13 16:11:07 -07002262 return result_tens
2263
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002264 def build_cond_if_binary(
2265 self, op, a, b, cond, validator_fcns=None, error_name=None
2266 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002267 # For cond_if with a binary op in the then/else blocks, take a and b and
2268 # alternately add or subtract them based on the condition
2269
2270 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002271 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002272
Kevin Cheng550ccc52021-03-03 11:21:43 -08002273 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002274
2275 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002276 then_block = "THEN_BLOCK"
2277 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002278 attr = ts.TosaSerializerAttribute()
2279 attr.CondIfAttribute(then_block, else_block)
2280
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002281 if error_name in [
2282 ErrorIf.CondIfInputListThenGraphMismatch,
2283 ErrorIf.CondIfInputListElseGraphMismatch,
2284 ErrorIf.CondIfOutputListElseGraphMismatch,
2285 ErrorIf.CondIfOutputListThenGraphMismatch,
2286 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002287 incorrect_shape = a.shape.copy()
2288 for i in range(len(incorrect_shape)):
2289 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2290 incorrect_block_input = deepcopy(a)
2291 incorrect_block_input.shape = incorrect_shape
2292
Eric Kunzee5e26762020-10-13 16:11:07 -07002293 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002294 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002295 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002296 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002297
James Ward24dbc422022-10-19 12:20:31 +01002298 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002299 then_op, else_op = Op.ADD, Op.SUB
2300 elif a.dtype in (DType.INT8, DType.INT16):
2301 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2302 else:
2303 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002304
Les Bell6040b4d2021-10-11 12:50:31 +01002305 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002306 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002307 if (
2308 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2309 and block == then_block
2310 ) or (
2311 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2312 and block == else_block
2313 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002314 self.ser.addInputTensor(incorrect_block_input)
2315 self.ser.addInputTensor(b)
2316 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002317 elif (
2318 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2319 and block == then_block
2320 ) or (
2321 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2322 and block == else_block
2323 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002324 self.ser.addInputTensor(a)
2325 self.ser.addInputTensor(b)
2326 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2327 else:
2328 self.ser.addInputTensor(a)
2329 self.ser.addInputTensor(b)
2330 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002331 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002332
Les Bell729b0352021-11-24 10:28:21 +00002333 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002334 self.ser,
2335 validator_fcns,
2336 error_name,
2337 op=op,
2338 a=a,
2339 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002340 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002341 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002342 ):
2343 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002344
Eric Kunzee5e26762020-10-13 16:11:07 -07002345 return result_tens
2346
Matthew Haddon630c17c2021-10-14 15:05:41 +01002347 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002348 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002349
Kevin Cheng550ccc52021-03-03 11:21:43 -08002350 cond_block = "COND_BLOCK"
2351 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002352
2353 attr = ts.TosaSerializerAttribute()
2354 attr.WhileLoopAttribute(cond_block, body_block)
2355
2356 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002357 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002358 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002359 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002360
2361 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002362 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2363 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002364 if error_name == ErrorIf.InputListOutputListMismatch:
2365 incorrect_acc = deepcopy(acc)
2366 for i in range(len(incorrect_acc.shape)):
2367 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2368 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2369 else:
2370 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002371
2372 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002373 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002374 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002375 [iter.name, a.name, acc.name],
2376 [iter_out.name, a_out.name, acc_out.name],
2377 attr,
2378 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002379 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002380
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002381 if error_name in [
2382 ErrorIf.InputListCondGraphMismatch,
2383 ErrorIf.InputListBodyGraphInputMismatch,
2384 ErrorIf.InputListBodyGraphOutputMismatch,
2385 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002386 incorrect_iter = deepcopy(iter)
2387 for i in range(len(incorrect_iter.shape)):
2388 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2389 if len(incorrect_iter.shape) == 0:
2390 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2391
2392 incorrect_acc = deepcopy(acc)
2393 for i in range(len(incorrect_acc.shape)):
2394 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2395
Eric Kunzee5e26762020-10-13 16:11:07 -07002396 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002397 self.ser.addBasicBlock(cond_block)
2398
Matthew Haddon630c17c2021-10-14 15:05:41 +01002399 if error_name == ErrorIf.InputListCondGraphMismatch:
2400 self.ser.addInputTensor(incorrect_iter)
2401 self.ser.addInputTensor(a)
2402 self.ser.addInputTensor(incorrect_acc)
2403 else:
2404 self.ser.addInputTensor(iter)
2405 self.ser.addInputTensor(a)
2406 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002407 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002408
2409 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002410 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002411 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002412 cond_type = DType.BOOL
2413 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2414 choice = self.rng.choice([1, 2])
2415 if choice == 1:
2416 cond_shape = [3]
2417 else:
2418 cond_shape = [1, 2]
2419 else:
2420 cond_shape = []
2421 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002422
Kevin Cheng550ccc52021-03-03 11:21:43 -08002423 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002424
2425 # BODY block (input: a, acc, iter, output: a, acc, iter)
2426 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002427 self.ser.addBasicBlock(body_block)
2428
Matthew Haddon630c17c2021-10-14 15:05:41 +01002429 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2430 self.ser.addInputTensor(incorrect_iter)
2431 self.ser.addInputTensor(a)
2432 self.ser.addInputTensor(incorrect_acc)
2433 else:
2434 self.ser.addInputTensor(iter)
2435 self.ser.addInputTensor(a)
2436 self.ser.addInputTensor(acc)
2437
Kevin Cheng550ccc52021-03-03 11:21:43 -08002438 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002439
2440 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002441 iter_body_out = self.ser.addIntermediate(
2442 incorrect_iter.shape, incorrect_iter.dtype
2443 )
2444 acc_body_out = self.ser.addIntermediate(
2445 incorrect_acc.shape, incorrect_acc.dtype
2446 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002447 else:
2448 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2449 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2450
Eric Kunzee5e26762020-10-13 16:11:07 -07002451 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2452 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2453 self.ser.addOutputTensor(iter_body_out)
2454 self.ser.addOutputTensor(a)
2455 self.ser.addOutputTensor(acc_body_out)
2456
Les Bell729b0352021-11-24 10:28:21 +00002457 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002458 self.ser,
2459 validator_fcns,
2460 error_name,
2461 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002462 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002463 ):
2464 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002465
Eric Kunzee5e26762020-10-13 16:11:07 -07002466 return acc_out
2467
Luke Hutton57287132023-02-06 14:54:18 +00002468 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002469 self,
2470 op,
2471 val1,
2472 val2,
2473 inverse,
2474 validator_fcns=None,
2475 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002476 ):
2477 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2478
2479 input_names = [val1.name, val2.name]
2480 pCount, cCount = op["operands"]
2481 num_operands = pCount + cCount
2482
2483 output_names = [res.name for res in results]
2484 output_shapes = [res.shape for res in results]
2485 output_dtypes = [res.dtype for res in results]
2486
2487 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2488 self, error_name, input_names, output_names
2489 )
2490
2491 if not TosaErrorValidator.evValidateErrorIfs(
2492 self.ser,
2493 validator_fcns,
2494 error_name,
2495 op=op,
2496 inverse=inverse,
2497 input1=val1,
2498 input2=val2,
2499 input_shape=val1.shape,
2500 input_dtype=val1.dtype,
2501 output_shape=output_shapes,
2502 output_dtype=output_dtypes,
2503 result_tensors=results,
2504 input_list=input_names,
2505 output_list=output_names,
2506 num_operands=num_operands,
2507 ):
2508 return None
2509
Tai Lyd3797f02023-11-15 23:06:19 +00002510 # TODO - Test local_bound, for now set local bound attribute to False
2511 local_bound = False
2512
Luke Hutton57287132023-02-06 14:54:18 +00002513 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002514 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002515
2516 self.ser.addOperator(op["op"], input_names, output_names, attr)
2517 return results
2518
Tai Lyd3797f02023-11-15 23:06:19 +00002519 def build_rfft2d(
2520 self,
2521 op,
2522 val,
2523 validator_fcns=None,
2524 error_name=None,
2525 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002526 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2527
2528 input_names = [val.name]
2529 pCount, cCount = op["operands"]
2530 num_operands = pCount + cCount
2531
2532 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002533 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002534 output_dtypes = [res.dtype for res in results]
2535
2536 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2537 self, error_name, input_names, output_names
2538 )
2539
2540 if not TosaErrorValidator.evValidateErrorIfs(
2541 self.ser,
2542 validator_fcns,
2543 error_name,
2544 op=op,
2545 input_shape=val.shape,
2546 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002547 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002548 output_dtype=output_dtypes,
2549 result_tensors=results,
2550 input_list=input_names,
2551 output_list=output_names,
2552 num_operands=num_operands,
2553 ):
2554 return None
2555
Tai Lyd3797f02023-11-15 23:06:19 +00002556 # TODO - Test local_bound, for now set local bound attribute to False
2557 local_bound = False
2558
2559 attr = ts.TosaSerializerAttribute()
2560 attr.RFFTAttribute(local_bound)
2561
2562 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002563 return results
2564
Won Jeon74342e52024-01-09 00:34:40 +00002565 def build_shape_op(
2566 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2567 ):
2568 assert len(inputs) == 2
2569 a, b = inputs
2570
2571 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2572
2573 # Invalidate Input/Output list for error if checks.
2574 input_list = [a.name, b.name]
2575 output_list = [result_tensor.name]
2576 pCount, cCount = op["operands"]
2577 num_operands = pCount + cCount
2578 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2579 self, error_name, input_list, output_list
2580 )
2581
2582 if not TosaErrorValidator.evValidateErrorIfs(
2583 self.ser,
2584 validator_fcns,
2585 error_name,
2586 op=op,
2587 input1=a,
2588 input2=b,
2589 input_shape=a.shape,
2590 input_dtype=a.dtype,
2591 output_shape=result_tensor.shape,
2592 output_dtype=result_tensor.dtype,
2593 result_tensors=[result_tensor],
2594 input_list=input_list,
2595 output_list=output_list,
2596 num_operands=num_operands,
2597 ):
2598 return None
2599
2600 self.ser.addOperator(
2601 op["op"],
2602 input_list,
2603 output_list,
2604 )
2605 compliance = self.tensorComplianceMetaData(
2606 op, a.dtype, args_dict, result_tensor, error_name
2607 )
2608
2609 return TosaTestGen.BuildInfo(result_tensor, compliance)
2610
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002611 def create_filter_lists(
2612 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2613 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002614 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2615 default_test_rank_range = range(1, 5)
2616 if not shapeFilter:
2617 shapeFilter = [None]
2618
2619 # Calculate the filters based on what is requested and what the operator allows
2620 rmin, rmax = op["rank"]
2621 if rankFilter is not None:
2622 cleanRankFilter = []
2623 # Ensure rankFilter values are allowed by operator
2624 for rank in rankFilter:
2625 if rank >= rmin and rank <= rmax:
2626 cleanRankFilter.append(rank)
2627 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002628 # Ensure default behaviour is bounded by default range or by operator,
2629 # whichever is the smaller range of ranks.
2630 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002631 cleanRankFilter = (
2632 opRankRange
2633 if len(opRankRange) <= len(default_test_rank_range)
2634 else default_test_rank_range
2635 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002636 else:
2637 cleanRankFilter = range(rmin, rmax + 1)
2638
2639 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002640
Matthew Haddon1c00b712021-10-01 15:51:03 +01002641 if dtypeFilter is not None:
2642 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002643 # Create list of operator dtypes filtered by requested dtypes
2644 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002645 if dtype in dtypeFilter or (
2646 isinstance(dtype, list) and dtype[0] in dtypeFilter
2647 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002648 cleanDtypeFilter.append(dtype)
2649 else:
2650 cleanDtypeFilter = dtypes
2651
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002652 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002653 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002654 "shapeFilter": shapeFilter,
2655 "rankFilter": cleanRankFilter,
2656 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002657 }
2658 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002659 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002660 if validator is not None:
2661 validator_info = validator(check=False, op=op)
2662 else:
2663 return None
2664
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002665 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002666
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002667 # Set parameters as required
2668 if error_arguments["rank"] is not None:
2669 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002670 else:
2671 rankFilter = cleanRankFilter
2672
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002673 if error_arguments["dtype"] is not None:
2674 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002675 else:
2676 dtypeFilter = cleanDtypeFilter
2677
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002678 if error_arguments["shape"] is not None:
2679 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002680 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002681 shapeFilter = shapeFilter[
2682 :2
2683 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002684
2685 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002686 "shapeFilter": shapeFilter,
2687 "rankFilter": rankFilter,
2688 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002689 }
2690 return filterDict
2691
Kevin Cheng550ccc52021-03-03 11:21:43 -08002692 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002693 self,
2694 opName,
2695 shapeFilter=[None],
2696 rankFilter=None,
2697 dtypeFilter=None,
2698 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002699 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002700
2701 try:
2702 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002703 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002704 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002705
2706 # Initialize a new random number generator
2707 self.rng = np.random.default_rng(self.random_seed)
2708
Jeremy Johnson1271c442023-09-05 11:39:26 +01002709 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002710
Eric Kunzee5e26762020-10-13 16:11:07 -07002711 # Test list consists of a tuple of:
2712 # (opName, testNameStr, dtype, shapeList, argumentsList)
2713 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002714 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002715 error_if_validators = op["error_if_validators"]
2716 else:
2717 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002718
Matthew Haddon1c00b712021-10-01 15:51:03 +01002719 for validator in error_if_validators:
2720 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002721 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002722 else:
2723 error_name = None
2724
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002725 filterDict = self.create_filter_lists(
2726 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2727 )
2728 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002729 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002730 cleanRankFilter = filterDict["rankFilter"]
2731 cleanDtypeFilter = filterDict["dtypeFilter"]
2732 cleanShapeFilter = filterDict["shapeFilter"]
2733 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002734
2735 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002736 for t in cleanDtypeFilter:
2737 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002738 # Filter out by rank
2739 if shape is not None and len(shape) != r:
2740 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002741 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002742 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002743
Matthew Haddon74567092021-07-16 15:38:20 +01002744 shapeStr = self.shapeStr(shapeList[0])
2745 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002746
Matthew Haddon74567092021-07-16 15:38:20 +01002747 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2748 argList = []
2749 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002750 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002751 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002752 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002753
Matthew Haddon74567092021-07-16 15:38:20 +01002754 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002755 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002756 if argStr:
2757 testStr = "{}_{}_{}_{}".format(
2758 opName, shapeStr, typeStr, argStr
2759 )
2760 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002761 testStr = "{}_{}_{}".format(
2762 opName, shapeStr, typeStr
2763 )
2764 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002765 if argStr:
2766 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2767 opName, error_name, shapeStr, typeStr, argStr
2768 )
2769 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002770 testStr = "{}_ERRORIF_{}_{}_{}".format(
2771 opName, error_name, shapeStr, typeStr
2772 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002773
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002774 testList.append(
2775 (opName, testStr, t, error_name, shapeList, args)
2776 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002777
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002778 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002779 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2780 if "invalid_test_validators" in op:
2781 invalid_test_validators = op["invalid_test_validators"]
2782 clean_testList = []
2783 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002784 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002785 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002786 if validator_fcn(
2787 opName=test[0],
2788 input_dtype=test[2],
2789 shapeList=test[4],
2790 args=test[5],
2791 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002792 remove_test = True
2793 if not remove_test:
2794 clean_testList.append(test)
2795 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002796
2797 return testList
2798
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002799 def serializeTest(
2800 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2801 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002802 try:
2803 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002804 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002805 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002806
Jeremy Johnson0c716862023-04-13 17:18:19 +01002807 if self.args.verbose:
2808 print(f"Creating {testStr}")
2809
Eric Kunzee5e26762020-10-13 16:11:07 -07002810 # Create a serializer
2811 self.createSerializer(opName, testStr)
2812
Jeremy Johnson1271c442023-09-05 11:39:26 +01002813 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002814 if "error_if_validators" in op:
2815 error_if_validators = op["error_if_validators"]
2816 else:
2817 error_if_validators = None
2818
Kevin Cheng550ccc52021-03-03 11:21:43 -08002819 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002820 num_operands = pCount + cCount
2821
2822 if isinstance(dtype_or_dtypeList, list):
2823 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00002824 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002825 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002826 else:
2827 dtypeList = [dtype_or_dtypeList] * (num_operands)
2828
Won Jeon74342e52024-01-09 00:34:40 +00002829 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002830 assert (
2831 len(shapeList) == num_operands
2832 ), "shapeList length {} must match number of operands {}".format(
2833 len(shapeList), num_operands
2834 )
2835 assert (
2836 len(dtypeList) == num_operands
2837 ), "dtypeList length {} must match number of operands {}".format(
2838 len(dtypeList), num_operands
2839 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002840
2841 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002842 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002843 except KeyError:
2844 qgen = None
2845
2846 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002847
Matthew Haddon1c00b712021-10-01 15:51:03 +01002848 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002849 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002850 else:
2851 qinfo = None
2852
Jeremy Johnson1271c442023-09-05 11:39:26 +01002853 # Extra meta data for the desc.json
2854 tensMeta = {}
2855
2856 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002857 if isinstance(testArgs, dict):
2858 # New interface with args info in dictionary
2859 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002860 assert "dg_type" in argsDict
2861 tvgInfo = tvgen_fcn(
2862 self, opName, dtypeList, shapeList, argsDict, error_name
2863 )
2864 if tvgInfo.dataGenDict:
2865 tensMeta["data_gen"] = tvgInfo.dataGenDict
2866 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002867
2868 result = build_fcn(
2869 self,
2870 op,
2871 tens,
2872 argsDict,
2873 validator_fcns=error_if_validators,
2874 error_name=error_name,
2875 qinfo=qinfo,
2876 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002877 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002878 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002879 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002880
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002881 try:
2882 if error_if_validators is None:
2883 if qinfo is not None:
2884 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2885 else:
2886 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002887 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002888 if qinfo is not None:
2889 result = build_fcn(
2890 self,
2891 op,
2892 *tens,
2893 *testArgs,
2894 validator_fcns=error_if_validators,
2895 error_name=error_name,
2896 qinfo=qinfo,
2897 )
2898 else:
2899 result = build_fcn(
2900 self,
2901 op,
2902 *tens,
2903 *testArgs,
2904 validator_fcns=error_if_validators,
2905 error_name=error_name,
2906 )
2907 except TypeError as e:
2908 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2909 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002910
Jeremy Johnson1271c442023-09-05 11:39:26 +01002911 if result:
Les Bell729b0352021-11-24 10:28:21 +00002912 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002913 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2914 # Add the compliance meta data
2915 # NOTE: This currently expects only one result output
2916 tensMeta["compliance"] = {
2917 "version": "0.1",
2918 "tensors": {result.resultTensor.name: result.complianceDict},
2919 }
2920 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002921 else:
2922 # The test is not valid
2923 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002924
Eric Kunzee5e26762020-10-13 16:11:07 -07002925 def createDynamicOpLists(self):
2926
Jeremy Johnson00423432022-09-12 17:27:37 +01002927 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2928 # Already created these lists (can occur when class is initialized more than once)
2929 return
2930
Eric Kunzee5e26762020-10-13 16:11:07 -07002931 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002932 if not self.args.level8k:
2933 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2934 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2935 else:
2936 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2937 KERNELS_2D = [[1, bigK], [bigK, 2]]
2938 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002939
Kevin Cheng1533b852021-09-01 12:51:58 -07002940 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002941 testName = "conv2d_{}x{}".format(k[0], k[1])
2942 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2943 self.TOSA_OP_LIST[testName]["filter"] = k
2944 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002945
Kevin Cheng550ccc52021-03-03 11:21:43 -08002946 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2947 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2948 "depthwise_conv2d_TEMPLATE"
2949 ].copy()
2950 self.TOSA_OP_LIST[testName]["filter"] = k
2951 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002952
Kevin Cheng550ccc52021-03-03 11:21:43 -08002953 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2954 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2955 "transpose_conv2d_TEMPLATE"
2956 ].copy()
2957 self.TOSA_OP_LIST[testName]["filter"] = k
2958 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002959
Kevin Cheng1533b852021-09-01 12:51:58 -07002960 for k in KERNELS_3D:
2961 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2962 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2963 self.TOSA_OP_LIST[testName]["filter"] = k
2964 self.TOSA_OP_LIST[testName]["template"] = False
2965
Eric Kunzee5e26762020-10-13 16:11:07 -07002966 # Delete any templates after having created any dynamic ops
2967 # This is a two-pass operation because it's bad practice to delete
2968 # keys from dictionaries while iterating
2969 keyList = []
2970 for k in self.TOSA_OP_LIST:
2971 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002972 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002973 keyList.append(k)
2974 continue
2975 except KeyError:
2976 pass
2977
2978 for k in keyList:
2979 del self.TOSA_OP_LIST[k]
2980
2981 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002982 """Fill in default fields for ops if they aren't already specified.
2983 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002984 for op in self.TOSA_OP_LIST:
2985
2986 # Required fields
2987 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002988 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002989 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002990 raise Exception(
2991 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2992 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002993
2994 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002995 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002996 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002997 raise Exception(
2998 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2999 op
3000 )
3001 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003002
3003 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003004 _ = self.TOSA_OP_LIST[op]["types"]
3005 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003006 raise Exception(
3007 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3008 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003009
3010 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003011 _ = self.TOSA_OP_LIST[op]["op"]
3012 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003013 raise Exception(
3014 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3015 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003016
3017 # Put in default rank range, if missing
3018 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003019 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003020 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003021 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003022
3023 # Tensor operator list
3024 # 'op': op name
3025 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003026 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3027 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003028 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3029 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003030 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003031
Kevin Cheng550ccc52021-03-03 11:21:43 -08003032 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003033 TYPE_INT_FP = [
3034 DType.INT8,
3035 DType.INT16,
3036 DType.INT32,
3037 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003038 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003039 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003040 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003041
Kevin Cheng550ccc52021-03-03 11:21:43 -08003042 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003043 TYPE_FI32 = [
3044 DType.FP32,
3045 DType.FP16,
3046 DType.BF16,
3047 DType.INT32,
3048 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003049 TYPE_FIB = [
3050 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003051 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003052 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003053 DType.INT8,
3054 DType.INT16,
3055 DType.INT32,
3056 DType.BOOL,
3057 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003058 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003059
James Ward24dbc422022-10-19 12:20:31 +01003060 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07003061
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003062 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003063 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003064 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003065 [DType.INT8, DType.INT8, DType.INT32],
3066 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003067 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003068 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003069 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003070 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003071 ]
3072
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003073 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003074
3075 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003076 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003077 "argmax": {
3078 "op": Op.ARGMAX,
3079 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003080 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003081 "build_fcn": (
3082 build_argmax,
3083 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003084 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003085 TosaArgGen.agAxis,
3086 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003087 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003088 "error_if_validators": (
3089 TosaErrorValidator.evAxisSmallerZero,
3090 TosaErrorValidator.evAxisLargerRank,
3091 TosaErrorValidator.evArgmaxOutputRankMismatch,
3092 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3093 TosaErrorValidator.evWrongRank,
3094 TosaErrorValidator.evWrongInputType,
3095 TosaErrorValidator.evWrongOutputType,
3096 TosaErrorValidator.evWrongInputList,
3097 TosaErrorValidator.evWrongOutputList,
3098 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003099 "data_gen": {
3100 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3101 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003102 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003103 "avg_pool2d": {
3104 "op": Op.AVG_POOL2D,
3105 "operands": (1, 0),
3106 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003107 "build_fcn": (
3108 build_pool2d,
3109 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003110 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003111 TosaArgGen.agPooling,
3112 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003113 "qgen": TosaQuantGen.qgUnary,
3114 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003115 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003116 "error_if_validators": (
3117 TosaErrorValidator.evKernelSmallerOne,
3118 TosaErrorValidator.evStrideSmallerOne,
3119 TosaErrorValidator.evPadSmallerZero,
3120 TosaErrorValidator.evWrongRank,
3121 TosaErrorValidator.evWrongInputType,
3122 TosaErrorValidator.evWrongOutputType,
3123 TosaErrorValidator.evWrongInputList,
3124 TosaErrorValidator.evWrongOutputList,
3125 TosaErrorValidator.evInputZeroPointNotZero,
3126 TosaErrorValidator.evOutputZeroPointNotZero,
3127 TosaErrorValidator.evPadLargerEqualKernel,
3128 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003129 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003130 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003131 "data_gen": {
3132 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3133 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003134 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003135 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003136 "conv2d_TEMPLATE": {
3137 "op": Op.CONV2D,
3138 "operands": (1, 2),
3139 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003140 "build_fcn": (
3141 build_conv2d,
3142 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003143 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003144 TosaArgGen.agConv,
3145 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003146 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003147 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003148 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3149 "error_if_validators": (
3150 TosaErrorValidator.evWrongInputType,
3151 TosaErrorValidator.evWrongOutputType,
3152 TosaErrorValidator.evWrongInputList,
3153 TosaErrorValidator.evWrongOutputList,
3154 TosaErrorValidator.evInputZeroPointNotZero,
3155 TosaErrorValidator.evWeightZeroPointNotZero,
3156 TosaErrorValidator.evPadSmallerZero,
3157 TosaErrorValidator.evStrideSmallerOne,
3158 TosaErrorValidator.evDilationSmallerOne,
3159 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003160 TosaErrorValidator.evConvOutputShapeMismatch,
3161 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003162 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003163 "data_gen": {
3164 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3165 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003166 "template": True,
3167 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003168 # Templated operator. Filled in by createDynamicOpLists
3169 "conv3d_TEMPLATE": {
3170 "op": Op.CONV3D,
3171 "operands": (1, 2),
3172 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003173 "build_fcn": (
3174 build_conv3d,
3175 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003176 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003177 TosaArgGen.agConv,
3178 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003179 "qgen": TosaQuantGen.qgConv,
3180 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003181 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3182 "error_if_validators": (
3183 TosaErrorValidator.evWrongInputType,
3184 TosaErrorValidator.evWrongOutputType,
3185 TosaErrorValidator.evWrongInputList,
3186 TosaErrorValidator.evWrongOutputList,
3187 TosaErrorValidator.evInputZeroPointNotZero,
3188 TosaErrorValidator.evWeightZeroPointNotZero,
3189 TosaErrorValidator.evPadSmallerZero,
3190 TosaErrorValidator.evStrideSmallerOne,
3191 TosaErrorValidator.evDilationSmallerOne,
3192 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003193 TosaErrorValidator.evConvOutputShapeMismatch,
3194 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003195 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003196 "template": True,
3197 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003198 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003199 "depthwise_conv2d_TEMPLATE": {
3200 "op": Op.DEPTHWISE_CONV2D,
3201 "operands": (1, 2),
3202 "filter": [1, 1],
3203 "rank": (4, 4),
3204 "build_fcn": (
3205 build_depthwise_conv2d,
3206 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003207 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003208 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003209 ),
3210 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003211 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003212 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3213 "error_if_validators": (
3214 TosaErrorValidator.evWrongInputType,
3215 TosaErrorValidator.evWrongOutputType,
3216 TosaErrorValidator.evWrongInputList,
3217 TosaErrorValidator.evWrongOutputList,
3218 TosaErrorValidator.evInputZeroPointNotZero,
3219 TosaErrorValidator.evWeightZeroPointNotZero,
3220 TosaErrorValidator.evPadSmallerZero,
3221 TosaErrorValidator.evStrideSmallerOne,
3222 TosaErrorValidator.evDilationSmallerOne,
3223 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003224 TosaErrorValidator.evConvOutputShapeMismatch,
3225 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003226 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003227 "data_gen": {
3228 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3229 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003230 "template": True,
3231 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003232 "fully_connected": {
3233 "op": Op.FULLY_CONNECTED,
3234 "operands": (1, 2),
3235 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003236 "build_fcn": (
3237 build_fully_connected,
3238 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003239 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003240 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003241 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003242 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003243 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003244 "error_if_validators": (
3245 TosaErrorValidator.evInputZeroPointNotZero,
3246 TosaErrorValidator.evWeightZeroPointNotZero,
3247 TosaErrorValidator.evWrongRank,
3248 TosaErrorValidator.evWrongInputType,
3249 TosaErrorValidator.evWrongOutputType,
3250 TosaErrorValidator.evWrongInputList,
3251 TosaErrorValidator.evWrongOutputList,
3252 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003253 "data_gen": {
3254 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3255 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003256 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003257 "matmul": {
3258 "op": Op.MATMUL,
3259 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003260 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003261 "build_fcn": (
3262 build_matmul,
3263 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003264 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003265 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003266 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003267 "qgen": TosaQuantGen.qgMatmul,
3268 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003269 "error_if_validators": (
3270 TosaErrorValidator.evInputZeroPointNotZero,
3271 TosaErrorValidator.evWrongRank,
3272 TosaErrorValidator.evWrongInputType,
3273 TosaErrorValidator.evWrongOutputType,
3274 TosaErrorValidator.evWrongInputList,
3275 TosaErrorValidator.evWrongOutputList,
3276 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003277 "data_gen": {
3278 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003279 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003280 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003281 "max_pool2d": {
3282 "op": Op.MAX_POOL2D,
3283 "operands": (1, 0),
3284 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003285 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003286 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003287 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003288 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003289 TosaArgGen.agPooling,
3290 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003291 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003292 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003293 "error_if_validators": (
3294 TosaErrorValidator.evKernelSmallerOne,
3295 TosaErrorValidator.evStrideSmallerOne,
3296 TosaErrorValidator.evPadSmallerZero,
3297 TosaErrorValidator.evWrongRank,
3298 TosaErrorValidator.evWrongInputType,
3299 TosaErrorValidator.evWrongOutputType,
3300 TosaErrorValidator.evWrongInputList,
3301 TosaErrorValidator.evWrongOutputList,
3302 TosaErrorValidator.evPadLargerEqualKernel,
3303 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003304 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003305 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003306 "data_gen": {
3307 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3308 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003309 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003310 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003311 "transpose_conv2d_TEMPLATE": {
3312 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003313 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003314 "rank": (4, 4),
3315 "build_fcn": (
3316 build_transpose_conv2d,
3317 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003318 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003319 TosaArgGen.agTransposeConv2D,
3320 ),
3321 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003322 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003323 "invalid_test_validators": (
3324 TosaInvalidValidator.ivHeightWidthInvalid,
3325 TosaInvalidValidator.ivNonPositiveOutputShape,
3326 ),
3327 "error_if_validators": (
3328 TosaErrorValidator.evWrongInputType,
3329 TosaErrorValidator.evWrongOutputType,
3330 TosaErrorValidator.evWrongInputList,
3331 TosaErrorValidator.evWrongOutputList,
3332 TosaErrorValidator.evInputZeroPointNotZero,
3333 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003334 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003335 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003336 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003337 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003338 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003339 "data_gen": {
3340 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3341 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003342 "template": True,
3343 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003344 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003345 "clamp": {
3346 "op": Op.CLAMP,
3347 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003348 "build_fcn": (
3349 build_clamp,
3350 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003351 TosaTensorValuesGen.tvgLazyGenDefault,
3352 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003353 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003354 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003355 "error_if_validators": (
3356 TosaErrorValidator.evMaxSmallerMin,
3357 TosaErrorValidator.evWrongInputType,
3358 TosaErrorValidator.evWrongOutputType,
3359 TosaErrorValidator.evWrongInputList,
3360 TosaErrorValidator.evWrongOutputList,
3361 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003362 "data_gen": {
3363 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3364 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003365 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003366 "sigmoid": {
3367 "op": Op.SIGMOID,
3368 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003369 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003370 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003371 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003372 TosaTensorValuesGen.tvgLazyGenDefault,
3373 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003374 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003375 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003376 "error_if_validators": (
3377 TosaErrorValidator.evWrongInputType,
3378 TosaErrorValidator.evWrongOutputType,
3379 TosaErrorValidator.evWrongInputList,
3380 TosaErrorValidator.evWrongOutputList,
3381 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003382 "data_gen": {
3383 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3384 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003385 },
3386 "tanh": {
3387 "op": Op.TANH,
3388 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003389 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003390 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003391 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003392 TosaTensorValuesGen.tvgLazyGenDefault,
3393 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003394 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003395 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003396 "error_if_validators": (
3397 TosaErrorValidator.evWrongInputType,
3398 TosaErrorValidator.evWrongOutputType,
3399 TosaErrorValidator.evWrongInputList,
3400 TosaErrorValidator.evWrongOutputList,
3401 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003402 "data_gen": {
3403 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3404 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003405 "compliance": {
3406 "abs_error_lower_bound": 0.5,
3407 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003408 },
Won Jeon78155c62023-06-10 00:20:04 +00003409 "erf": {
3410 "op": Op.ERF,
3411 "operands": (1, 0),
3412 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003413 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003414 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003415 TosaTensorValuesGen.tvgLazyGenDefault,
3416 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003417 ),
3418 "types": TYPE_FP,
3419 "error_if_validators": (
3420 TosaErrorValidator.evWrongInputType,
3421 TosaErrorValidator.evWrongOutputType,
3422 TosaErrorValidator.evWrongInputList,
3423 TosaErrorValidator.evWrongOutputList,
3424 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003425 "data_gen": {
3426 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3427 },
3428 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003429 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003430 # Elementwise Binary Operators
3431 "add": {
3432 "op": Op.ADD,
3433 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003434 "build_fcn": (
3435 build_binary_broadcast,
3436 TosaTensorGen.tgBroadcastFuzz,
3437 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003438 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003439 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003440 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003441 "error_if_validators": (
3442 TosaErrorValidator.evRankMismatch,
3443 TosaErrorValidator.evWrongInputType,
3444 TosaErrorValidator.evWrongOutputType,
3445 TosaErrorValidator.evWrongInputList,
3446 TosaErrorValidator.evWrongOutputList,
3447 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003448 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003449 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003450 "data_gen": {
3451 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3452 },
3453 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003454 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003455 "arithmetic_right_shift": {
3456 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3457 "operands": (2, 0),
3458 "build_fcn": (
3459 build_arithmetic_right_shift,
3460 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003461 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003462 TosaArgGen.agArithmeticRightShift,
3463 ),
3464 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003465 "error_if_validators": (
3466 TosaErrorValidator.evRankMismatch,
3467 TosaErrorValidator.evWrongInputType,
3468 TosaErrorValidator.evWrongOutputType,
3469 TosaErrorValidator.evWrongInputList,
3470 TosaErrorValidator.evWrongOutputList,
3471 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003472 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003473 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003474 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003475 "bitwise_and": {
3476 "op": Op.BITWISE_AND,
3477 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003478 "build_fcn": (
3479 build_binary_broadcast,
3480 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003481 TosaTensorValuesGen.tvgLazyGenDefault,
3482 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003483 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003484 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003485 "error_if_validators": (
3486 TosaErrorValidator.evRankMismatch,
3487 TosaErrorValidator.evWrongInputType,
3488 TosaErrorValidator.evWrongOutputType,
3489 TosaErrorValidator.evWrongInputList,
3490 TosaErrorValidator.evWrongOutputList,
3491 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003492 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003493 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003494 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003495 "bitwise_or": {
3496 "op": Op.BITWISE_OR,
3497 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003498 "build_fcn": (
3499 build_binary_broadcast,
3500 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003501 TosaTensorValuesGen.tvgLazyGenDefault,
3502 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003503 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003504 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003505 "error_if_validators": (
3506 TosaErrorValidator.evRankMismatch,
3507 TosaErrorValidator.evWrongInputType,
3508 TosaErrorValidator.evWrongOutputType,
3509 TosaErrorValidator.evWrongInputList,
3510 TosaErrorValidator.evWrongOutputList,
3511 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003512 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003513 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003514 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003515 "bitwise_xor": {
3516 "op": Op.BITWISE_XOR,
3517 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003518 "build_fcn": (
3519 build_binary_broadcast,
3520 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003521 TosaTensorValuesGen.tvgLazyGenDefault,
3522 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003523 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003524 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003525 "error_if_validators": (
3526 TosaErrorValidator.evRankMismatch,
3527 TosaErrorValidator.evWrongInputType,
3528 TosaErrorValidator.evWrongOutputType,
3529 TosaErrorValidator.evWrongInputList,
3530 TosaErrorValidator.evWrongOutputList,
3531 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003532 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003533 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003534 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003535 "intdiv": {
3536 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003537 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003538 "build_fcn": (
3539 build_binary_broadcast,
3540 TosaTensorGen.tgBroadcastFuzz,
3541 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003542 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003543 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003544 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003545 "error_if_validators": (
3546 TosaErrorValidator.evRankMismatch,
3547 TosaErrorValidator.evWrongInputType,
3548 TosaErrorValidator.evWrongOutputType,
3549 TosaErrorValidator.evWrongInputList,
3550 TosaErrorValidator.evWrongOutputList,
3551 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003552 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003553 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003554 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003555 "logical_and": {
3556 "op": Op.LOGICAL_AND,
3557 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003558 "build_fcn": (
3559 build_binary_broadcast,
3560 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003561 TosaTensorValuesGen.tvgLazyGenDefault,
3562 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003563 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003564 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003565 "error_if_validators": (
3566 TosaErrorValidator.evRankMismatch,
3567 TosaErrorValidator.evWrongInputType,
3568 TosaErrorValidator.evWrongOutputType,
3569 TosaErrorValidator.evWrongInputList,
3570 TosaErrorValidator.evWrongOutputList,
3571 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003572 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003573 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003574 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003575 "logical_left_shift": {
3576 "op": Op.LOGICAL_LEFT_SHIFT,
3577 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003578 "build_fcn": (
3579 build_binary_broadcast,
3580 TosaTensorGen.tgBroadcastFuzz,
3581 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003582 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003583 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003584 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003585 "error_if_validators": (
3586 TosaErrorValidator.evRankMismatch,
3587 TosaErrorValidator.evWrongInputType,
3588 TosaErrorValidator.evWrongOutputType,
3589 TosaErrorValidator.evWrongInputList,
3590 TosaErrorValidator.evWrongOutputList,
3591 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003592 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003593 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003594 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003595 "logical_right_shift": {
3596 "op": Op.LOGICAL_RIGHT_SHIFT,
3597 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003598 "build_fcn": (
3599 build_binary_broadcast,
3600 TosaTensorGen.tgBroadcastFuzz,
3601 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003602 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003603 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003604 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003605 "error_if_validators": (
3606 TosaErrorValidator.evRankMismatch,
3607 TosaErrorValidator.evWrongInputType,
3608 TosaErrorValidator.evWrongOutputType,
3609 TosaErrorValidator.evWrongInputList,
3610 TosaErrorValidator.evWrongOutputList,
3611 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003612 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003613 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003614 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003615 "logical_or": {
3616 "op": Op.LOGICAL_OR,
3617 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003618 "build_fcn": (
3619 build_binary_broadcast,
3620 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003621 TosaTensorValuesGen.tvgLazyGenDefault,
3622 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003623 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003624 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003625 "error_if_validators": (
3626 TosaErrorValidator.evRankMismatch,
3627 TosaErrorValidator.evWrongInputType,
3628 TosaErrorValidator.evWrongOutputType,
3629 TosaErrorValidator.evWrongInputList,
3630 TosaErrorValidator.evWrongOutputList,
3631 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003632 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003633 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003634 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003635 "logical_xor": {
3636 "op": Op.LOGICAL_XOR,
3637 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003638 "build_fcn": (
3639 build_binary_broadcast,
3640 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003641 TosaTensorValuesGen.tvgLazyGenDefault,
3642 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003643 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003644 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003645 "error_if_validators": (
3646 TosaErrorValidator.evRankMismatch,
3647 TosaErrorValidator.evWrongInputType,
3648 TosaErrorValidator.evWrongOutputType,
3649 TosaErrorValidator.evWrongInputList,
3650 TosaErrorValidator.evWrongOutputList,
3651 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003652 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003653 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003654 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003655 "maximum": {
3656 "op": Op.MAXIMUM,
3657 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003658 "build_fcn": (
3659 build_binary_broadcast,
3660 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003661 TosaTensorValuesGen.tvgLazyGenDefault,
3662 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003663 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003664 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003665 "error_if_validators": (
3666 TosaErrorValidator.evRankMismatch,
3667 TosaErrorValidator.evWrongInputType,
3668 TosaErrorValidator.evWrongOutputType,
3669 TosaErrorValidator.evWrongInputList,
3670 TosaErrorValidator.evWrongOutputList,
3671 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003672 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003673 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003674 "data_gen": {
3675 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3676 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003677 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003678 "minimum": {
3679 "op": Op.MINIMUM,
3680 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003681 "build_fcn": (
3682 build_binary_broadcast,
3683 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003684 TosaTensorValuesGen.tvgLazyGenDefault,
3685 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003686 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003687 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003688 "error_if_validators": (
3689 TosaErrorValidator.evRankMismatch,
3690 TosaErrorValidator.evWrongInputType,
3691 TosaErrorValidator.evWrongOutputType,
3692 TosaErrorValidator.evWrongInputList,
3693 TosaErrorValidator.evWrongOutputList,
3694 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003695 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003696 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003697 "data_gen": {
3698 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3699 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003700 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003701 "mul": {
3702 "op": Op.MUL,
3703 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003704 "build_fcn": (
3705 build_mul,
3706 TosaTensorGen.tgBroadcastFuzz,
3707 TosaTensorValuesGen.tvgMul,
3708 TosaArgGen.agMul,
3709 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003710 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003711 "error_if_validators": (
3712 TosaErrorValidator.evWrongInputType,
3713 TosaErrorValidator.evWrongOutputType,
3714 TosaErrorValidator.evWrongInputList,
3715 TosaErrorValidator.evWrongOutputList,
3716 TosaErrorValidator.evRankMismatch,
3717 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003718 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003719 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003720 "data_gen": {
3721 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3722 },
3723 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003724 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003725 "pow": {
3726 "op": Op.POW,
3727 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003728 "build_fcn": (
3729 build_binary_broadcast,
3730 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003731 TosaTensorValuesGen.tvgPow,
3732 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003733 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003734 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003735 "error_if_validators": (
3736 TosaErrorValidator.evRankMismatch,
3737 TosaErrorValidator.evWrongInputType,
3738 TosaErrorValidator.evWrongOutputType,
3739 TosaErrorValidator.evWrongInputList,
3740 TosaErrorValidator.evWrongOutputList,
3741 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003742 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003743 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003744 "data_gen": {
3745 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3746 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003747 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003748 "sub": {
3749 "op": Op.SUB,
3750 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003751 "build_fcn": (
3752 build_binary_broadcast,
3753 TosaTensorGen.tgBroadcastFuzz,
3754 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003755 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003756 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003757 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003758 "error_if_validators": (
3759 TosaErrorValidator.evRankMismatch,
3760 TosaErrorValidator.evWrongInputType,
3761 TosaErrorValidator.evWrongOutputType,
3762 TosaErrorValidator.evWrongInputList,
3763 TosaErrorValidator.evWrongOutputList,
3764 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003765 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003766 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003767 "data_gen": {
3768 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3769 },
3770 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003771 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003772 "table": {
3773 "op": Op.TABLE,
3774 # Use the automatic generation functions to create the input array
3775 # but create the table tensor in the build function, as it may be
3776 # a different type from the input
3777 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003778 "build_fcn": (
3779 build_table,
3780 TosaTensorGen.tgBasic,
3781 TosaTensorValuesGen.tvgDefault,
3782 TosaArgGen.agTable,
3783 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003784 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003785 "error_if_validators": (
3786 TosaErrorValidator.evWrongInputType,
3787 TosaErrorValidator.evWrongOutputType,
3788 TosaErrorValidator.evWrongInputList,
3789 TosaErrorValidator.evWrongOutputList,
3790 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003791 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003792 # Elementwise Unary operators
3793 "abs": {
3794 "op": Op.ABS,
3795 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003796 "build_fcn": (
3797 build_unary,
3798 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003799 TosaTensorValuesGen.tvgLazyGenDefault,
3800 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003801 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003802 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003803 "error_if_validators": (
3804 TosaErrorValidator.evWrongInputType,
3805 TosaErrorValidator.evWrongOutputType,
3806 TosaErrorValidator.evWrongInputList,
3807 TosaErrorValidator.evWrongOutputList,
3808 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003809 "data_gen": {
3810 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3811 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003812 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003813 "bitwise_not": {
3814 "op": Op.BITWISE_NOT,
3815 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003816 "build_fcn": (
3817 build_unary,
3818 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003819 TosaTensorValuesGen.tvgLazyGenDefault,
3820 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003821 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003822 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003823 "error_if_validators": (
3824 TosaErrorValidator.evWrongInputType,
3825 TosaErrorValidator.evWrongOutputType,
3826 TosaErrorValidator.evWrongInputList,
3827 TosaErrorValidator.evWrongOutputList,
3828 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003829 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003830 "ceil": {
3831 "op": Op.CEIL,
3832 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003833 "build_fcn": (
3834 build_unary,
3835 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003836 TosaTensorValuesGen.tvgLazyGenDefault,
3837 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003838 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003839 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003840 "error_if_validators": (
3841 TosaErrorValidator.evWrongInputType,
3842 TosaErrorValidator.evWrongOutputType,
3843 TosaErrorValidator.evWrongInputList,
3844 TosaErrorValidator.evWrongOutputList,
3845 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003846 "data_gen": {
3847 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3848 },
3849 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003850 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003851 "clz": {
3852 "op": Op.CLZ,
3853 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003854 "build_fcn": (
3855 build_unary,
3856 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003857 TosaTensorValuesGen.tvgLazyGenDefault,
3858 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003859 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003860 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003861 "error_if_validators": (
3862 TosaErrorValidator.evWrongInputType,
3863 TosaErrorValidator.evWrongOutputType,
3864 TosaErrorValidator.evWrongInputList,
3865 TosaErrorValidator.evWrongOutputList,
3866 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003867 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003868 "exp": {
3869 "op": Op.EXP,
3870 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003871 "build_fcn": (
3872 build_unary,
3873 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003874 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003875 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003876 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003877 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003878 "error_if_validators": (
3879 TosaErrorValidator.evWrongInputType,
3880 TosaErrorValidator.evWrongOutputType,
3881 TosaErrorValidator.evWrongInputList,
3882 TosaErrorValidator.evWrongOutputList,
3883 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003884 "data_gen": {
3885 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3886 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003887 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003888 "floor": {
3889 "op": Op.FLOOR,
3890 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003891 "build_fcn": (
3892 build_unary,
3893 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003894 TosaTensorValuesGen.tvgLazyGenDefault,
3895 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003896 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003897 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003898 "error_if_validators": (
3899 TosaErrorValidator.evWrongInputType,
3900 TosaErrorValidator.evWrongOutputType,
3901 TosaErrorValidator.evWrongInputList,
3902 TosaErrorValidator.evWrongOutputList,
3903 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003904 "data_gen": {
3905 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3906 },
3907 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003908 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003909 "log": {
3910 "op": Op.LOG,
3911 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003912 "build_fcn": (
3913 build_unary,
3914 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003915 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003916 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003917 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003918 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003919 "error_if_validators": (
3920 TosaErrorValidator.evWrongInputType,
3921 TosaErrorValidator.evWrongOutputType,
3922 TosaErrorValidator.evWrongInputList,
3923 TosaErrorValidator.evWrongOutputList,
3924 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003925 "data_gen": {
3926 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3927 },
3928 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003929 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003930 "logical_not": {
3931 "op": Op.LOGICAL_NOT,
3932 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003933 "build_fcn": (
3934 build_unary,
3935 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003936 TosaTensorValuesGen.tvgLazyGenDefault,
3937 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003938 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003939 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003940 "error_if_validators": (
3941 TosaErrorValidator.evWrongInputType,
3942 TosaErrorValidator.evWrongOutputType,
3943 TosaErrorValidator.evWrongInputList,
3944 TosaErrorValidator.evWrongOutputList,
3945 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003946 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003947 "negate": {
3948 "op": Op.NEGATE,
3949 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003950 "build_fcn": (
3951 build_unary,
3952 TosaTensorGen.tgBasic,
3953 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003954 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003955 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003956 "qgen": TosaQuantGen.qgUnary,
3957 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003958 "error_if_validators": (
3959 TosaErrorValidator.evInputZeroPointNotZero,
3960 TosaErrorValidator.evOutputZeroPointNotZero,
3961 TosaErrorValidator.evWrongInputType,
3962 TosaErrorValidator.evWrongOutputType,
3963 TosaErrorValidator.evWrongInputList,
3964 TosaErrorValidator.evWrongOutputList,
3965 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003966 "data_gen": {
3967 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3968 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003969 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003970 "reciprocal": {
3971 "op": Op.RECIPROCAL,
3972 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003973 "build_fcn": (
3974 build_unary,
3975 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003976 TosaTensorValuesGen.tvgLazyGenDefault,
3977 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003978 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003979 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003980 "error_if_validators": (
3981 TosaErrorValidator.evWrongInputType,
3982 TosaErrorValidator.evWrongOutputType,
3983 TosaErrorValidator.evWrongInputList,
3984 TosaErrorValidator.evWrongOutputList,
3985 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003986 "data_gen": {
3987 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3988 },
3989 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003990 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003991 "rsqrt": {
3992 "op": Op.RSQRT,
3993 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003994 "build_fcn": (
3995 build_unary,
3996 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003997 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003998 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003999 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004000 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004001 "error_if_validators": (
4002 TosaErrorValidator.evWrongInputType,
4003 TosaErrorValidator.evWrongOutputType,
4004 TosaErrorValidator.evWrongInputList,
4005 TosaErrorValidator.evWrongOutputList,
4006 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004007 "data_gen": {
4008 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4009 },
4010 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004011 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004012 # Elementwise Ternary operators
4013 "select": {
4014 "op": Op.SELECT,
4015 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004016 "build_fcn": (
4017 build_select,
4018 TosaTensorGen.tgBroadcastFuzz,
4019 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004020 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004021 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004022 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004023 "error_if_validators": (
4024 TosaErrorValidator.evRankMismatch,
4025 TosaErrorValidator.evWrongInputType,
4026 TosaErrorValidator.evWrongOutputType,
4027 TosaErrorValidator.evWrongInputList,
4028 TosaErrorValidator.evWrongOutputList,
4029 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004030 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004031 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004032 "data_gen": {
4033 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4034 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004035 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004036 # Comparison operators
4037 "equal": {
4038 "op": Op.EQUAL,
4039 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004040 "build_fcn": (
4041 build_comparison,
4042 TosaTensorGen.tgBroadcastFuzz,
4043 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004044 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004045 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004046 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004047 "error_if_validators": (
4048 TosaErrorValidator.evRankMismatch,
4049 TosaErrorValidator.evWrongInputType,
4050 TosaErrorValidator.evWrongOutputType,
4051 TosaErrorValidator.evWrongInputList,
4052 TosaErrorValidator.evWrongOutputList,
4053 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004054 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004055 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004056 "data_gen": {
4057 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4058 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004059 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004060 "greater_equal": {
4061 "op": Op.GREATER_EQUAL,
4062 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004063 "build_fcn": (
4064 build_comparison,
4065 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004066 TosaTensorValuesGen.tvgLazyGenDefault,
4067 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004068 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004069 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004070 "error_if_validators": (
4071 TosaErrorValidator.evRankMismatch,
4072 TosaErrorValidator.evWrongInputType,
4073 TosaErrorValidator.evWrongOutputType,
4074 TosaErrorValidator.evWrongInputList,
4075 TosaErrorValidator.evWrongOutputList,
4076 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004077 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004078 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004079 "data_gen": {
4080 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4081 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004082 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004083 "greater": {
4084 "op": Op.GREATER,
4085 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004086 "build_fcn": (
4087 build_comparison,
4088 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004089 TosaTensorValuesGen.tvgLazyGenDefault,
4090 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004091 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004092 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004093 "error_if_validators": (
4094 TosaErrorValidator.evRankMismatch,
4095 TosaErrorValidator.evWrongInputType,
4096 TosaErrorValidator.evWrongOutputType,
4097 TosaErrorValidator.evWrongInputList,
4098 TosaErrorValidator.evWrongOutputList,
4099 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004100 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004101 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004102 "data_gen": {
4103 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4104 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004105 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004106 # Reduction operators
4107 "reduce_all": {
4108 "op": Op.REDUCE_ALL,
4109 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004110 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004111 "build_fcn": (
4112 build_reduce,
4113 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004114 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004115 TosaArgGen.agAxis,
4116 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004117 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004118 "error_if_validators": (
4119 TosaErrorValidator.evAxisLargerRank,
4120 TosaErrorValidator.evAxisSmallerZero,
4121 TosaErrorValidator.evShapeOfAxisNotOne,
4122 TosaErrorValidator.evWrongInputType,
4123 TosaErrorValidator.evWrongOutputType,
4124 TosaErrorValidator.evWrongRank,
4125 TosaErrorValidator.evWrongInputList,
4126 TosaErrorValidator.evWrongOutputList,
4127 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004128 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004129 "reduce_any": {
4130 "op": Op.REDUCE_ANY,
4131 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004132 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004133 "build_fcn": (
4134 build_reduce,
4135 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004136 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004137 TosaArgGen.agAxis,
4138 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004139 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004140 "error_if_validators": (
4141 TosaErrorValidator.evAxisLargerRank,
4142 TosaErrorValidator.evAxisSmallerZero,
4143 TosaErrorValidator.evShapeOfAxisNotOne,
4144 TosaErrorValidator.evWrongInputType,
4145 TosaErrorValidator.evWrongOutputType,
4146 TosaErrorValidator.evWrongRank,
4147 TosaErrorValidator.evWrongInputList,
4148 TosaErrorValidator.evWrongOutputList,
4149 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004150 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004151 "reduce_max": {
4152 "op": Op.REDUCE_MAX,
4153 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004154 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004155 "build_fcn": (
4156 build_reduce,
4157 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004158 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004159 TosaArgGen.agAxis,
4160 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004161 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004162 "error_if_validators": (
4163 TosaErrorValidator.evAxisLargerRank,
4164 TosaErrorValidator.evAxisSmallerZero,
4165 TosaErrorValidator.evShapeOfAxisNotOne,
4166 TosaErrorValidator.evWrongInputType,
4167 TosaErrorValidator.evWrongOutputType,
4168 TosaErrorValidator.evWrongRank,
4169 TosaErrorValidator.evWrongInputList,
4170 TosaErrorValidator.evWrongOutputList,
4171 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004172 "data_gen": {
4173 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4174 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004175 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004176 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004177 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004178 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004179 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004180 "build_fcn": (
4181 build_reduce,
4182 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004183 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004184 TosaArgGen.agAxis,
4185 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004186 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004187 "error_if_validators": (
4188 TosaErrorValidator.evAxisLargerRank,
4189 TosaErrorValidator.evAxisSmallerZero,
4190 TosaErrorValidator.evShapeOfAxisNotOne,
4191 TosaErrorValidator.evWrongInputType,
4192 TosaErrorValidator.evWrongOutputType,
4193 TosaErrorValidator.evWrongRank,
4194 TosaErrorValidator.evWrongInputList,
4195 TosaErrorValidator.evWrongOutputList,
4196 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004197 "data_gen": {
4198 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4199 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004200 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004201 "reduce_product": {
4202 "op": Op.REDUCE_PRODUCT,
4203 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004204 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004205 "build_fcn": (
4206 build_reduce,
4207 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004208 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004209 TosaArgGen.agAxis,
4210 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004211 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004212 "error_if_validators": (
4213 TosaErrorValidator.evAxisLargerRank,
4214 TosaErrorValidator.evAxisSmallerZero,
4215 TosaErrorValidator.evShapeOfAxisNotOne,
4216 TosaErrorValidator.evWrongInputType,
4217 TosaErrorValidator.evWrongOutputType,
4218 TosaErrorValidator.evWrongRank,
4219 TosaErrorValidator.evWrongInputList,
4220 TosaErrorValidator.evWrongOutputList,
4221 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004222 "data_gen": {
4223 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4224 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004225 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004226 "reduce_sum": {
4227 "op": Op.REDUCE_SUM,
4228 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004229 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004230 "build_fcn": (
4231 build_reduce,
4232 TosaTensorGen.tgBasic,
4233 TosaTensorValuesGen.tvgReduceSum,
4234 TosaArgGen.agAxis,
4235 ),
James Ward24dbc422022-10-19 12:20:31 +01004236 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004237 "error_if_validators": (
4238 TosaErrorValidator.evAxisLargerRank,
4239 TosaErrorValidator.evAxisSmallerZero,
4240 TosaErrorValidator.evShapeOfAxisNotOne,
4241 TosaErrorValidator.evWrongInputType,
4242 TosaErrorValidator.evWrongOutputType,
4243 TosaErrorValidator.evWrongRank,
4244 TosaErrorValidator.evWrongInputList,
4245 TosaErrorValidator.evWrongOutputList,
4246 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004247 "data_gen": {
4248 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4249 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004250 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004251 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004252 "concat": {
4253 "op": Op.CONCAT,
4254 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004255 "build_fcn": (
4256 build_concat,
4257 TosaTensorGen.tgConcat,
4258 TosaTensorValuesGen.tvgConcat,
4259 TosaArgGen.agAxis,
4260 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004261 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004262 "error_if_validators": (
4263 TosaErrorValidator.evAxisLargerRank,
4264 TosaErrorValidator.evAxisSmallerZero,
4265 TosaErrorValidator.evConcatInputRankMismatch,
4266 TosaErrorValidator.evConcatShapeSumMismatch,
4267 TosaErrorValidator.evConcatInputDimMismatch,
4268 TosaErrorValidator.evWrongInputType,
4269 TosaErrorValidator.evWrongOutputType,
4270 TosaErrorValidator.evWrongOutputList,
4271 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004272 "data_gen": {
4273 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4274 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004275 },
4276 "pad": {
4277 "op": Op.PAD,
4278 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004279 "build_fcn": (
4280 build_pad,
4281 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004282 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004283 TosaArgGen.agPad,
4284 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004285 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004286 "error_if_validators": (
4287 TosaErrorValidator.evWrongInputType,
4288 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004289 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004290 TosaErrorValidator.evWrongOutputType,
4291 TosaErrorValidator.evWrongInputList,
4292 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004293 TosaErrorValidator.evRankMismatch,
4294 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004295 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004296 "data_gen": {
4297 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4298 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004299 },
Won Jeona21b2e82023-08-10 10:33:01 +00004300 "dim": {
4301 "op": Op.DIM,
4302 "operands": (1, 0),
4303 "build_fcn": (
4304 build_dim,
4305 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004306 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004307 TosaArgGen.agAxis,
4308 ),
4309 "types": TYPE_FIB,
4310 "error_if_validators": (
4311 TosaErrorValidator.evAxisLargerRank,
4312 TosaErrorValidator.evAxisSmallerZero,
4313 TosaErrorValidator.evWrongInputType,
4314 TosaErrorValidator.evWrongInputList,
4315 TosaErrorValidator.evWrongOutputList,
4316 TosaErrorValidator.evWrongRank,
4317 ),
4318 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004319 "reshape": {
4320 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004321 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004322 "build_fcn": (
4323 build_reshape,
4324 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004325 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004326 TosaArgGen.agReshape,
4327 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004328 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004329 "error_if_validators": (
4330 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4331 TosaErrorValidator.evWrongInputType,
4332 TosaErrorValidator.evWrongOutputType,
4333 TosaErrorValidator.evWrongInputList,
4334 TosaErrorValidator.evWrongOutputList,
4335 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004336 "data_gen": {
4337 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4338 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004339 },
4340 "reverse": {
4341 "op": Op.REVERSE,
4342 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004343 "build_fcn": (
4344 build_reverse,
4345 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004346 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004347 TosaArgGen.agAxis,
4348 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004349 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004350 "error_if_validators": (
4351 TosaErrorValidator.evAxisSmallerZero,
4352 TosaErrorValidator.evAxisLargerRank,
4353 TosaErrorValidator.evWrongInputType,
4354 TosaErrorValidator.evWrongOutputType,
4355 TosaErrorValidator.evWrongInputList,
4356 TosaErrorValidator.evWrongOutputList,
4357 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004358 },
4359 "slice": {
4360 "op": Op.SLICE,
4361 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004362 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004363 "build_fcn": (
4364 build_slice,
4365 TosaTensorGen.tgBasic,
evacha017f7d4252024-01-24 12:08:09 +00004366 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004367 TosaArgGen.agSlice,
4368 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004369 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004370 "error_if_validators": (
4371 TosaErrorValidator.evStartSmallerZero,
4372 TosaErrorValidator.evSizeSmallerEqualZero,
4373 TosaErrorValidator.evStartSizeOutsideBounds,
4374 TosaErrorValidator.evSizeOutputShapeMismatch,
4375 TosaErrorValidator.evInputSizeStartLengthMismatch,
4376 TosaErrorValidator.evWrongRank,
4377 TosaErrorValidator.evWrongInputType,
4378 TosaErrorValidator.evWrongOutputType,
4379 TosaErrorValidator.evWrongInputList,
4380 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004381 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004382 ),
evacha017f7d4252024-01-24 12:08:09 +00004383 "data_gen": {
4384 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4385 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004386 },
4387 "tile": {
4388 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004389 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004390 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004391 "build_fcn": (
4392 build_tile,
4393 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004394 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004395 TosaArgGen.agTile,
4396 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004397 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004398 "error_if_validators": (
4399 TosaErrorValidator.evWrongInputType,
4400 TosaErrorValidator.evWrongOutputType,
4401 TosaErrorValidator.evWrongInputList,
4402 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004403 TosaErrorValidator.evRankMismatch,
4404 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004405 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004406 "data_gen": {
4407 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4408 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004409 },
4410 "transpose": {
4411 "op": Op.TRANSPOSE,
4412 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004413 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004414 "build_fcn": (
4415 build_transpose,
4416 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004417 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004418 TosaArgGen.agTranspose,
4419 ),
4420 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004421 "error_if_validators": (
4422 TosaErrorValidator.evIndexOutsideBounds,
4423 TosaErrorValidator.evIndexUsedTwice,
4424 TosaErrorValidator.evWrongInputType,
4425 TosaErrorValidator.evWrongOutputType,
4426 TosaErrorValidator.evWrongInputList,
4427 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004428 TosaErrorValidator.evWrongRank,
4429 TosaErrorValidator.evRankMismatch,
4430 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004431 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004432 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004433 # Data nodes
4434 "const": {
4435 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004436 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004437 "build_fcn": (
4438 build_const,
4439 TosaTensorGen.tgBasic,
4440 TosaTensorValuesGen.tvgDefault,
4441 None,
4442 ),
Luke Hutton65872422023-02-20 10:33:04 +00004443 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004444 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004445 "identity": {
4446 "op": Op.IDENTITY,
4447 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004448 "build_fcn": (
4449 build_unary,
4450 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004451 TosaTensorValuesGen.tvgLazyGenDefault,
4452 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004453 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004454 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004455 "data_gen": {
4456 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4457 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004458 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004459 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004460 "gather": {
4461 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004462 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004463 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004464 "build_fcn": (
4465 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004466 TosaTensorGen.tgGather,
4467 TosaTensorValuesGen.tvgGather,
4468 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004469 ),
James Ward24dbc422022-10-19 12:20:31 +01004470 "types": (
4471 DType.INT8,
4472 DType.INT16,
4473 DType.INT32,
4474 DType.FP16,
4475 DType.BF16,
4476 DType.FP32,
4477 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004478 "error_if_validators": (
4479 TosaErrorValidator.evWrongInputType,
4480 TosaErrorValidator.evWrongOutputType,
4481 TosaErrorValidator.evWrongInputList,
4482 TosaErrorValidator.evWrongOutputList,
4483 TosaErrorValidator.evWrongRank,
4484 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004485 "data_gen": {
4486 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4487 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004488 },
4489 "scatter": {
4490 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004491 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004492 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004493 "build_fcn": (
4494 build_scatter,
4495 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004496 TosaTensorValuesGen.tvgScatter,
4497 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004498 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004499 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004500 "error_if_validators": (
4501 TosaErrorValidator.evWrongInputType,
4502 TosaErrorValidator.evWrongOutputType,
4503 TosaErrorValidator.evWrongInputList,
4504 TosaErrorValidator.evWrongOutputList,
4505 TosaErrorValidator.evWrongRank,
4506 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004507 "data_gen": {
4508 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4509 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004510 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004511 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004512 "resize": {
4513 "op": Op.RESIZE,
4514 "operands": (1, 0),
4515 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004516 "build_fcn": (
4517 build_resize,
4518 TosaTensorGen.tgNHWC,
4519 TosaTensorValuesGen.tvgDefault,
4520 TosaArgGen.agResize,
4521 ),
James Ward24dbc422022-10-19 12:20:31 +01004522 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004523 "invalid_test_validators": (
4524 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004525 ),
4526 "error_if_validators": (
4527 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004528 TosaErrorValidator.evScaleSmallerEqualZero,
4529 TosaErrorValidator.evScaleNLargerMax,
4530 TosaErrorValidator.evScaleDLargerMax,
4531 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004532 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004533 TosaErrorValidator.evBorderSmallerMin,
4534 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004535 TosaErrorValidator.evWrongInputType,
4536 TosaErrorValidator.evWrongOutputType,
4537 TosaErrorValidator.evWrongRank,
4538 TosaErrorValidator.evWrongInputList,
4539 TosaErrorValidator.evWrongOutputList,
4540 TosaErrorValidator.evBatchMismatch,
4541 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004542 TosaErrorValidator.evResizeOutputShapeMismatch,
4543 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004544 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004545 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004546 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004547 "cast": {
4548 "op": Op.CAST,
4549 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004550 "build_fcn": (
4551 build_cast,
4552 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004553 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004554 TosaArgGen.agCast,
4555 ),
James Ward8b390432022-08-12 20:48:56 +01004556 "types": (
4557 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004558 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004559 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004560 DType.INT8,
4561 DType.INT16,
4562 DType.INT32,
4563 DType.BOOL,
4564 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004565 "error_if_validators": (
4566 TosaErrorValidator.evWrongInputType,
4567 TosaErrorValidator.evWrongOutputType,
4568 TosaErrorValidator.evWrongInputList,
4569 TosaErrorValidator.evWrongOutputList,
4570 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004571 "data_gen": {
4572 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4573 },
4574 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004575 },
4576 "rescale": {
4577 "op": Op.RESCALE,
4578 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004579 "build_fcn": (
4580 build_rescale,
4581 TosaTensorGen.tgBasic,
4582 TosaTensorValuesGen.tvgDefault,
4583 TosaArgGen.agRescale,
4584 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004585 "types": [
4586 DType.UINT8,
4587 DType.INT8,
4588 DType.INT16,
4589 DType.INT32,
4590 DType.INT48,
4591 DType.UINT16,
4592 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004593 "error_if_validators": (
4594 TosaErrorValidator.evInputZeroPointNotZero,
4595 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004596 TosaErrorValidator.evU16InputZeroPointNotValid,
4597 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004598 TosaErrorValidator.evScaleTrue,
4599 TosaErrorValidator.evScaleNotTrue,
4600 TosaErrorValidator.evWrongInputType,
4601 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004602 TosaErrorValidator.evWrongInputList,
4603 TosaErrorValidator.evWrongOutputList,
4604 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004605 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004606 # Custom
4607 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004608 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004609 # Two varients of cond_if, one that generates one of two constant tensors (no
4610 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4611 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004612 "cond_if_const": {
4613 "op": Op.COND_IF,
4614 "operands": (0, 2),
4615 "build_fcn": (
4616 build_cond_if_const,
4617 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004618 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004619 TosaArgGen.agCondIf,
4620 ),
4621 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004622 "error_if_validators": (
4623 TosaErrorValidator.evOutputListThenGraphMismatch,
4624 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004625 TosaErrorValidator.evCondIfCondNotMatchingBool,
4626 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004627 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004628 },
4629 "cond_if_binary": {
4630 "op": Op.COND_IF,
4631 "operands": (2, 0),
4632 "build_fcn": (
4633 build_cond_if_binary,
4634 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004635 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004636 TosaArgGen.agCondIf,
4637 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004638 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004639 "error_if_validators": (
4640 TosaErrorValidator.evInputListThenGraphMismatch,
4641 TosaErrorValidator.evInputListElseGraphMismatch,
4642 TosaErrorValidator.evOutputListThenGraphMismatch,
4643 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004644 TosaErrorValidator.evCondIfCondNotMatchingBool,
4645 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004646 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004647 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004648 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004649 "while_loop": {
4650 "op": Op.WHILE_LOOP,
4651 "operands": (0, 1),
4652 "build_fcn": (
4653 build_while_loop,
4654 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004655 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004656 TosaArgGen.agWhileLoop,
4657 ),
4658 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004659 "error_if_validators": (
4660 TosaErrorValidator.evInputListOutputListMismatch,
4661 TosaErrorValidator.evInputListCondGraphMismatch,
4662 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4663 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4664 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004665 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004666 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004667 },
Luke Hutton57287132023-02-06 14:54:18 +00004668 "fft2d": {
4669 "op": Op.FFT2D,
4670 "operands": (2, 0),
4671 "rank": (3, 3),
4672 "build_fcn": (
4673 build_fft2d,
4674 TosaTensorGen.tgFFT2d,
4675 TosaTensorValuesGen.tvgDefault,
4676 TosaArgGen.agFFT2d,
4677 ),
4678 "types": [DType.FP32],
4679 "error_if_validators": (
4680 TosaErrorValidator.evWrongInputType,
4681 TosaErrorValidator.evWrongOutputType,
4682 TosaErrorValidator.evWrongInputList,
4683 TosaErrorValidator.evWrongOutputList,
4684 TosaErrorValidator.evWrongRank,
4685 TosaErrorValidator.evBatchMismatch,
4686 TosaErrorValidator.evKernelNotPowerOfTwo,
4687 TosaErrorValidator.evFFTInputShapeMismatch,
4688 TosaErrorValidator.evFFTOutputShapeMismatch,
4689 ),
4690 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004691 "rfft2d": {
4692 "op": Op.RFFT2D,
4693 "operands": (1, 0),
4694 "rank": (3, 3),
4695 "build_fcn": (
4696 build_rfft2d,
4697 TosaTensorGen.tgRFFT2d,
4698 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004699 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004700 ),
4701 "types": [DType.FP32],
4702 "error_if_validators": (
4703 TosaErrorValidator.evWrongInputType,
4704 TosaErrorValidator.evWrongOutputType,
4705 TosaErrorValidator.evWrongInputList,
4706 TosaErrorValidator.evWrongOutputList,
4707 TosaErrorValidator.evWrongRank,
4708 TosaErrorValidator.evBatchMismatch,
4709 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004710 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004711 ),
4712 },
Won Jeon74342e52024-01-09 00:34:40 +00004713 # Shape
4714 "add_shape": {
4715 "op": Op.ADD_SHAPE,
4716 "operands": (2, 0),
4717 "build_fcn": (
4718 build_shape_op,
4719 TosaTensorGen.tgShape,
4720 TosaTensorValuesGen.tvgAddSub,
4721 TosaArgGen.agNone,
4722 ),
4723 "types": [DType.SHAPE],
4724 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4725 },
4726 "sub_shape": {
4727 "op": Op.SUB_SHAPE,
4728 "operands": (2, 0),
4729 "build_fcn": (
4730 build_shape_op,
4731 TosaTensorGen.tgShape,
4732 TosaTensorValuesGen.tvgAddSub,
4733 TosaArgGen.agNone,
4734 ),
4735 "types": [DType.SHAPE],
4736 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4737 },
4738 "mul_shape": {
4739 "op": Op.MUL_SHAPE,
4740 "operands": (2, 0),
4741 "build_fcn": (
4742 build_shape_op,
4743 TosaTensorGen.tgShape,
4744 TosaTensorValuesGen.tvgMul,
4745 TosaArgGen.agNone,
4746 ),
4747 "types": [DType.SHAPE],
4748 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4749 },
4750 "div_shape": {
4751 "op": Op.DIV_SHAPE,
4752 "operands": (2, 0),
4753 "build_fcn": (
4754 build_shape_op,
4755 TosaTensorGen.tgShape,
4756 TosaTensorValuesGen.tvgIntDiv,
4757 TosaArgGen.agNone,
4758 ),
4759 "types": [DType.SHAPE],
4760 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4761 },
4762 "concat_shape": {
4763 "op": Op.CONCAT_SHAPE,
4764 "operands": (2, 0),
4765 "build_fcn": (
4766 build_concat,
4767 TosaTensorGen.tgConcat,
4768 TosaTensorValuesGen.tvgConcat,
4769 TosaArgGen.agNone,
4770 ),
4771 "types": [DType.SHAPE],
4772 "error_if_validators": (),
4773 },
4774 "const_shape": {
4775 "op": Op.CONST_SHAPE,
4776 "operands": (0, 1),
4777 "build_fcn": (
4778 build_const,
4779 TosaTensorGen.tgBasic,
4780 TosaTensorValuesGen.tvgDefault,
4781 None,
4782 ),
4783 "types": [DType.SHAPE],
4784 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004785 }
4786
Kevin Cheng550ccc52021-03-03 11:21:43 -08004787
Eric Kunzee5e26762020-10-13 16:11:07 -07004788class OutputShaper:
4789 # Methods in this class compute the expected output shape and datatype
4790 # for common classes of operations
4791 def __init__(self):
4792 pass
4793
4794 # These methods return arguments that can be used for
4795 # creating a new output tensor
4796 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004797 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4798 if error_name != ErrorIf.RankMismatch:
4799 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004800 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004801
4802 shape = []
4803 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004804 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004805 shape.append(b.shape[i])
4806 else:
4807 shape.append(a.shape[i])
4808
Jerry Ge135c9552023-05-23 20:59:32 +00004809 fuzz_idx = rng.integers(0, len(a.shape))
4810 if error_name == ErrorIf.DimensionMismatch:
4811 shape[fuzz_idx] += 1
4812
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004813 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004814 all_dtypes = [
4815 DType.INT8,
4816 DType.INT16,
4817 DType.INT32,
4818 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004819 DType.FP16,
4820 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004821 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004822 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004823 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4824 outputDType = rng.choice(wrong_dtypes)
4825 else:
4826 outputDType = a.dtype
4827
4828 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004829
4830 @staticmethod
4831 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004832 assert len(a.shape) == len(b.shape)
4833 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004834
4835 shape = []
4836 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004837 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004838 shape.append(a.shape[i])
4839
Kevin Cheng550ccc52021-03-03 11:21:43 -08004840 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004841
4842 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004843 def unaryOp(ser, rng, a, error_name=None):
4844 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004845 all_dtypes = [
4846 DType.INT8,
4847 DType.INT16,
4848 DType.INT32,
4849 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004850 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004851 DType.FP16,
4852 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004853 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004854 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4855 outputDType = rng.choice(wrong_dtypes)
4856 else:
4857 outputDType = a.dtype
4858
4859 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004860
4861 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004862 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004863 if error_name != ErrorIf.RankMismatch:
4864 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004865 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004866
4867 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004868 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004869 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004870 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4871 else:
4872 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004873
Jerry Ge135c9552023-05-23 20:59:32 +00004874 fuzz_idx = rng.integers(0, len(a.shape))
4875 if error_name == ErrorIf.DimensionMismatch:
4876 shape[fuzz_idx] += 1
4877
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004878 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004879 all_dtypes = [
4880 DType.INT8,
4881 DType.INT16,
4882 DType.INT32,
4883 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004884 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004885 DType.FP16,
4886 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004887 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004888 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4889 outputDType = rng.choice(wrong_dtypes)
4890 else:
4891 outputDType = a.dtype
4892
4893 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004894
4895 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004896 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004897 if error_name != ErrorIf.RankMismatch:
4898 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004899 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004900
4901 # Do broadcast
4902 shape = []
4903 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004904 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004905 shape.append(b.shape[i])
4906 else:
4907 shape.append(a.shape[i])
4908
Jerry Ge135c9552023-05-23 20:59:32 +00004909 fuzz_idx = rng.integers(0, len(a.shape))
4910 if error_name == ErrorIf.DimensionMismatch:
4911 shape[fuzz_idx] += 1
4912
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004913 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004914 wrong_dtypes = [
4915 DType.INT8,
4916 DType.INT16,
4917 DType.INT32,
4918 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004919 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004920 DType.FP16,
4921 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004922 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004923 outputDType = rng.choice(wrong_dtypes)
4924 else:
4925 outputDType = DType.BOOL
4926
4927 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004928
4929 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004930 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004931 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004932 if error_name not in [
4933 ErrorIf.AxisSmallerZero,
4934 ErrorIf.AxisLargerRank,
4935 ErrorIf.ShapeOfAxisNotOne,
4936 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004937 shape[axis] = 1
4938 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4939 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004940
Matthew Haddond6ce7252021-09-29 15:35:44 +01004941 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004942 all_dtypes = [
4943 DType.INT8,
4944 DType.INT16,
4945 DType.INT32,
4946 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004947 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004948 DType.FP16,
4949 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004950 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004951 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4952 outputDType = rng.choice(wrong_dtypes)
4953 else:
4954 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004955
Matthew Haddond6ce7252021-09-29 15:35:44 +01004956 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004957
4958 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004959 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004960 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004961
4962 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4963 del shape[axis]
4964
4965 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4966 remove = rng.choice([True, False])
4967 if remove and len(shape) > 1:
4968 del shape[0]
4969 else:
4970 shape.append(1)
4971 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4972 for i in range(len(shape)):
4973 shape[i] = shape[i] + rng.integers(1, 10)
4974
4975 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004976 all_dtypes = [
4977 DType.INT8,
4978 DType.INT16,
4979 DType.INT32,
4980 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004981 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004982 DType.FP16,
4983 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004984 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004985 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4986 outputDType = rng.choice(wrong_dtypes)
4987 else:
4988 outputDType = DType.INT32
4989
4990 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004991
4992 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004993 def conv2dOp(
4994 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4995 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004996
4997 # IFM: NHWC
4998 # Filter: OHWI
4999 # OFM: NHWC
5000
Kevin Cheng550ccc52021-03-03 11:21:43 -08005001 h = (
5002 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005003 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005004 + padding[0]
5005 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005006 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005007 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005008
Kevin Cheng550ccc52021-03-03 11:21:43 -08005009 w = (
5010 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005011 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005012 + padding[2]
5013 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005014 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005015 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005016
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005017 if error_name == ErrorIf.ConvOutputShapeMismatch:
5018 choices = [1, 2, 3]
5019 change = rng.choice(choices)
5020 # increment in multiples of stride to not hit non-integer error case
5021 if change in [1, 3]:
5022 h = h + (rng.choice(choices) * strides[0])
5023 if change in [2, 3]:
5024 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005025
Eric Kunzee5e26762020-10-13 16:11:07 -07005026 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5027
James Ward8b390432022-08-12 20:48:56 +01005028 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005029 # Pick some potentially correct output dtype if input type is incorrect
5030 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005031 else:
James Ward8b390432022-08-12 20:48:56 +01005032 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005033
5034 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005035 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005036 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005037 else:
5038 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005039 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005040 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005041
Kevin Cheng550ccc52021-03-03 11:21:43 -08005042 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005043
5044 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005045 def conv3dOp(
5046 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5047 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005048
5049 # IFM: NDHWC
5050 # Filter: ODHWI
5051 # OFM: NDHWC
5052
5053 d = (
5054 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005055 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005056 + padding[0]
5057 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005058 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005059 ) // strides[0] + 1
5060
5061 h = (
5062 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005063 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005064 + padding[2]
5065 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005066 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005067 ) // strides[1] + 1
5068
5069 w = (
5070 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005071 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005072 + padding[4]
5073 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005074 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005075 ) // strides[2] + 1
5076
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005077 if error_name == ErrorIf.ConvOutputShapeMismatch:
5078 choices = [1, 2, 3, 4]
5079 change = rng.choice(choices)
5080 # increment in multiples of stride to not hit non-integer error case
5081 if change in [1, 4]:
5082 d = d + (rng.choice(choices) * strides[0])
5083 if change in [2, 4]:
5084 h = h + (rng.choice(choices) * strides[1])
5085 if change in [3, 4]:
5086 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005087
Kevin Cheng1533b852021-09-01 12:51:58 -07005088 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5089
James Ward8b390432022-08-12 20:48:56 +01005090 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005091 # Pick some potentially correct output dtype if input type is incorrect
5092 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005093 else:
James Ward8b390432022-08-12 20:48:56 +01005094 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005095
5096 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005097 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005098 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005099 else:
5100 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005101 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005102 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005103
5104 return ser.addOutput(ofm_shape, out_dtype)
5105
5106 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005107 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005108 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005109 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005110 # IFM: NHWC
5111 # Filter: HWCM
5112 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005113
Kevin Cheng550ccc52021-03-03 11:21:43 -08005114 h = (
5115 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005116 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005117 + padding[0]
5118 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005119 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005120 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005121
Kevin Cheng550ccc52021-03-03 11:21:43 -08005122 w = (
5123 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005124 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005125 + padding[2]
5126 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005127 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005128 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005129
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005130 if error_name == ErrorIf.ConvOutputShapeMismatch:
5131 choices = [1, 2, 3]
5132 change = rng.choice(choices)
5133 # increment in multiples of stride to not hit non-integer error case
5134 if change in [1, 3]:
5135 h = h + (rng.choice(choices) * strides[0])
5136 if change in [2, 3]:
5137 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005138
Eric Kunzee5e26762020-10-13 16:11:07 -07005139 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5140
James Ward8b390432022-08-12 20:48:56 +01005141 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005142 # Pick some potentially correct output dtype if input type is incorrect
5143 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005144 else:
James Ward8b390432022-08-12 20:48:56 +01005145 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005146
5147 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005148 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005149 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005150 else:
5151 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005152 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005153 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005154
Kevin Cheng550ccc52021-03-03 11:21:43 -08005155 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005156
5157 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005158 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005159 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005160 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005161 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005162 h = 1
5163 w = 1
5164 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005165 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5166 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005167
5168 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005169 choices = [1, 2, 3]
5170 change = rng.choice(choices)
5171 # increment in multiples of stride to not hit non-integer error case
5172 if change in [1, 3]:
5173 h = h + (rng.choice(choices) * stride[0])
5174 if change in [2, 3]:
5175 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005176 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005177
5178 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005179 all_dtypes = [
5180 DType.INT8,
5181 DType.INT16,
5182 DType.INT32,
5183 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005184 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005185 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005186 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005187 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005188 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5189 outputDType = rng.choice(wrong_dtypes)
5190 else:
5191 outputDType = ifm.dtype
5192
5193 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005194
5195 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005196 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005197 # input: N, IC
5198 # filter: OC, IC
5199 # output: N, OC
5200
5201 output_shape = [input.shape[0], filter.shape[0]]
5202
James Ward8b390432022-08-12 20:48:56 +01005203 # Validated in arg_gen (also invalidated for ErrorIf)
5204 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005205
Kevin Cheng550ccc52021-03-03 11:21:43 -08005206 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005207
5208 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005209 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005210 # a: N, H, C
5211 # b: N, C, W
5212 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005213
Kevin Cheng2d60f002021-06-09 14:18:32 -07005214 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005215
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005216 if error_name == ErrorIf.WrongOutputType:
5217 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005218 incorrect_types = (
5219 DType.INT4,
5220 DType.INT8,
5221 DType.INT16,
5222 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005223 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005224 DType.FP16,
5225 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005226 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005227 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005228 incorrect_types = (
5229 DType.INT4,
5230 DType.INT8,
5231 DType.INT16,
5232 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005233 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005234 DType.FP16,
5235 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005236 )
James Ward24dbc422022-10-19 12:20:31 +01005237 elif (
5238 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5239 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005240 incorrect_types = (
5241 DType.INT4,
5242 DType.INT8,
5243 DType.INT16,
5244 DType.INT32,
5245 DType.INT48,
5246 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005247 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005248 elif error_name == ErrorIf.WrongInputType:
5249 # Pick some potentially correct output dtype if input type is incorrect
5250 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005251 else:
James Ward8b390432022-08-12 20:48:56 +01005252 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005253
Kevin Cheng550ccc52021-03-03 11:21:43 -08005254 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005255
5256 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005257 def concatOp(ser, rng, axis, inputs, error_name=None):
5258 input1 = inputs[0]
5259 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005260
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005261 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005262 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005263 if not (
5264 # unable to concat tensors of different ranks
5265 error_name == ErrorIf.ConcatInputRankMismatch
5266 # unable to concat tensors along an invalid axis
5267 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005268 ):
5269 for tensor in remaining_inputs:
5270 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005271
Matthew Haddon01c359d2021-10-15 16:30:48 +01005272 if error_name == ErrorIf.ConcatShapeSumMismatch:
5273 output_shape[axis] += rng.integers(5, 10)
5274
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005275 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005276 all_dtypes = {
5277 DType.INT8,
5278 DType.INT16,
5279 DType.INT32,
5280 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005281 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005282 DType.FP16,
5283 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005284 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005285 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5286 outputDType = rng.choice(wrong_dtypes)
5287 else:
5288 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005289
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005290 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005291
5292 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005293 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005294
5295 output_shape = a.shape.copy()
5296
5297 for i in range(len(output_shape)):
5298 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5299
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005300 if error_name == ErrorIf.PadOutputShapeMismatch:
5301 bad_dim = rng.choice(range(len(output_shape)))
5302 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005303 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005304 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005305
Matthew Haddone807aae2021-10-11 18:12:58 +01005306 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005307 all_dtypes = [
5308 DType.INT8,
5309 DType.INT16,
5310 DType.INT32,
5311 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005312 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005313 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005314 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005315 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005316 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5317 outputDType = rng.choice(wrong_dtypes)
5318 else:
5319 outputDType = a.dtype
5320
5321 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005322
5323 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005324 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005325 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005326
5327 if error_name == ErrorIf.WrongOutputType:
5328 all_dtypes = [
5329 DType.INT8,
5330 DType.INT16,
5331 DType.INT32,
5332 DType.INT48,
5333 DType.FP32,
5334 DType.FP16,
5335 DType.BF16,
5336 ]
5337 wrong_dtypes = list(set(all_dtypes))
5338 outputDType = rng.choice(wrong_dtypes)
5339 else:
5340 outputDType = DType.SHAPE
5341
5342 return ser.addOutput(output_shape, outputDType)
5343
5344 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005345 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005346 output_shape = shape.copy()
5347
Matthew Haddone807aae2021-10-11 18:12:58 +01005348 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5349 for i in range(len(output_shape)):
5350 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5351
5352 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005353 all_dtypes = [
5354 DType.INT8,
5355 DType.INT16,
5356 DType.INT32,
5357 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005358 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005359 DType.FP16,
5360 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005361 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005362 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5363 outputDType = rng.choice(wrong_dtypes)
5364 else:
5365 outputDType = a.dtype
5366
5367 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005368
5369 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005370 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005371
Matthew Haddone807aae2021-10-11 18:12:58 +01005372 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005373 all_dtypes = [
5374 DType.INT8,
5375 DType.INT16,
5376 DType.INT32,
5377 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005378 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005379 DType.FP16,
5380 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005381 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005382 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005383 outputDType = rng.choice(wrong_dtypes)
5384 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005385 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005386
Luke Huttona4e48ca2023-02-22 11:53:48 +00005387 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005388 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005389 for index in range(len(output_shape)):
5390 if output_shape[index] <= 2:
5391 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5392 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005393 output_shape[index] = output_shape[index] + rng.choice(
5394 [-2, -1, 1, 2]
5395 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005396 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5397 output_shape = input.shape.copy()
5398 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005399 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005400
5401 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005402
5403 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005404 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005405
5406 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005407 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005408
5409 for i in range(len(output_shape)):
5410 output_shape[i] = a.shape[i] * multiples[i]
5411
Luke Huttona4e48ca2023-02-22 11:53:48 +00005412 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005413 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005414
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005415 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005416 all_dtypes = [
5417 DType.INT8,
5418 DType.INT16,
5419 DType.INT32,
5420 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005421 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005422 DType.FP16,
5423 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005424 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005425 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5426 outputDType = rng.choice(wrong_dtypes)
5427 else:
5428 outputDType = a.dtype
5429
5430 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005431
5432 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005433 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005434 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005435
Kevin Cheng550ccc52021-03-03 11:21:43 -08005436 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005437
Luke Huttona4e48ca2023-02-22 11:53:48 +00005438 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005439 for i in range(len(output_shape)):
5440 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005441
Luke Huttona4e48ca2023-02-22 11:53:48 +00005442 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5443 for i in range(len(output_shape)):
5444 output_shape[i] += rng.integers(1, 10)
5445 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005446 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005447
Matthew Haddone807aae2021-10-11 18:12:58 +01005448 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005449 all_dtypes = [
5450 DType.INT8,
5451 DType.INT16,
5452 DType.INT32,
5453 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005454 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005455 DType.FP16,
5456 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005457 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005458 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5459 outputDType = rng.choice(wrong_dtypes)
5460 else:
5461 outputDType = a.dtype
5462
5463 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005464
5465 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005466 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005467 if error_name != ErrorIf.WrongRank:
5468 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005469 assert len(indices.shape) == 2
5470 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005471
Kevin Cheng77d0f762020-11-24 10:26:32 -08005472 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5473
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005474 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005475 all_dtypes = [
5476 DType.INT8,
5477 DType.INT16,
5478 DType.INT32,
5479 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005480 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005481 DType.FP16,
5482 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005483 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005484 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5485 outputDType = rng.choice(wrong_dtypes)
5486 else:
5487 outputDType = values.dtype
5488
5489 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005490
5491 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005492 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005493 if error_name != ErrorIf.WrongRank:
5494 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005495 assert len(indices.shape) == 2
5496 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005497 assert values_in.shape[0] == indices.shape[0] # N
5498 assert input.shape[1] == indices.shape[1] # W
5499 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005500
5501 output_shape = values_in.shape
5502
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005503 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005504 all_dtypes = [
5505 DType.INT8,
5506 DType.INT16,
5507 DType.INT32,
5508 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005509 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005510 DType.FP16,
5511 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005512 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005513 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5514 outputDType = rng.choice(wrong_dtypes)
5515 else:
5516 outputDType = values_in.dtype
5517
5518 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005519
5520 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005521 def tableOp(ser, rng, input, error_name=None):
5522 # Same shape as the input, dtype dependent on input dtype
5523 if error_name != ErrorIf.WrongInputType:
5524 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005525 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005526 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005527 wrong_dtypes = [
5528 DType.INT8,
5529 DType.INT16,
5530 DType.INT32,
5531 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005532 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005533 DType.FP16,
5534 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005535 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005536 wrong_dtypes.remove(output_dtype)
5537 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005538 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005539
5540 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005541 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005542 serializer,
5543 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005544 input,
5545 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005546 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005547 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005548 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005549 input_dtype,
5550 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005551 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005552 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005553 # Calculate OH, OW
5554 scale_y_n = scale[0]
5555 scale_y_d = scale[1]
5556 scale_x_n = scale[2]
5557 scale_x_d = scale[3]
5558 if error_name == ErrorIf.ScaleSmallerEqualZero:
5559 scale_y_n = max(scale_y_n, 1)
5560 scale_y_d = max(scale_y_d, 1)
5561 scale_x_n = max(scale_x_n, 1)
5562 scale_x_d = max(scale_x_d, 1)
5563
5564 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5565 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5566
5567 if error_name is not None:
5568 # Make sure the output tensor is valid, which can occur when
5569 # scale, offset or border have been changed for ERROR_IFs
5570 oh = max(oh, 1)
5571 ow = max(ow, 1)
5572 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005573 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5574 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005575
5576 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5577 choices = [1, 2, 3]
5578 change = rng.choice(choices)
5579 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5580 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005581 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005582 oh -= scale_y_d
5583 assert oh > 0 # Should have been caught in agResize
5584 else:
5585 oh += scale_y_d
5586 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005587 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005588 ow -= scale_x_d
5589 assert ow > 0 # Should have been caught in agResize
5590 else:
5591 ow += scale_x_d
5592
Matthew Haddon848efb42021-09-09 12:30:53 +01005593 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005594 output_dims = [
5595 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005596 oh,
5597 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005598 input.shape[0],
5599 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005600 elif error_name == ErrorIf.BatchMismatch:
5601 output_dims = [
5602 input.shape[0] + rng.integers(1, 10),
5603 oh,
5604 ow,
5605 input.shape[3],
5606 ]
5607 elif error_name == ErrorIf.ChannelMismatch:
5608 output_dims = [
5609 input.shape[0],
5610 oh,
5611 ow,
5612 input.shape[3] + rng.integers(1, 10),
5613 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005614 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005615 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005616
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005617 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005618
5619 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005620 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005621 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005622
5623 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005624 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005625 if error_name == ErrorIf.ConvOutputShapeMismatch:
5626 choices = [1, 2, 3]
5627 change = rng.choice(choices)
5628 if change in [1, 3]:
5629 output_shape[1] = output_shape[1] + rng.choice(choices)
5630 if change in [2, 3]:
5631 output_shape[2] = output_shape[2] + rng.choice(choices)
5632
James Ward8b390432022-08-12 20:48:56 +01005633 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005634 # Pick some potentially correct output dtype if input type is incorrect
5635 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005636 else:
James Ward8b390432022-08-12 20:48:56 +01005637 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005638
5639 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005640 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005641 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005642 else:
5643 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005644 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005645 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005646
Kevin Cheng550ccc52021-03-03 11:21:43 -08005647 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005648
5649 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005650 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5651 outputs = []
5652
5653 assert ifm1.dtype == ifm2.dtype
5654 input_dtype = ifm1.dtype
5655
5656 if error_name != ErrorIf.FFTInputShapeMismatch:
5657 assert ifm1.shape == ifm2.shape
5658
5659 input_shape = ifm1.shape
5660 if error_name != ErrorIf.WrongRank:
5661 assert len(input_shape) == 3
5662
5663 output_shape = input_shape.copy()
5664 output_dtype = input_dtype
5665
5666 if error_name == ErrorIf.WrongOutputType:
5667 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005668 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005669 output_dtype = rng.choice(wrong_dtypes)
5670 elif error_name == ErrorIf.BatchMismatch:
5671 output_shape[0] += rng.integers(1, 10)
5672 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5673 modify_dim = rng.choice([1, 2])
5674 output_shape[modify_dim] += rng.integers(1, 10)
5675
5676 outputs.append(serializer.addOutput(output_shape, output_dtype))
5677 outputs.append(serializer.addOutput(output_shape, output_dtype))
5678 return outputs
5679
5680 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005681 def rfft2dOp(serializer, rng, value, error_name=None):
5682 outputs = []
5683
5684 input_shape = value.shape
5685 if error_name != ErrorIf.WrongRank:
5686 assert len(input_shape) == 3
5687
5688 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5689
5690 output_dtype = value.dtype
5691 if error_name == ErrorIf.WrongOutputType:
5692 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005693 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005694 output_dtype = rng.choice(wrong_dtypes)
5695 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005696 output_shape[0] += rng.integers(1, 10)
5697 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5698 modify_dim = rng.choice([1, 2])
5699 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005700
5701 outputs.append(serializer.addOutput(output_shape, output_dtype))
5702 outputs.append(serializer.addOutput(output_shape, output_dtype))
5703 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005704
5705 @staticmethod
5706 def addShapeOp(ser, rng, a, b, error_name=None):
5707 if error_name != ErrorIf.RankMismatch:
5708 assert len(a.shape) == len(b.shape)
5709 assert a.dtype == b.dtype
5710
5711 shape = []
5712 for i in range(len(a.shape)):
5713 shape.append(a.shape[i])
5714
5715 fuzz_idx = rng.integers(0, len(a.shape))
5716 if error_name == ErrorIf.DimensionMismatch:
5717 shape[fuzz_idx] += 1
5718
5719 if error_name == ErrorIf.WrongOutputType:
5720 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5721 outputDType = rng.choice(wrong_dtypes)
5722 else:
5723 outputDType = DType.SHAPE
5724 return ser.addOutput(shape, outputDType)