blob: d82f919c7d0a294f0d37fbe52f0128c1a9210d94 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
79 for dtype in (DType.FP32, DType.FP16, DType.BF16):
80 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000198 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 return np.int64(self.rng.integers(low=low, high=high, size=shape))
200 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
201 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
202
203 if dtype == DType.FP16:
204 return np.float16(f_tensor)
205 else:
206 f32_tensor = np.float32(f_tensor)
207 if dtype == DType.BF16:
208 # Floor the last 16 bits of each f32 value
209 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
210 else:
211 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100213 # All other integer types
214 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
Kevin Cheng989cb052021-04-28 16:29:44 -0700216 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700217 placeholders = []
218
Kevin Cheng989cb052021-04-28 16:29:44 -0700219 assert len(shape_list) == len(dtype_list)
220
Jeremy Johnson1271c442023-09-05 11:39:26 +0100221 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700222 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 if not self.args.lazy_data_gen:
224 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700225 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700226
227 return placeholders
228
Kevin Cheng989cb052021-04-28 16:29:44 -0700229 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700230 consts = []
231
Kevin Cheng989cb052021-04-28 16:29:44 -0700232 assert len(shape_list) == len(dtype_list)
233
Jeremy Johnson1271c442023-09-05 11:39:26 +0100234 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700235 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100236 if not self.args.lazy_data_gen:
237 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700238 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700239
240 return consts
241
242 def makeShape(self, rank):
243 if self.targetted_shape:
244 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 return np.int32(
246 self.rng.integers(
247 low=self.args.tensor_shape_range[0],
248 high=self.args.tensor_shape_range[1],
249 size=rank,
250 )
251 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700252
253 def setTargetShape(self, shape):
254 self.targetted_shape = shape
255
256 def randInt(self, low=0, high=256):
257 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
258
259 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 low, high = self.getDTypeRange(dtype)
261
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100262 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100263 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100264 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100265 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100266 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100267 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
268 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700269 elif dtype == DType.BOOL:
270 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000271 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700272 # Special size
273 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 return np.int32(self.rng.integers(low, high, size=1))[0]
276
277 def shapeStr(self, shape):
278
279 sStr = []
280 # Convert to strings
281 for i in shape:
282 sStr.append(str(i))
283
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100286 def typeStr(self, dtype):
287 if isinstance(dtype, list) or isinstance(dtype, tuple):
288 assert len(dtype) >= 2
289 strs = [self.typeStr(t) for t in dtype]
290 # Limit types to the first 2 as the 3rd is the accumulator
291 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100293 if dtype in gtu.DTYPE_ATTRIBUTES:
294 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700295 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100296 raise Exception(
297 "Unknown dtype, cannot convert to string: {}".format(dtype)
298 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100300 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100301 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100302 if dtype in gtu.DTYPE_ATTRIBUTES:
303 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100305 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700306
Luke Hutton57287132023-02-06 14:54:18 +0000307 def constrictBatchSize(self, shape):
308 # Limit the batch size unless an explicit target shape set
309 if self.args.max_batch_size and not self.args.target_shapes:
310 shape[0] = min(shape[0], self.args.max_batch_size)
311 return shape
312
James Ward30124a82023-02-02 14:56:33 +0000313 def makeDimension(self):
314 return self.randInt(
315 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
316 )
317
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100318 def tensorComplianceMetaData(
319 self, op, inputType, argsDict, outputTensor, errorName
320 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000321 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
322 UNSUPPORTED_NON_FP32_INPUT_OPS = (
323 Op.MATMUL,
324 Op.CONV2D,
325 Op.FULLY_CONNECTED,
326 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000327 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000328 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000329 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100330 if (
331 errorName
332 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000333 or (
334 not gtu.dtypeIsSupportedByCompliance(inputType)
335 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
336 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100337 ):
338 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100339 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100340
Jeremy Johnson1271c442023-09-05 11:39:26 +0100341 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100342 compliance_tens = {
343 "mode": None,
344 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
345 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
346 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100347 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
348 mode = gtu.ComplianceMode.DOT_PRODUCT
349 compliance_tens["dot_product_info"] = {
350 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100351 "ks": int(argsDict["ksb"])
352 if "ksb" in argsDict
353 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100354 }
355 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
356 mode = gtu.ComplianceMode.FP_SPECIAL
357 elif "compliance" in op and "ulp" in op["compliance"]:
358 mode = gtu.ComplianceMode.ULP
359 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
360 elif op["op"] == Op.REDUCE_PRODUCT:
361 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000362 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000363 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000364 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000365 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
366 compliance_tens["abs_error_info"] = {
367 "lower_bound": op["compliance"]["abs_error_lower_bound"]
368 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100369 else:
370 mode = gtu.ComplianceMode.EXACT
371 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
372
373 return compliance_tens
374
375 # Build Op functions
376 # Create the output tensor (calling OutputShaper as needed)
377 # Do final tweaks to attributes (if necessary for errorIf)
378 # Add Op into graph
379 # Return resulting tensor information or BuildInfo
380
381 class BuildInfo:
382 """Enhanced build information containing result tensor and associated compliance dict."""
383
384 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000385 if isinstance(resultTensor, list):
386 assert complianceDict is None or isinstance(complianceDict, list)
387 self.resultTensorList = resultTensor
388 self.complianceDictList = complianceDict
389 else:
390 self.resultTensorList = [resultTensor]
391 if complianceDict is None:
392 self.complianceDictList = None
393 else:
394 self.complianceDictList = [complianceDict]
395
396 def getComplianceInfo(self):
397 if self.complianceDictList is None:
398 return None
399 else:
400 tens_dict = {}
401 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
402 if comp is not None:
403 tens_dict[tens.name] = comp
404
405 if tens_dict:
406 # Have some compliance data, so return the info
407 compliance = {
408 "version": "0.1",
409 "tensors": tens_dict,
410 }
411 else:
412 compliance = None
413 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700414
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000415 def build_unary(
416 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
417 ):
418 assert len(inputs) == 1
419 a = inputs[0]
420 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100421
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000422 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100423
424 # Ensure new output type has correct qinfo
425 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000426 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000427 qinfo = [
428 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000429 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000430 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100431
432 # Invalidate Input/Output list for error if checks.
433 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000434 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100435 pCount, cCount = op["operands"]
436 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000437 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
438 self, error_name, input_list, output_list
439 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100440
Les Bell729b0352021-11-24 10:28:21 +0000441 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100442 self.ser,
443 validator_fcns,
444 error_name,
445 op=op,
446 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000447 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000448 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000449 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100450 input_list=input_list,
451 output_list=output_list,
452 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000453 ):
454 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100455
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000456 attr = None
457 if op["op"] == Op.NEGATE:
458 attr = ts.TosaSerializerAttribute()
459 attr.NegateAttribute(qinfo[0], qinfo[1])
460
461 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000462
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000463 compliance = self.tensorComplianceMetaData(
464 op, a.dtype, args_dict, result_tensor, error_name
465 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000466 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700467
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000468 def build_binary_broadcast(
469 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
470 ):
471 assert len(inputs) == 2
472 a, b = inputs
473 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000474 self.ser, self.rng, a, b, error_name
475 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100476
477 # Invalidate Input/Output list for error if checks.
478 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000479 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100480 pCount, cCount = op["operands"]
481 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000482 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
483 self, error_name, input_list, output_list
484 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100485
Les Bell729b0352021-11-24 10:28:21 +0000486 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100487 self.ser,
488 validator_fcns,
489 error_name,
490 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000491 input1=a,
492 input2=b,
493 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000494 output_dtype=result_tensor.dtype,
495 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100496 input_list=input_list,
497 output_list=output_list,
498 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000499 ):
500 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100501
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000502 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000503
Jeremy Johnson9a758382023-11-07 16:27:35 +0000504 compliance = self.tensorComplianceMetaData(
505 op, a.dtype, args_dict, result_tensor, error_name
506 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000507
508 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700509
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100510 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700511 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000512 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700513 return result_tens
514
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000515 def build_arithmetic_right_shift(
516 self, op, a, b, round, validator_fcns=None, error_name=None
517 ):
518 result_tens = OutputShaper.binaryBroadcastOp(
519 self.ser, self.rng, a, b, error_name
520 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100521
522 # Invalidate Input/Output list for error if checks.
523 input_list = [a.name, b.name]
524 output_list = [result_tens.name]
525 pCount, cCount = op["operands"]
526 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000527 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
528 self, error_name, input_list, output_list
529 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100530
Les Bell729b0352021-11-24 10:28:21 +0000531 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100532 self.ser,
533 validator_fcns,
534 error_name,
535 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000536 input1=a,
537 input2=b,
538 input_dtype=a.dtype,
539 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000540 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100541 input_list=input_list,
542 output_list=output_list,
543 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000544 ):
545 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800546
547 attr = ts.TosaSerializerAttribute()
548 attr.ArithmeticRightShiftAttribute(round)
549
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000550 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800551 return result_tens
552
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100553 def build_mul(
554 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
555 ):
556 assert len(inputs) == 2
557 a, b = inputs
558 shift = args_dict["shift"]
559
560 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000561 self.ser, self.rng, a, b, error_name
562 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700563
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100564 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100565 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100566 result_tensor.setDtype(DType.INT32)
567
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100568 if error_name == ErrorIf.WrongOutputType:
569 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
570 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100571 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100572
573 # Invalidate Input/Output list for error if checks.
574 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100575 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100576 pCount, cCount = op["operands"]
577 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000578 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
579 self, error_name, input_list, output_list
580 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100581
Les Bell729b0352021-11-24 10:28:21 +0000582 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100583 self.ser,
584 validator_fcns,
585 error_name,
586 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000587 input1=a,
588 input2=b,
589 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100590 output_dtype=result_tensor.dtype,
591 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100592 input_list=input_list,
593 output_list=output_list,
594 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000595 ):
596 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700597
Kevin Chengaee1fac2020-11-11 13:54:06 -0800598 attr = ts.TosaSerializerAttribute()
599 attr.MulAttribute(shift)
600
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000601 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100602
603 compliance = self.tensorComplianceMetaData(
604 op, a.dtype, args_dict, result_tensor, error_name
605 )
606
607 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700608
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100609 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
610 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700611
Kevin Chengfe392ce2021-10-18 21:51:55 +0000612 attr = ts.TosaSerializerAttribute()
613 attr.TableAttribute(table)
614
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100615 # Invalidate Input/Output list for error if checks.
616 input_list = [a.name]
617 output_list = [result_tens.name]
618 pCount, cCount = op["operands"]
619 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000620 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
621 self, error_name, input_list, output_list
622 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100623
Les Bell729b0352021-11-24 10:28:21 +0000624 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100625 self.ser,
626 validator_fcns,
627 error_name,
628 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000629 input_shape=a.shape,
630 input_dtype=a.dtype,
631 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000632 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100633 input_list=input_list,
634 output_list=output_list,
635 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000636 ):
637 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100638
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000639 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700640
641 return result_tens
642
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000643 def build_select(
644 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
645 ):
646 assert len(inputs) == 3
647 cond, a, b = inputs
648
649 result_tensor = OutputShaper.selectOp(
650 self.ser, self.rng, cond, a, b, error_name
651 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100652
653 # Invalidate Input/Output list for error if checks.
654 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000655 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100656 pCount, cCount = op["operands"]
657 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000658 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
659 self, error_name, input_list, output_list
660 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100661
Les Bell729b0352021-11-24 10:28:21 +0000662 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100663 self.ser,
664 validator_fcns,
665 error_name,
666 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 input1=cond,
668 input2=a,
669 input3=b,
670 input_shape=a.shape,
671 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000672 output_dtype=result_tensor.dtype,
673 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100674 input_list=input_list,
675 output_list=output_list,
676 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000677 ):
678 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100679
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000680 self.ser.addOperator(
681 op["op"],
682 input_list,
683 output_list,
684 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000685 compliance = self.tensorComplianceMetaData(
686 op, a.dtype, args_dict, result_tensor, error_name
687 )
688
689 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700690
Jeremy Johnsona0150012023-11-15 15:52:06 +0000691 def build_comparison(
692 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
693 ):
694 assert len(inputs) == 2
695 a, b = inputs
696
697 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000698 self.ser, self.rng, a, b, error_name
699 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100700
701 # Invalidate Input/Output list for error if checks.
702 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000703 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100704 pCount, cCount = op["operands"]
705 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000706 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
707 self, error_name, input_list, output_list
708 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100709
Les Bell729b0352021-11-24 10:28:21 +0000710 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100711 self.ser,
712 validator_fcns,
713 error_name,
714 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000715 input1=a,
716 input2=b,
717 input_shape=a.shape,
718 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000719 output_shape=result_tensor.shape,
720 output_dtype=result_tensor.dtype,
721 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100722 input_list=input_list,
723 output_list=output_list,
724 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000725 ):
726 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100727
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000728 self.ser.addOperator(
729 op["op"],
730 input_list,
731 output_list,
732 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000733
734 compliance = self.tensorComplianceMetaData(
735 op, a.dtype, args_dict, result_tensor, error_name
736 )
737 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700738
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000739 def build_argmax(
740 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
741 ):
742 assert len(inputs) == 1
743 a = inputs[0]
744 axis = args_dict["axis"]
745 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100746
747 # Invalidate Input/Output list for error if checks.
748 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000749 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100750 pCount, cCount = op["operands"]
751 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000752 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
753 self, error_name, input_list, output_list
754 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100755
Les Bell729b0352021-11-24 10:28:21 +0000756 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100757 self.ser,
758 validator_fcns,
759 error_name,
760 op=op,
761 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000762 input_shape=a.shape,
763 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000764 output_shape=result_tensor.shape,
765 output_dtype=result_tensor.dtype,
766 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100767 input_list=input_list,
768 output_list=output_list,
769 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000770 ):
771 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700772
773 attr = ts.TosaSerializerAttribute()
774 attr.AxisAttribute(axis)
775
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000776 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000777
778 compliance = self.tensorComplianceMetaData(
779 op, inputs[0].dtype, args_dict, result_tensor, error_name
780 )
781 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700782
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000783 def build_pool2d(
784 self,
785 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100786 inputs,
787 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000788 validator_fcns=None,
789 error_name=None,
790 qinfo=None,
791 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100792 assert len(inputs) == 1
793 input = inputs[0]
794 # max_pool has no accum_dtype
795 accum_dtype = (
796 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
797 )
798 stride = args_dict["stride"]
799 pad = args_dict["pad"]
800 kernel = args_dict["kernel"]
801
Jeremy Johnson0601f802023-11-08 16:28:09 +0000802 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000803 self.ser, self.rng, input, kernel, stride, pad, error_name
804 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100805
806 # Ensure new output type has correct qinfo
807 if error_name == ErrorIf.WrongInputType:
808 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000809 qinfo = [
810 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000811 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000812 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100813
814 # Invalidate Input/Output list for error if checks.
815 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000816 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100817 pCount, cCount = op["operands"]
818 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000819 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
820 self, error_name, input_list, output_list
821 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100822
Les Bell729b0352021-11-24 10:28:21 +0000823 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100824 self.ser,
825 validator_fcns,
826 error_name,
827 op=op,
828 input_shape=input.shape,
829 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000830 output_shape=result_tensor.shape,
831 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100832 kernel=kernel,
833 stride=stride,
834 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000835 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000836 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100837 input_list=input_list,
838 output_list=output_list,
839 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000840 ):
841 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700842
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000843 if qinfo is None:
844 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700845
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000846 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100847 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000848
849 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700850
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100851 compliance = self.tensorComplianceMetaData(
852 op, inputs[0].dtype, args_dict, result_tensor, error_name
853 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100854
855 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100856
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000857 def build_conv2d(
858 self,
859 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100860 inputs,
861 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000862 validator_fcns=None,
863 error_name=None,
864 qinfo=None,
865 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100866 assert len(inputs) == 3
867 ifm, filter, bias = inputs
868 accum_dtype = args_dict["acc_type"]
869 strides = args_dict["stride"]
870 padding = args_dict["pad"]
871 dilations = args_dict["dilation"]
872
Kevin Cheng550ccc52021-03-03 11:21:43 -0800873 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100874 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100875 self.ser,
876 self.rng,
877 ifm,
878 filter,
879 accum_dtype,
880 strides,
881 padding,
882 dilations,
883 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000884 )
885
886 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000887 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
888 DType.INT8,
889 DType.UINT8,
890 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000891 qinfo = [
892 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100893 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000894 ]
Les Bell0e027d42021-11-09 14:42:14 +0000895
896 # Invalidate Input/Output list for error_if checks.
897 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100898 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000899 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000900 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
901 self, error_name, input_list, output_list
902 )
Les Bell0e027d42021-11-09 14:42:14 +0000903
Les Bell729b0352021-11-24 10:28:21 +0000904 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000905 self.ser,
906 validator_fcns,
907 error_name,
908 op=op,
909 input_dtype=ifm.dtype,
910 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100911 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000912 qinfo=qinfo,
913 input_list=input_list,
914 num_operands=num_operands,
915 output_list=output_list,
916 pad=padding,
917 stride=strides,
918 dilation=dilations,
919 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100920 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100921 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000922 ):
923 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700924
Tai Lyd3797f02023-11-15 23:06:19 +0000925 # TODO - Test local_bound, for now set local bound attribute to False
926 local_bound = False
927
Eric Kunzee5e26762020-10-13 16:11:07 -0700928 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000929 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700930
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000931 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100932
933 compliance = self.tensorComplianceMetaData(
934 op, ifm.dtype, args_dict, result_tensor, error_name
935 )
936
937 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700938
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000939 def build_conv3d(
940 self,
941 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100942 inputs,
943 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000944 validator_fcns=None,
945 error_name=None,
946 qinfo=None,
947 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100948 assert len(inputs) == 3
949 ifm, filter, bias = inputs
950 accum_dtype = args_dict["acc_type"]
951 strides = args_dict["stride"]
952 padding = args_dict["pad"]
953 dilations = args_dict["dilation"]
954
Kevin Cheng1533b852021-09-01 12:51:58 -0700955 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000956 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100957 self.ser,
958 self.rng,
959 ifm,
960 filter,
961 accum_dtype,
962 strides,
963 padding,
964 dilations,
965 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000966 )
967
968 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000969 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
970 DType.INT8,
971 DType.UINT8,
972 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000973 qinfo = [
974 TosaQuantGen.getZeroPoint(self, ifm.dtype),
evacha0147ab1762024-01-29 13:23:23 +0000975 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000976 ]
Les Bell0e027d42021-11-09 14:42:14 +0000977
978 # Invalidate Input/Output list for error_if checks.
979 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +0000980 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000981 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000982 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
983 self, error_name, input_list, output_list
984 )
Les Bell0e027d42021-11-09 14:42:14 +0000985
Les Bell729b0352021-11-24 10:28:21 +0000986 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000987 self.ser,
988 validator_fcns,
989 error_name,
990 op=op,
991 input_dtype=ifm.dtype,
992 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +0000993 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000994 qinfo=qinfo,
995 input_list=input_list,
996 num_operands=num_operands,
997 output_list=output_list,
998 pad=padding,
999 stride=strides,
1000 dilation=dilations,
1001 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001002 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +00001003 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001004 ):
1005 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07001006
Tai Lyd3797f02023-11-15 23:06:19 +00001007 # TODO - Test local_bound, for now set local bound attribute to False
1008 local_bound = False
1009
Kevin Cheng1533b852021-09-01 12:51:58 -07001010 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001011 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -07001012
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001013 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001014
1015 compliance = self.tensorComplianceMetaData(
1016 op, ifm.dtype, args_dict, result_tensor, error_name
1017 )
1018
1019 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001020
Kevin Cheng550ccc52021-03-03 11:21:43 -08001021 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001022 self,
1023 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001024 inputs,
1025 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001026 validator_fcns=None,
1027 error_name=None,
1028 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001029 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001030 assert len(inputs) == 3
1031 ifm, filter, bias = inputs
1032 accum_dtype = args_dict["acc_type"]
1033 strides = args_dict["stride"]
1034 out_pad = args_dict["pad"]
1035 output_shape = args_dict["out_shape"]
1036
TatWai Chong24594f52022-06-08 00:48:04 -07001037 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001038 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001039 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001040 )
Les Bell0e027d42021-11-09 14:42:14 +00001041
1042 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001043 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1044 DType.INT8,
1045 DType.UINT8,
1046 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001047 qinfo = [
1048 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001049 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001050 ]
Les Bell0e027d42021-11-09 14:42:14 +00001051
1052 # Invalidate Input/Output list for error_if checks.
1053 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001054 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001055 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001056 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1057 self, error_name, input_list, output_list
1058 )
Les Bell0e027d42021-11-09 14:42:14 +00001059
Les Bell729b0352021-11-24 10:28:21 +00001060 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001061 self.ser,
1062 validator_fcns,
1063 error_name,
1064 op=op,
1065 input_dtype=ifm.dtype,
1066 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001067 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001068 qinfo=qinfo,
1069 input_list=input_list,
1070 num_operands=num_operands,
1071 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001072 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001073 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001074 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001075 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001076 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001077 ):
1078 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001079
Tai Lyd3797f02023-11-15 23:06:19 +00001080 # TODO - Test local_bound, for now set local bound attribute to False
1081 local_bound = False
1082
Eric Kunzee5e26762020-10-13 16:11:07 -07001083 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001084 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001085 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001086 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001087
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001088 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001089
1090 compliance = self.tensorComplianceMetaData(
1091 op, ifm.dtype, args_dict, result_tensor, error_name
1092 )
1093
1094 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001095
Kevin Cheng550ccc52021-03-03 11:21:43 -08001096 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001097 self,
1098 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001099 inputs,
1100 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001101 validator_fcns=None,
1102 error_name=None,
1103 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001104 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001105 assert len(inputs) == 3
1106 ifm, filter, bias = inputs
1107 accum_dtype = args_dict["acc_type"]
1108 strides = args_dict["stride"]
1109 padding = args_dict["pad"]
1110 dilations = args_dict["dilation"]
1111
Jeremy Johnson4f931302024-01-04 17:05:24 +00001112 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001113 self.ser,
1114 self.rng,
1115 ifm,
1116 filter,
1117 accum_dtype,
1118 strides,
1119 padding,
1120 dilations,
1121 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001122 )
1123
1124 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001125 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1126 DType.INT8,
1127 DType.UINT8,
1128 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001129 qinfo = [
1130 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001131 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001132 ]
Les Bell0e027d42021-11-09 14:42:14 +00001133
1134 # Invalidate Input/Output list for error_if checks.
1135 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001136 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001137 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001138 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1139 self, error_name, input_list, output_list
1140 )
Les Bell0e027d42021-11-09 14:42:14 +00001141
Les Bell729b0352021-11-24 10:28:21 +00001142 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001143 self.ser,
1144 validator_fcns,
1145 error_name,
1146 op=op,
1147 input_dtype=ifm.dtype,
1148 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001149 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001150 qinfo=qinfo,
1151 input_list=input_list,
1152 num_operands=num_operands,
1153 output_list=output_list,
1154 pad=padding,
1155 stride=strides,
1156 dilation=dilations,
1157 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001158 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001159 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001160 ):
1161 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001162
Tai Lyd3797f02023-11-15 23:06:19 +00001163 # TODO - Test local_bound, for now set local bound attribute to False
1164 local_bound = False
1165
Eric Kunzee5e26762020-10-13 16:11:07 -07001166 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001167 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001168
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001169 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001170
1171 compliance = self.tensorComplianceMetaData(
1172 op, ifm.dtype, args_dict, result_tensor, error_name
1173 )
1174
1175 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001176
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001177 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001178 self,
1179 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001180 inputs,
1181 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001182 validator_fcns=None,
1183 error_name=None,
1184 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001185 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001186 assert len(inputs) == 3
1187 ifm, filter, bias = inputs
1188 accum_dtype = args_dict["acc_type"]
1189
1190 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001191 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001192 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001193
1194 # Invalidate Input/Output list for error if checks.
1195 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001196 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001197 pCount, cCount = op["operands"]
1198 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001199 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1200 self, error_name, input_list, output_list
1201 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001202
Les Bell729b0352021-11-24 10:28:21 +00001203 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001204 self.ser,
1205 validator_fcns,
1206 error_name,
1207 op=op,
1208 input_shape=ifm.shape,
1209 input_dtype=ifm.dtype,
1210 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001211 output_shape=result_tensor.shape,
1212 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001213 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001214 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001215 input_list=input_list,
1216 output_list=output_list,
1217 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001218 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001219 ):
1220 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001221
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001222 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001223 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001224
1225 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001226
1227 compliance = self.tensorComplianceMetaData(
1228 op, ifm.dtype, args_dict, result_tensor, error_name
1229 )
1230
1231 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001232
James Ward8b390432022-08-12 20:48:56 +01001233 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001234 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001235 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001236 assert len(inputs) == 2
1237 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001238 accum_dtype = args_dict["acc_type"]
1239 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001240 self.ser, self.rng, a, b, accum_dtype, error_name
1241 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001242
1243 # Invalidate Input/Output list for error if checks.
1244 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001245 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001246 pCount, cCount = op["operands"]
1247 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001248 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1249 self, error_name, input_list, output_list
1250 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001251
Les Bell729b0352021-11-24 10:28:21 +00001252 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001253 self.ser,
1254 validator_fcns,
1255 error_name,
1256 op=op,
1257 input_shape=a.shape,
1258 input_dtype=a.dtype,
1259 input2_shape=b.shape,
1260 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001261 output_shape=result_tensor.shape,
1262 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001263 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001264 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001265 input_list=input_list,
1266 output_list=output_list,
1267 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001268 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001269 ):
1270 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001271
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001272 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001273 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001274
1275 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001276
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001277 compliance = self.tensorComplianceMetaData(
1278 op, a.dtype, args_dict, result_tensor, error_name
1279 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001280
1281 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001282
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001283 def build_reduce(
1284 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1285 ):
1286 assert len(inputs) == 1
1287 a = inputs[0]
1288 axis = args_dict["axis"]
1289 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001290
1291 # Invalidate Input/Output list for error if checks.
1292 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001293 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001294 pCount, cCount = op["operands"]
1295 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001296 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1297 self, error_name, input_list, output_list
1298 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001299
Les Bell729b0352021-11-24 10:28:21 +00001300 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001301 self.ser,
1302 validator_fcns,
1303 error_name,
1304 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001305 axis=axis,
1306 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001307 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001308 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001309 output_dtype=result_tensor.dtype,
1310 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001311 input_list=input_list,
1312 output_list=output_list,
1313 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001314 ):
1315 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001316
1317 attr = ts.TosaSerializerAttribute()
1318 attr.AxisAttribute(axis)
1319
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001320 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001321
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001322 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1323 # Number of products - needed for compliance
1324 args_dict["n"] = a.shape[axis]
1325
1326 compliance = self.tensorComplianceMetaData(
1327 op, a.dtype, args_dict, result_tensor, error_name
1328 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001329
1330 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001331
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001332 def build_clamp(
1333 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1334 ):
1335 assert len(inputs) == 1
1336 a = inputs[0]
1337
1338 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001339
Jeremy Johnson18e26662021-07-22 16:15:29 +01001340 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001341
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001342 if error_name == ErrorIf.MaxSmallerMin:
1343 # Make sure the numbers are different to invoke this error
1344 while v[0] == v[1]:
1345 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1346 max_val = min(v)
1347 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001348 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001349 max_val = max(v)
1350 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001351
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001352 # Invalidate Input/Output list for error if checks.
1353 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001354 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001355 pCount, cCount = op["operands"]
1356 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001357 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1358 self, error_name, input_list, output_list
1359 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001360
Les Bell729b0352021-11-24 10:28:21 +00001361 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001362 self.ser,
1363 validator_fcns,
1364 error_name,
1365 op=op,
1366 max_val=max_val,
1367 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001368 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001369 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001370 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001371 output_dtype=result_tensor.dtype,
1372 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001373 input_list=input_list,
1374 output_list=output_list,
1375 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001376 ):
1377 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001378
1379 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001380 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1381 if a.dtype == DType.FP16:
1382 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1383 min_val = min_val.astype(np.float32)
1384 max_val = max_val.astype(np.float32)
1385
1386 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001387 else:
James Ward34071252022-12-07 15:48:47 +00001388 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001389
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001390 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001391
1392 compliance = self.tensorComplianceMetaData(
1393 op, a.dtype, args_dict, result_tensor, error_name
1394 )
1395
1396 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001397
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001398 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1399 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001400 attr = ts.TosaSerializerAttribute()
1401
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001402 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001403
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001404 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001405 return result_tens
1406
1407 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001408 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1409 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001410
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001411 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001412 return result_tens
1413
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001414 def build_activation(
1415 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1416 ):
1417 assert len(inputs) == 1
1418 a = inputs[0]
1419
1420 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001421
1422 # Invalidate Input/Output list for error if checks.
1423 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001424 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001425 pCount, cCount = op["operands"]
1426 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001427 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1428 self, error_name, input_list, output_list
1429 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001430
Les Bell729b0352021-11-24 10:28:21 +00001431 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432 self.ser,
1433 validator_fcns,
1434 error_name,
1435 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001436 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001437 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001438 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001439 output_dtype=result_tensor.dtype,
1440 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001441 input_list=input_list,
1442 output_list=output_list,
1443 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001444 ):
1445 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001446
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001447 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001448
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001449 compliance = self.tensorComplianceMetaData(
1450 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001451 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001452
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001453 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001454
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001455 def build_concat(
1456 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1457 ):
Won Jeon74342e52024-01-09 00:34:40 +00001458 if op["op"] == Op.CONCAT_SHAPE:
1459 axis = 0
1460 else:
1461 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001462 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001463 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001464
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001465 result_tensor = OutputShaper.concatOp(
1466 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001467 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001468
Matthew Haddon818ab902021-07-27 09:12:49 +01001469 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001470 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001471 input_tensor_names.append(tensor.name)
1472
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001473 # Invalidate Input/Output list for error if checks.
1474 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001475 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001476 pCount, cCount = op["operands"]
1477 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001478 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1479 self, error_name, input_list, output_list
1480 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001481
Les Bell729b0352021-11-24 10:28:21 +00001482 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001483 self.ser,
1484 validator_fcns,
1485 error_name,
1486 op=op,
1487 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001488 input_shape=inputs[0].shape,
1489 output_shape=result_tensor.shape,
1490 input_dtype=inputs[0].dtype,
1491 output_dtype=result_tensor.dtype,
1492 inputs=inputs,
1493 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001494 input_list=input_list,
1495 output_list=output_list,
1496 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001497 ):
1498 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001499
Won Jeon74342e52024-01-09 00:34:40 +00001500 if op["op"] == Op.CONCAT:
1501 attr = ts.TosaSerializerAttribute()
1502 attr.AxisAttribute(axis)
1503 else:
1504 assert op["op"] == Op.CONCAT_SHAPE
1505 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001506 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001507
1508 compliance = self.tensorComplianceMetaData(
1509 op, inputs[0].dtype, args_dict, result_tensor, error_name
1510 )
1511
1512 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001513
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001514 def build_pad(
1515 self,
1516 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001517 inputs,
1518 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001519 validator_fcns=None,
1520 error_name=None,
1521 qinfo=None,
1522 ):
Tai Lye095da72024-01-25 22:00:18 +00001523 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001524 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001525 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001526 padding = args_dict["pad"]
1527 pad_const_int = args_dict["pad_const_int"]
1528 pad_const_float = args_dict["pad_const_fp"]
1529
1530 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001531
Tai Lye095da72024-01-25 22:00:18 +00001532 # write empty padding into PadAttribute to ensure inputs[1] is used
Kevin Chengfe392ce2021-10-18 21:51:55 +00001533 attr = ts.TosaSerializerAttribute()
Tai Lye095da72024-01-25 22:00:18 +00001534 attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001535
Matthew Haddone807aae2021-10-11 18:12:58 +01001536 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001537 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001538 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001539 pCount, cCount = op["operands"]
1540 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001541 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1542 self, error_name, input_list, output_list
1543 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001544
Les Bell729b0352021-11-24 10:28:21 +00001545 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001546 self.ser,
1547 validator_fcns,
1548 error_name,
1549 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001550 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001551 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001552 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001553 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001554 pad=padding,
1555 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001556 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001557 input_list=input_list,
1558 output_list=output_list,
1559 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001560 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001561 ):
1562 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001563
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001564 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001565
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001566 compliance = self.tensorComplianceMetaData(
1567 op, a.dtype, args_dict, result_tensor, error_name
1568 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001569
1570 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001571
Won Jeona21b2e82023-08-10 10:33:01 +00001572 def build_dim(
1573 self,
1574 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001575 inputs,
1576 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001577 validator_fcns=None,
1578 error_name=None,
1579 qinfo=None,
1580 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001581 assert len(inputs) == 1
1582 a = inputs[0]
1583 axis = args_dict["axis"]
1584 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001585
1586 # Invalidate Input/Output list for error if checks.
1587 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001588 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001589 pCount, cCount = op["operands"]
1590 num_operands = pCount + cCount
1591 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1592 self, error_name, input_list, output_list
1593 )
1594
1595 if not TosaErrorValidator.evValidateErrorIfs(
1596 self.ser,
1597 validator_fcns,
1598 error_name,
1599 op=op,
1600 axis=axis,
1601 input_shape=a.shape,
1602 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001603 output_shape=result_tensor.shape,
1604 output_dtype=result_tensor.dtype,
1605 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001606 input_list=input_list,
1607 output_list=output_list,
1608 num_operands=num_operands,
1609 ):
1610 return None
1611
1612 attr = ts.TosaSerializerAttribute()
1613 attr.AxisAttribute(axis)
1614
1615 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001616 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001617
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001618 def build_reshape(
1619 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1620 ):
Tai Ly8690a082023-12-18 20:40:24 +00001621 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001622 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001623 shape = inputs[1]
1624 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001625 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001626 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001627 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001628
1629 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001630 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001631 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001632 pCount, cCount = op["operands"]
1633 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001634 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1635 self, error_name, input_list, output_list
1636 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001637
Les Bell729b0352021-11-24 10:28:21 +00001638 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001639 self.ser,
1640 validator_fcns,
1641 error_name,
1642 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001643 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001644 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001645 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001646 output_dtype=result_tensor.dtype,
1647 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001648 input_list=input_list,
1649 output_list=output_list,
1650 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001651 ):
1652 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001653
Tai Ly8690a082023-12-18 20:40:24 +00001654 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001655
1656 compliance = self.tensorComplianceMetaData(
1657 op, a.dtype, args_dict, result_tensor, error_name
1658 )
1659
1660 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001661
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001662 def build_reverse(
1663 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1664 ):
1665 assert len(inputs) == 1
1666 a = inputs[0]
1667 axis = args_dict["axis"]
1668 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001669
1670 # Invalidate Input/Output list for error if checks.
1671 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001672 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001673 pCount, cCount = op["operands"]
1674 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001675 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1676 self, error_name, input_list, output_list
1677 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001678
Les Bell729b0352021-11-24 10:28:21 +00001679 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001680 self.ser,
1681 validator_fcns,
1682 error_name,
1683 op=op,
1684 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001685 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001686 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001687 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001688 output_dtype=result_tensor.dtype,
1689 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001690 input_list=input_list,
1691 output_list=output_list,
1692 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001693 ):
1694 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001695
1696 attr = ts.TosaSerializerAttribute()
1697 attr.AxisAttribute(axis)
1698
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001699 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001700 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001701
evacha0198477222024-01-26 12:25:32 +00001702 def build_transpose(
1703 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1704 ):
1705 assert len(inputs) == 1
1706 a = inputs[0]
1707 perms = args_dict["perms"]
1708
1709 result_tensor = OutputShaper.transposeOp(
1710 self.ser, self.rng, a, perms, error_name
1711 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001712
Kevin Chengfe392ce2021-10-18 21:51:55 +00001713 attr = ts.TosaSerializerAttribute()
1714 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001715
Matthew Haddone807aae2021-10-11 18:12:58 +01001716 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001717 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001718 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001719 pCount, cCount = op["operands"]
1720 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001721 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1722 self, error_name, input_list, output_list
1723 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001724
Les Bell729b0352021-11-24 10:28:21 +00001725 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001726 self.ser,
1727 validator_fcns,
1728 error_name,
1729 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001730 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001731 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001732 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001733 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001734 output_dtype=result_tensor.dtype,
1735 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001736 input_list=input_list,
1737 output_list=output_list,
1738 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001739 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001740 ):
1741 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001742
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001743 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001744
1745 compliance = self.tensorComplianceMetaData(
1746 op, a.dtype, args_dict, result_tensor, error_name
1747 )
1748
1749 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001750
evacha017f7d4252024-01-24 12:08:09 +00001751 def build_slice(
1752 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1753 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001754 assert len(inputs) == 3
1755 a, start_var, size_var = inputs
1756 start_const = args_dict["start"]
1757 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001758
1759 result_tensor = OutputShaper.sliceOp(
TatWai Chongf15bad82024-01-31 21:33:27 -08001760 self.ser, self.rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001761 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001762
1763 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001764 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001765 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001766 pCount, cCount = op["operands"]
1767 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001768 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1769 self, error_name, input_list, output_list
1770 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001771
Les Bell729b0352021-11-24 10:28:21 +00001772 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001773 self.ser,
1774 validator_fcns,
1775 error_name,
1776 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001777 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001778 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001779 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001780 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001781 start=start_const,
1782 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001783 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001784 input_list=input_list,
1785 output_list=output_list,
1786 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001787 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001788 ):
1789 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001790
TatWai Chongf15bad82024-01-31 21:33:27 -08001791 # TODO remove the slice attribute once shape dynamism support is mature.
Eric Kunzee5e26762020-10-13 16:11:07 -07001792 attr = ts.TosaSerializerAttribute()
TatWai Chongf15bad82024-01-31 21:33:27 -08001793 attr.SliceAttribute(start_const, size_const)
Eric Kunzee5e26762020-10-13 16:11:07 -07001794
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001795 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha017f7d4252024-01-24 12:08:09 +00001796
1797 compliance = self.tensorComplianceMetaData(
1798 op, a.dtype, args_dict, result_tensor, error_name
1799 )
1800
1801 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001802
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001803 def build_tile(
1804 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1805 ):
Tai Ly8690a082023-12-18 20:40:24 +00001806 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001807 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001808 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001809 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001810 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001811 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001812 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001813
1814 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001815 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001816 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001817 pCount, cCount = op["operands"]
1818 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001819 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1820 self, error_name, input_list, output_list
1821 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001822
Les Bell729b0352021-11-24 10:28:21 +00001823 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001824 self.ser,
1825 validator_fcns,
1826 error_name,
1827 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001828 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001829 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001830 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001831 output_dtype=result_tensor.dtype,
1832 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001833 input_list=input_list,
1834 output_list=output_list,
1835 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001836 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001837 ):
1838 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001839
Tai Ly8690a082023-12-18 20:40:24 +00001840 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001841
1842 compliance = self.tensorComplianceMetaData(
1843 op, a.dtype, args_dict, result_tensor, error_name
1844 )
1845
1846 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001847
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001848 def build_gather(
1849 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1850 ):
1851 assert len(inputs) == 2
1852 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001853
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001854 result_tensor = OutputShaper.gatherOp(
1855 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001856 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001857
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001858 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001859 input_list = [values.name, indices.name]
1860 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001861 pCount, cCount = op["operands"]
1862 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001863 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1864 self, error_name, input_list, output_list
1865 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001866
Les Bell729b0352021-11-24 10:28:21 +00001867 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001868 self.ser,
1869 validator_fcns,
1870 error_name,
1871 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001872 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001873 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001874 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001875 output_dtype=result_tensor.dtype,
1876 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001877 input_list=input_list,
1878 output_list=output_list,
1879 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001880 ):
1881 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001882
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001883 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001884
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001885 compliance = self.tensorComplianceMetaData(
1886 op, values.dtype, args_dict, result_tensor, error_name
1887 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001888
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001889 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001890
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001891 def build_scatter(
1892 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1893 ):
1894 assert len(inputs) == 3
1895 values_in, indices, input = inputs
1896 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001897 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001898 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001899
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001900 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001901 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001902 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001903 pCount, cCount = op["operands"]
1904 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001905 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1906 self, error_name, input_list, output_list
1907 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001908
Les Bell729b0352021-11-24 10:28:21 +00001909 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001910 self.ser,
1911 validator_fcns,
1912 error_name,
1913 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001914 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001915 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001916 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001917 output_dtype=result_tensor.dtype,
1918 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001919 input_list=input_list,
1920 output_list=output_list,
1921 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001922 ):
1923 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001924
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001925 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001926
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001927 compliance = self.tensorComplianceMetaData(
1928 op, values_in.dtype, args_dict, result_tensor, error_name
1929 )
1930
1931 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001932
Kevin Cheng550ccc52021-03-03 11:21:43 -08001933 def build_resize(
1934 self,
1935 op,
1936 input,
1937 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001938 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001939 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001940 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001941 input_dtype,
1942 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001943 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001944 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001945 ):
1946 result_tens = OutputShaper.resizeOp(
1947 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001948 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001949 input,
1950 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001951 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001952 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001953 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001954 input_dtype,
1955 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001956 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001957 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001958
Matthew Haddon848efb42021-09-09 12:30:53 +01001959 # Invalidate Input/Output list for error if checks.
1960 input_list = [input.name]
1961 output_list = [result_tens.name]
1962 pCount, cCount = op["operands"]
1963 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001964 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1965 self, error_name, input_list, output_list
1966 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001967
Les Bell729b0352021-11-24 10:28:21 +00001968 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001969 self.ser,
1970 validator_fcns,
1971 error_name,
1972 op=op,
1973 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001974 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001975 input_dtype=input_dtype,
1976 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001977 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001978 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001979 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001980 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001981 input_list=input_list,
1982 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001983 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001984 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001985 ):
1986 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001987
Eric Kunzee5e26762020-10-13 16:11:07 -07001988 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001989
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001990 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001991
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001992 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001993 return result_tens
1994
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001995 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1996 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1997 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001998 self.ser.addOperator(
1999 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2000 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002001 return result_tens
2002
evacha0198477222024-01-26 12:25:32 +00002003 def build_const(
2004 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2005 ):
2006 assert len(inputs) == 1
2007 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002008 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002009
2010 compliance = self.tensorComplianceMetaData(
2011 op, val.dtype, args_dict, val, error_name
2012 )
2013
2014 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002015
2016 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002017 def build_cast(
2018 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2019 ):
2020 assert len(inputs) == 1
2021 val = inputs[0]
2022 out_dtype = args_dict["out_type"]
2023
2024 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002025 self.ser, self.rng, val, out_dtype, error_name
2026 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002027
2028 # Invalidate Input/Output list for error if checks.
2029 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002030 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002031 pCount, cCount = op["operands"]
2032 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002033 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2034 self, error_name, input_list, output_list
2035 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002036
Les Bell729b0352021-11-24 10:28:21 +00002037 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002038 self.ser,
2039 validator_fcns,
2040 error_name,
2041 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002042 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002043 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002044 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002045 output_dtype=result_tensor.dtype,
2046 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002047 input_list=input_list,
2048 output_list=output_list,
2049 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002050 ):
2051 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002052
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002053 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002054
2055 compliance = self.tensorComplianceMetaData(
2056 op, val.dtype, args_dict, result_tensor, error_name
2057 )
2058
2059 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002060
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002061 def build_rescale(
2062 self,
2063 op,
2064 val,
2065 out_dtype,
2066 scale32,
2067 double_round,
2068 per_channel,
2069 validator_fcns,
2070 error_name,
2071 ):
2072 result_tens = OutputShaper.typeConversionOp(
2073 self.ser, self.rng, val, out_dtype, error_name
2074 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002075
2076 if per_channel:
2077 nc = val.shape[-1]
2078 else:
2079 nc = 1
2080
2081 in_type_width = self.typeWidth(val.dtype)
2082 out_type_width = self.typeWidth(out_dtype)
2083
Tai Ly8690a082023-12-18 20:40:24 +00002084 input_unsigned = False
2085 output_unsigned = False
2086
Kevin Cheng3a478572021-01-22 17:21:02 -08002087 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002088 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002089 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002090 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002091 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002092 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002093 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002094 elif error_name in [
2095 ErrorIf.InputZeroPointNotZero,
2096 ErrorIf.U16InputZeroPointNotValid,
2097 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002098 input_zp = self.randInt(-128, 128)
2099 if input_zp == 0:
2100 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002101 in_type_width += 1
2102 elif val.dtype == DType.UINT16:
2103 # Must come after ErrorIf.U16InputZeroPointNotValid check
2104 input_zp = self.rng.choice([0, 32768])
2105 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002106 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002107 else:
2108 input_zp = 0
2109
Kevin Cheng3a478572021-01-22 17:21:02 -08002110 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002111 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002112 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002113 elif out_dtype == DType.UINT8:
2114 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002115 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002116 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002117 elif error_name in [
2118 ErrorIf.OutputZeroPointNotZero,
2119 ErrorIf.U16OutputZeroPointNotValid,
2120 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002121 output_zp = self.randInt(-128, 128)
2122 if output_zp == 0:
2123 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002124 out_type_width += 1
2125 elif out_dtype == DType.UINT16:
2126 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2127 output_zp = self.rng.choice([0, 32768])
2128 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002129 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002130 else:
2131 output_zp = 0
2132
2133 # Calculate scale based on:
2134 # scale = a *(2^output_width)/(2^input_width))
2135
2136 a = np.float32(self.rng.random(size=[nc]))
2137 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2138
2139 if scale32:
2140 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002141 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002142 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2143 else:
2144 # Cap the scaling at 2^15 - 1 for scale16
2145 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2146
Kevin Cheng550ccc52021-03-03 11:21:43 -08002147 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002148
2149 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2150 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002151 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2152 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002153
2154 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002155 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2156 scale_arr[i], scale32
2157 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002158 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2159 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002160
Kevin Cheng550ccc52021-03-03 11:21:43 -08002161 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002162 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002163 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002164 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002165 assert val.placeholderFilename
2166 values = np.load(
2167 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2168 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002169 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2170 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2171 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002172 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2173 # Check we can safely convert to the expected dtype
2174 assert (
2175 val_adj.all() >= np.iinfo(values.dtype).min
2176 and val_adj.all() <= np.iinfo(values.dtype).max
2177 )
2178
2179 # Force casting to output datatype
2180 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2181
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002182 if not np.all(np.array_equal(values, val_adj)):
2183 # Values changed so overwrite file with new values
2184 np.save(
2185 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2186 val_adj,
2187 False,
2188 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002189
Matthew Haddonc2025212021-10-08 21:21:05 +01002190 # Invalidate Input/Output list for error if checks.
2191 input_list = [val.name]
2192 output_list = [result_tens.name]
2193 pCount, cCount = op["operands"]
2194 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002195 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2196 self, error_name, input_list, output_list
2197 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002198
2199 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002200 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002201 self.ser,
2202 validator_fcns,
2203 error_name,
2204 op=op,
2205 input_dtype=val.dtype,
2206 output_dtype=out_dtype,
2207 input_shape=val.shape,
2208 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002209 scale32=scale32,
2210 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002211 input_list=input_list,
2212 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002213 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002214 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002215 ):
2216 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002217
Eric Kunzee5e26762020-10-13 16:11:07 -07002218 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002219 attr.RescaleAttribute(
2220 input_zp,
2221 output_zp,
2222 multiplier_arr,
2223 shift_arr,
2224 scale32,
2225 double_round,
2226 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002227 input_unsigned,
2228 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002229 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002230
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002231 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002232 return result_tens
2233
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002234 def _get_condition_tensor(self, op, cond, error_name):
2235 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002236 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002237 else:
2238 cond_type = DType.BOOL
2239 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2240 choice = self.rng.choice([1, 2])
2241 if choice == 1:
2242 cond_shape = [2]
2243 else:
2244 cond_shape = [1, 2]
2245 else:
2246 # Must be of size 1 (rank 0)
2247 cond_shape = []
2248 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2249 return cond_tens
2250
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002251 def build_cond_if_const(
2252 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2253 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002254 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002255 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002256 # and fill them with const nodes for the body.
2257
2258 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002259 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002260
2261 # Make then/else tensors
2262 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002263
2264 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002265 if error_name in [
2266 ErrorIf.CondIfOutputListThenGraphMismatch,
2267 ErrorIf.CondIfOutputListElseGraphMismatch,
2268 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002269 incorrect_shape = deepcopy(then_tens.shape)
2270 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002271 incorrect_shape[i] += (
2272 self.rng.choice([-3, -2, 2, 3])
2273 if incorrect_shape[i] > 3
2274 else self.rng.choice([1, 2, 4])
2275 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002276 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2277
Jeremy Johnson18e26662021-07-22 16:15:29 +01002278 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2279 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002280
2281 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002282 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002283
2284 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002285 then_block = "THEN_BLOCK"
2286 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002287 attr = ts.TosaSerializerAttribute()
2288 attr.CondIfAttribute(then_block, else_block)
2289
2290 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002291 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002292
Jerry Ge9e94af82022-10-27 09:57:00 -07002293 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002294 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002295 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2296 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2297 else:
2298 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002299 self.ser.addOutputTensor(then_tens)
2300
Jerry Ge9e94af82022-10-27 09:57:00 -07002301 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002302 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2303 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2304 else:
2305 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002306 self.ser.addOutputTensor(else_tens)
2307
Les Bell729b0352021-11-24 10:28:21 +00002308 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002309 self.ser,
2310 validator_fcns,
2311 error_name,
2312 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002313 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002314 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002315 ):
2316 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002317
Eric Kunzee5e26762020-10-13 16:11:07 -07002318 return result_tens
2319
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002320 def build_cond_if_binary(
2321 self, op, a, b, cond, validator_fcns=None, error_name=None
2322 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002323 # For cond_if with a binary op in the then/else blocks, take a and b and
2324 # alternately add or subtract them based on the condition
2325
2326 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002327 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002328
Kevin Cheng550ccc52021-03-03 11:21:43 -08002329 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002330
2331 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002332 then_block = "THEN_BLOCK"
2333 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002334 attr = ts.TosaSerializerAttribute()
2335 attr.CondIfAttribute(then_block, else_block)
2336
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002337 if error_name in [
2338 ErrorIf.CondIfInputListThenGraphMismatch,
2339 ErrorIf.CondIfInputListElseGraphMismatch,
2340 ErrorIf.CondIfOutputListElseGraphMismatch,
2341 ErrorIf.CondIfOutputListThenGraphMismatch,
2342 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002343 incorrect_shape = a.shape.copy()
2344 for i in range(len(incorrect_shape)):
2345 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2346 incorrect_block_input = deepcopy(a)
2347 incorrect_block_input.shape = incorrect_shape
2348
Eric Kunzee5e26762020-10-13 16:11:07 -07002349 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002350 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002351 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002352 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002353
James Ward24dbc422022-10-19 12:20:31 +01002354 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002355 then_op, else_op = Op.ADD, Op.SUB
2356 elif a.dtype in (DType.INT8, DType.INT16):
2357 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2358 else:
2359 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002360
Les Bell6040b4d2021-10-11 12:50:31 +01002361 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002362 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002363 if (
2364 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2365 and block == then_block
2366 ) or (
2367 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2368 and block == else_block
2369 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002370 self.ser.addInputTensor(incorrect_block_input)
2371 self.ser.addInputTensor(b)
2372 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002373 elif (
2374 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2375 and block == then_block
2376 ) or (
2377 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2378 and block == else_block
2379 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002380 self.ser.addInputTensor(a)
2381 self.ser.addInputTensor(b)
2382 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2383 else:
2384 self.ser.addInputTensor(a)
2385 self.ser.addInputTensor(b)
2386 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002387 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002388
Les Bell729b0352021-11-24 10:28:21 +00002389 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002390 self.ser,
2391 validator_fcns,
2392 error_name,
2393 op=op,
2394 a=a,
2395 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002396 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002397 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002398 ):
2399 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002400
Eric Kunzee5e26762020-10-13 16:11:07 -07002401 return result_tens
2402
Matthew Haddon630c17c2021-10-14 15:05:41 +01002403 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002404 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002405
Kevin Cheng550ccc52021-03-03 11:21:43 -08002406 cond_block = "COND_BLOCK"
2407 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002408
2409 attr = ts.TosaSerializerAttribute()
2410 attr.WhileLoopAttribute(cond_block, body_block)
2411
2412 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002413 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002414 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002415 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002416
2417 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002418 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2419 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002420 if error_name == ErrorIf.InputListOutputListMismatch:
2421 incorrect_acc = deepcopy(acc)
2422 for i in range(len(incorrect_acc.shape)):
2423 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2424 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2425 else:
2426 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002427
2428 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002429 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002430 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002431 [iter.name, a.name, acc.name],
2432 [iter_out.name, a_out.name, acc_out.name],
2433 attr,
2434 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002435 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002436
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002437 if error_name in [
2438 ErrorIf.InputListCondGraphMismatch,
2439 ErrorIf.InputListBodyGraphInputMismatch,
2440 ErrorIf.InputListBodyGraphOutputMismatch,
2441 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002442 incorrect_iter = deepcopy(iter)
2443 for i in range(len(incorrect_iter.shape)):
2444 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2445 if len(incorrect_iter.shape) == 0:
2446 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2447
2448 incorrect_acc = deepcopy(acc)
2449 for i in range(len(incorrect_acc.shape)):
2450 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2451
Eric Kunzee5e26762020-10-13 16:11:07 -07002452 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002453 self.ser.addBasicBlock(cond_block)
2454
Matthew Haddon630c17c2021-10-14 15:05:41 +01002455 if error_name == ErrorIf.InputListCondGraphMismatch:
2456 self.ser.addInputTensor(incorrect_iter)
2457 self.ser.addInputTensor(a)
2458 self.ser.addInputTensor(incorrect_acc)
2459 else:
2460 self.ser.addInputTensor(iter)
2461 self.ser.addInputTensor(a)
2462 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002463 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002464
2465 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002466 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002467 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002468 cond_type = DType.BOOL
2469 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2470 choice = self.rng.choice([1, 2])
2471 if choice == 1:
2472 cond_shape = [3]
2473 else:
2474 cond_shape = [1, 2]
2475 else:
2476 cond_shape = []
2477 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002478
Kevin Cheng550ccc52021-03-03 11:21:43 -08002479 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002480
2481 # BODY block (input: a, acc, iter, output: a, acc, iter)
2482 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002483 self.ser.addBasicBlock(body_block)
2484
Matthew Haddon630c17c2021-10-14 15:05:41 +01002485 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2486 self.ser.addInputTensor(incorrect_iter)
2487 self.ser.addInputTensor(a)
2488 self.ser.addInputTensor(incorrect_acc)
2489 else:
2490 self.ser.addInputTensor(iter)
2491 self.ser.addInputTensor(a)
2492 self.ser.addInputTensor(acc)
2493
Kevin Cheng550ccc52021-03-03 11:21:43 -08002494 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002495
2496 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002497 iter_body_out = self.ser.addIntermediate(
2498 incorrect_iter.shape, incorrect_iter.dtype
2499 )
2500 acc_body_out = self.ser.addIntermediate(
2501 incorrect_acc.shape, incorrect_acc.dtype
2502 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002503 else:
2504 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2505 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2506
Eric Kunzee5e26762020-10-13 16:11:07 -07002507 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2508 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2509 self.ser.addOutputTensor(iter_body_out)
2510 self.ser.addOutputTensor(a)
2511 self.ser.addOutputTensor(acc_body_out)
2512
Les Bell729b0352021-11-24 10:28:21 +00002513 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002514 self.ser,
2515 validator_fcns,
2516 error_name,
2517 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002518 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002519 ):
2520 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002521
Eric Kunzee5e26762020-10-13 16:11:07 -07002522 return acc_out
2523
Luke Hutton57287132023-02-06 14:54:18 +00002524 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002525 self,
2526 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002527 inputs,
2528 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002529 validator_fcns=None,
2530 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002531 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002532 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002533 assert len(inputs) == 2
2534 val1, val2 = inputs
2535 inverse = args_dict["inverse"]
2536
Luke Hutton57287132023-02-06 14:54:18 +00002537 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2538
2539 input_names = [val1.name, val2.name]
2540 pCount, cCount = op["operands"]
2541 num_operands = pCount + cCount
2542
2543 output_names = [res.name for res in results]
2544 output_shapes = [res.shape for res in results]
2545 output_dtypes = [res.dtype for res in results]
2546
2547 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2548 self, error_name, input_names, output_names
2549 )
2550
2551 if not TosaErrorValidator.evValidateErrorIfs(
2552 self.ser,
2553 validator_fcns,
2554 error_name,
2555 op=op,
2556 inverse=inverse,
2557 input1=val1,
2558 input2=val2,
2559 input_shape=val1.shape,
2560 input_dtype=val1.dtype,
2561 output_shape=output_shapes,
2562 output_dtype=output_dtypes,
2563 result_tensors=results,
2564 input_list=input_names,
2565 output_list=output_names,
2566 num_operands=num_operands,
2567 ):
2568 return None
2569
Tai Lyd3797f02023-11-15 23:06:19 +00002570 # TODO - Test local_bound, for now set local bound attribute to False
2571 local_bound = False
2572
Luke Hutton57287132023-02-06 14:54:18 +00002573 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002574 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002575
2576 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002577
2578 compliance = []
2579 for res in results:
2580 compliance.append(
2581 self.tensorComplianceMetaData(
2582 op, val1.dtype, args_dict, res, error_name
2583 )
2584 )
2585
2586 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002587
Tai Lyd3797f02023-11-15 23:06:19 +00002588 def build_rfft2d(
2589 self,
2590 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002591 inputs,
2592 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002593 validator_fcns=None,
2594 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002595 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002596 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002597 assert len(inputs) == 1
2598 val = inputs[0]
Luke Hutton261b7b62023-01-10 14:50:31 +00002599 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2600
2601 input_names = [val.name]
2602 pCount, cCount = op["operands"]
2603 num_operands = pCount + cCount
2604
2605 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002606 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002607 output_dtypes = [res.dtype for res in results]
2608
2609 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2610 self, error_name, input_names, output_names
2611 )
2612
2613 if not TosaErrorValidator.evValidateErrorIfs(
2614 self.ser,
2615 validator_fcns,
2616 error_name,
2617 op=op,
2618 input_shape=val.shape,
2619 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002620 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002621 output_dtype=output_dtypes,
2622 result_tensors=results,
2623 input_list=input_names,
2624 output_list=output_names,
2625 num_operands=num_operands,
2626 ):
2627 return None
2628
Tai Lyd3797f02023-11-15 23:06:19 +00002629 # TODO - Test local_bound, for now set local bound attribute to False
2630 local_bound = False
2631
2632 attr = ts.TosaSerializerAttribute()
2633 attr.RFFTAttribute(local_bound)
2634
2635 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002636
2637 compliance = []
2638 for res in results:
2639 compliance.append(
2640 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2641 )
2642
2643 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002644
Won Jeon74342e52024-01-09 00:34:40 +00002645 def build_shape_op(
2646 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2647 ):
2648 assert len(inputs) == 2
2649 a, b = inputs
2650
2651 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2652
2653 # Invalidate Input/Output list for error if checks.
2654 input_list = [a.name, b.name]
2655 output_list = [result_tensor.name]
2656 pCount, cCount = op["operands"]
2657 num_operands = pCount + cCount
2658 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2659 self, error_name, input_list, output_list
2660 )
2661
2662 if not TosaErrorValidator.evValidateErrorIfs(
2663 self.ser,
2664 validator_fcns,
2665 error_name,
2666 op=op,
2667 input1=a,
2668 input2=b,
2669 input_shape=a.shape,
2670 input_dtype=a.dtype,
2671 output_shape=result_tensor.shape,
2672 output_dtype=result_tensor.dtype,
2673 result_tensors=[result_tensor],
2674 input_list=input_list,
2675 output_list=output_list,
2676 num_operands=num_operands,
2677 ):
2678 return None
2679
2680 self.ser.addOperator(
2681 op["op"],
2682 input_list,
2683 output_list,
2684 )
2685 compliance = self.tensorComplianceMetaData(
2686 op, a.dtype, args_dict, result_tensor, error_name
2687 )
2688
2689 return TosaTestGen.BuildInfo(result_tensor, compliance)
2690
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002691 def create_filter_lists(
2692 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2693 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002694 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2695 default_test_rank_range = range(1, 5)
2696 if not shapeFilter:
2697 shapeFilter = [None]
2698
2699 # Calculate the filters based on what is requested and what the operator allows
2700 rmin, rmax = op["rank"]
2701 if rankFilter is not None:
2702 cleanRankFilter = []
2703 # Ensure rankFilter values are allowed by operator
2704 for rank in rankFilter:
2705 if rank >= rmin and rank <= rmax:
2706 cleanRankFilter.append(rank)
2707 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002708 # Ensure default behaviour is bounded by default range or by operator,
2709 # whichever is the smaller range of ranks.
2710 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002711 cleanRankFilter = (
2712 opRankRange
2713 if len(opRankRange) <= len(default_test_rank_range)
2714 else default_test_rank_range
2715 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002716 else:
2717 cleanRankFilter = range(rmin, rmax + 1)
2718
2719 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002720
Matthew Haddon1c00b712021-10-01 15:51:03 +01002721 if dtypeFilter is not None:
2722 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002723 # Create list of operator dtypes filtered by requested dtypes
2724 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002725 if dtype in dtypeFilter or (
2726 isinstance(dtype, list) and dtype[0] in dtypeFilter
2727 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002728 cleanDtypeFilter.append(dtype)
2729 else:
2730 cleanDtypeFilter = dtypes
2731
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002732 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002733 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002734 "shapeFilter": shapeFilter,
2735 "rankFilter": cleanRankFilter,
2736 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002737 }
2738 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002739 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002740 if validator is not None:
2741 validator_info = validator(check=False, op=op)
2742 else:
2743 return None
2744
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002745 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002746
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002747 # Set parameters as required
2748 if error_arguments["rank"] is not None:
2749 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002750 else:
2751 rankFilter = cleanRankFilter
2752
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002753 if error_arguments["dtype"] is not None:
2754 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002755 else:
2756 dtypeFilter = cleanDtypeFilter
2757
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002758 if error_arguments["shape"] is not None:
2759 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002760 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002761 shapeFilter = shapeFilter[
2762 :2
2763 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002764
2765 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002766 "shapeFilter": shapeFilter,
2767 "rankFilter": rankFilter,
2768 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002769 }
2770 return filterDict
2771
Kevin Cheng550ccc52021-03-03 11:21:43 -08002772 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002773 self,
2774 opName,
2775 shapeFilter=[None],
2776 rankFilter=None,
2777 dtypeFilter=None,
2778 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002779 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002780
2781 try:
2782 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002783 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002784 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002785
2786 # Initialize a new random number generator
2787 self.rng = np.random.default_rng(self.random_seed)
2788
Jeremy Johnson1271c442023-09-05 11:39:26 +01002789 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002790
Eric Kunzee5e26762020-10-13 16:11:07 -07002791 # Test list consists of a tuple of:
2792 # (opName, testNameStr, dtype, shapeList, argumentsList)
2793 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002794 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002795 error_if_validators = op["error_if_validators"]
2796 else:
2797 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002798
Matthew Haddon1c00b712021-10-01 15:51:03 +01002799 for validator in error_if_validators:
2800 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002801 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002802 else:
2803 error_name = None
2804
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002805 filterDict = self.create_filter_lists(
2806 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2807 )
2808 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002809 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002810 cleanRankFilter = filterDict["rankFilter"]
2811 cleanDtypeFilter = filterDict["dtypeFilter"]
2812 cleanShapeFilter = filterDict["shapeFilter"]
2813 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002814
2815 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002816 for t in cleanDtypeFilter:
2817 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002818 # Filter out by rank
2819 if shape is not None and len(shape) != r:
2820 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002821 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002822 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002823
Matthew Haddon74567092021-07-16 15:38:20 +01002824 shapeStr = self.shapeStr(shapeList[0])
2825 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002826
Matthew Haddon74567092021-07-16 15:38:20 +01002827 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2828 argList = []
2829 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002830 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002831 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002832 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002833
Matthew Haddon74567092021-07-16 15:38:20 +01002834 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002835 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002836 if argStr:
2837 testStr = "{}_{}_{}_{}".format(
2838 opName, shapeStr, typeStr, argStr
2839 )
2840 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002841 testStr = "{}_{}_{}".format(
2842 opName, shapeStr, typeStr
2843 )
2844 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002845 if argStr:
2846 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2847 opName, error_name, shapeStr, typeStr, argStr
2848 )
2849 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002850 testStr = "{}_ERRORIF_{}_{}_{}".format(
2851 opName, error_name, shapeStr, typeStr
2852 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002853
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002854 testList.append(
2855 (opName, testStr, t, error_name, shapeList, args)
2856 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002857
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002858 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002859 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2860 if "invalid_test_validators" in op:
2861 invalid_test_validators = op["invalid_test_validators"]
2862 clean_testList = []
2863 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002864 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002865 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002866 if validator_fcn(
2867 opName=test[0],
2868 input_dtype=test[2],
2869 shapeList=test[4],
2870 args=test[5],
2871 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002872 remove_test = True
2873 if not remove_test:
2874 clean_testList.append(test)
2875 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002876
2877 return testList
2878
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002879 def serializeTest(
2880 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2881 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002882 try:
2883 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002884 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002885 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002886
Jeremy Johnson0c716862023-04-13 17:18:19 +01002887 if self.args.verbose:
2888 print(f"Creating {testStr}")
2889
Eric Kunzee5e26762020-10-13 16:11:07 -07002890 # Create a serializer
2891 self.createSerializer(opName, testStr)
2892
Jeremy Johnson1271c442023-09-05 11:39:26 +01002893 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002894 if "error_if_validators" in op:
2895 error_if_validators = op["error_if_validators"]
2896 else:
2897 error_if_validators = None
2898
Kevin Cheng550ccc52021-03-03 11:21:43 -08002899 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002900 num_operands = pCount + cCount
2901
2902 if isinstance(dtype_or_dtypeList, list):
2903 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00002904 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002905 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002906 else:
2907 dtypeList = [dtype_or_dtypeList] * (num_operands)
2908
Won Jeon74342e52024-01-09 00:34:40 +00002909 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002910 assert (
2911 len(shapeList) == num_operands
2912 ), "shapeList length {} must match number of operands {}".format(
2913 len(shapeList), num_operands
2914 )
2915 assert (
2916 len(dtypeList) == num_operands
2917 ), "dtypeList length {} must match number of operands {}".format(
2918 len(dtypeList), num_operands
2919 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002920
2921 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002922 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002923 except KeyError:
2924 qgen = None
2925
2926 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002927
Matthew Haddon1c00b712021-10-01 15:51:03 +01002928 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002929 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002930 else:
2931 qinfo = None
2932
Jeremy Johnson1271c442023-09-05 11:39:26 +01002933 # Extra meta data for the desc.json
2934 tensMeta = {}
2935
2936 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002937 if isinstance(testArgs, dict):
2938 # New interface with args info in dictionary
2939 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002940 assert "dg_type" in argsDict
2941 tvgInfo = tvgen_fcn(
2942 self, opName, dtypeList, shapeList, argsDict, error_name
2943 )
2944 if tvgInfo.dataGenDict:
2945 tensMeta["data_gen"] = tvgInfo.dataGenDict
2946 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002947
2948 result = build_fcn(
2949 self,
2950 op,
2951 tens,
2952 argsDict,
2953 validator_fcns=error_if_validators,
2954 error_name=error_name,
2955 qinfo=qinfo,
2956 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002957 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002958 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002959 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002960
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002961 try:
2962 if error_if_validators is None:
2963 if qinfo is not None:
2964 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2965 else:
2966 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002967 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002968 if qinfo is not None:
2969 result = build_fcn(
2970 self,
2971 op,
2972 *tens,
2973 *testArgs,
2974 validator_fcns=error_if_validators,
2975 error_name=error_name,
2976 qinfo=qinfo,
2977 )
2978 else:
2979 result = build_fcn(
2980 self,
2981 op,
2982 *tens,
2983 *testArgs,
2984 validator_fcns=error_if_validators,
2985 error_name=error_name,
2986 )
2987 except TypeError as e:
2988 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2989 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002990
Jeremy Johnson1271c442023-09-05 11:39:26 +01002991 if result:
Les Bell729b0352021-11-24 10:28:21 +00002992 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002993 if isinstance(result, TosaTestGen.BuildInfo):
2994 # Add the compliance meta data (if any)
2995 compliance = result.getComplianceInfo()
2996 if compliance:
2997 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01002998 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002999 else:
3000 # The test is not valid
3001 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01003002
Eric Kunzee5e26762020-10-13 16:11:07 -07003003 def createDynamicOpLists(self):
3004
Jeremy Johnson00423432022-09-12 17:27:37 +01003005 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3006 # Already created these lists (can occur when class is initialized more than once)
3007 return
3008
Eric Kunzee5e26762020-10-13 16:11:07 -07003009 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003010 if not self.args.level8k:
3011 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3012 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3013 else:
3014 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3015 KERNELS_2D = [[1, bigK], [bigK, 2]]
3016 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003017
Kevin Cheng1533b852021-09-01 12:51:58 -07003018 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003019 testName = "conv2d_{}x{}".format(k[0], k[1])
3020 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3021 self.TOSA_OP_LIST[testName]["filter"] = k
3022 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003023
Kevin Cheng550ccc52021-03-03 11:21:43 -08003024 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3025 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3026 "depthwise_conv2d_TEMPLATE"
3027 ].copy()
3028 self.TOSA_OP_LIST[testName]["filter"] = k
3029 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003030
Kevin Cheng550ccc52021-03-03 11:21:43 -08003031 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3032 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3033 "transpose_conv2d_TEMPLATE"
3034 ].copy()
3035 self.TOSA_OP_LIST[testName]["filter"] = k
3036 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003037
Kevin Cheng1533b852021-09-01 12:51:58 -07003038 for k in KERNELS_3D:
3039 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3040 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3041 self.TOSA_OP_LIST[testName]["filter"] = k
3042 self.TOSA_OP_LIST[testName]["template"] = False
3043
Eric Kunzee5e26762020-10-13 16:11:07 -07003044 # Delete any templates after having created any dynamic ops
3045 # This is a two-pass operation because it's bad practice to delete
3046 # keys from dictionaries while iterating
3047 keyList = []
3048 for k in self.TOSA_OP_LIST:
3049 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003050 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003051 keyList.append(k)
3052 continue
3053 except KeyError:
3054 pass
3055
3056 for k in keyList:
3057 del self.TOSA_OP_LIST[k]
3058
3059 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003060 """Fill in default fields for ops if they aren't already specified.
3061 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003062 for op in self.TOSA_OP_LIST:
3063
3064 # Required fields
3065 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003066 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003067 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003068 raise Exception(
3069 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3070 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003071
3072 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003073 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003074 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003075 raise Exception(
3076 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3077 op
3078 )
3079 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003080
3081 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003082 _ = self.TOSA_OP_LIST[op]["types"]
3083 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003084 raise Exception(
3085 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3086 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003087
3088 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003089 _ = self.TOSA_OP_LIST[op]["op"]
3090 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003091 raise Exception(
3092 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3093 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003094
3095 # Put in default rank range, if missing
3096 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003097 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003098 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003099 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003100
3101 # Tensor operator list
3102 # 'op': op name
3103 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003104 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3105 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003106 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3107 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003108 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003109
Kevin Cheng550ccc52021-03-03 11:21:43 -08003110 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003111 TYPE_INT_FP = [
3112 DType.INT8,
3113 DType.INT16,
3114 DType.INT32,
3115 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003116 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003117 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003118 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003119
Kevin Cheng550ccc52021-03-03 11:21:43 -08003120 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003121 TYPE_FI32 = [
3122 DType.FP32,
3123 DType.FP16,
3124 DType.BF16,
3125 DType.INT32,
3126 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003127 TYPE_FIB = [
3128 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003129 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003130 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003131 DType.INT8,
3132 DType.INT16,
3133 DType.INT32,
3134 DType.BOOL,
3135 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003136 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003137
James Ward24dbc422022-10-19 12:20:31 +01003138 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07003139
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003140 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003141 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003142 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003143 [DType.INT8, DType.INT8, DType.INT32],
3144 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003145 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003146 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003147 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003148 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003149 ]
3150
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003151 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003152
3153 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003154 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003155 "argmax": {
3156 "op": Op.ARGMAX,
3157 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003158 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003159 "build_fcn": (
3160 build_argmax,
3161 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003162 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003163 TosaArgGen.agAxis,
3164 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003165 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003166 "error_if_validators": (
3167 TosaErrorValidator.evAxisSmallerZero,
3168 TosaErrorValidator.evAxisLargerRank,
3169 TosaErrorValidator.evArgmaxOutputRankMismatch,
3170 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3171 TosaErrorValidator.evWrongRank,
3172 TosaErrorValidator.evWrongInputType,
3173 TosaErrorValidator.evWrongOutputType,
3174 TosaErrorValidator.evWrongInputList,
3175 TosaErrorValidator.evWrongOutputList,
3176 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003177 "data_gen": {
3178 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3179 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003180 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003181 "avg_pool2d": {
3182 "op": Op.AVG_POOL2D,
3183 "operands": (1, 0),
3184 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003185 "build_fcn": (
3186 build_pool2d,
3187 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003188 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003189 TosaArgGen.agPooling,
3190 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003191 "qgen": TosaQuantGen.qgUnary,
3192 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003193 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003194 "error_if_validators": (
3195 TosaErrorValidator.evKernelSmallerOne,
3196 TosaErrorValidator.evStrideSmallerOne,
3197 TosaErrorValidator.evPadSmallerZero,
3198 TosaErrorValidator.evWrongRank,
3199 TosaErrorValidator.evWrongInputType,
3200 TosaErrorValidator.evWrongOutputType,
3201 TosaErrorValidator.evWrongInputList,
3202 TosaErrorValidator.evWrongOutputList,
3203 TosaErrorValidator.evInputZeroPointNotZero,
3204 TosaErrorValidator.evOutputZeroPointNotZero,
3205 TosaErrorValidator.evPadLargerEqualKernel,
3206 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003207 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003208 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003209 "data_gen": {
3210 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3211 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003212 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003213 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003214 "conv2d_TEMPLATE": {
3215 "op": Op.CONV2D,
3216 "operands": (1, 2),
3217 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003218 "build_fcn": (
3219 build_conv2d,
3220 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003221 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003222 TosaArgGen.agConv,
3223 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003224 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003225 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003226 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3227 "error_if_validators": (
3228 TosaErrorValidator.evWrongInputType,
3229 TosaErrorValidator.evWrongOutputType,
3230 TosaErrorValidator.evWrongInputList,
3231 TosaErrorValidator.evWrongOutputList,
3232 TosaErrorValidator.evInputZeroPointNotZero,
3233 TosaErrorValidator.evWeightZeroPointNotZero,
3234 TosaErrorValidator.evPadSmallerZero,
3235 TosaErrorValidator.evStrideSmallerOne,
3236 TosaErrorValidator.evDilationSmallerOne,
3237 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003238 TosaErrorValidator.evConvOutputShapeMismatch,
3239 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003240 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003241 "data_gen": {
3242 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3243 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003244 "template": True,
3245 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003246 # Templated operator. Filled in by createDynamicOpLists
3247 "conv3d_TEMPLATE": {
3248 "op": Op.CONV3D,
3249 "operands": (1, 2),
3250 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003251 "build_fcn": (
3252 build_conv3d,
3253 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003254 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003255 TosaArgGen.agConv,
3256 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003257 "qgen": TosaQuantGen.qgConv,
3258 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003259 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3260 "error_if_validators": (
3261 TosaErrorValidator.evWrongInputType,
3262 TosaErrorValidator.evWrongOutputType,
3263 TosaErrorValidator.evWrongInputList,
3264 TosaErrorValidator.evWrongOutputList,
3265 TosaErrorValidator.evInputZeroPointNotZero,
3266 TosaErrorValidator.evWeightZeroPointNotZero,
3267 TosaErrorValidator.evPadSmallerZero,
3268 TosaErrorValidator.evStrideSmallerOne,
3269 TosaErrorValidator.evDilationSmallerOne,
3270 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003271 TosaErrorValidator.evConvOutputShapeMismatch,
3272 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003273 ),
evacha0147ab1762024-01-29 13:23:23 +00003274 "data_gen": {
3275 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3276 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003277 "template": True,
3278 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003279 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003280 "depthwise_conv2d_TEMPLATE": {
3281 "op": Op.DEPTHWISE_CONV2D,
3282 "operands": (1, 2),
3283 "filter": [1, 1],
3284 "rank": (4, 4),
3285 "build_fcn": (
3286 build_depthwise_conv2d,
3287 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003288 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003289 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003290 ),
3291 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003292 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003293 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3294 "error_if_validators": (
3295 TosaErrorValidator.evWrongInputType,
3296 TosaErrorValidator.evWrongOutputType,
3297 TosaErrorValidator.evWrongInputList,
3298 TosaErrorValidator.evWrongOutputList,
3299 TosaErrorValidator.evInputZeroPointNotZero,
3300 TosaErrorValidator.evWeightZeroPointNotZero,
3301 TosaErrorValidator.evPadSmallerZero,
3302 TosaErrorValidator.evStrideSmallerOne,
3303 TosaErrorValidator.evDilationSmallerOne,
3304 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003305 TosaErrorValidator.evConvOutputShapeMismatch,
3306 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003307 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003308 "data_gen": {
3309 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3310 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003311 "template": True,
3312 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003313 "fully_connected": {
3314 "op": Op.FULLY_CONNECTED,
3315 "operands": (1, 2),
3316 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003317 "build_fcn": (
3318 build_fully_connected,
3319 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003320 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003321 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003322 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003323 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003324 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003325 "error_if_validators": (
3326 TosaErrorValidator.evInputZeroPointNotZero,
3327 TosaErrorValidator.evWeightZeroPointNotZero,
3328 TosaErrorValidator.evWrongRank,
3329 TosaErrorValidator.evWrongInputType,
3330 TosaErrorValidator.evWrongOutputType,
3331 TosaErrorValidator.evWrongInputList,
3332 TosaErrorValidator.evWrongOutputList,
3333 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003334 "data_gen": {
3335 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3336 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003337 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 "matmul": {
3339 "op": Op.MATMUL,
3340 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003341 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003342 "build_fcn": (
3343 build_matmul,
3344 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003345 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003346 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003347 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003348 "qgen": TosaQuantGen.qgMatmul,
3349 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003350 "error_if_validators": (
3351 TosaErrorValidator.evInputZeroPointNotZero,
3352 TosaErrorValidator.evWrongRank,
3353 TosaErrorValidator.evWrongInputType,
3354 TosaErrorValidator.evWrongOutputType,
3355 TosaErrorValidator.evWrongInputList,
3356 TosaErrorValidator.evWrongOutputList,
3357 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003358 "data_gen": {
3359 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003360 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003361 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003362 "max_pool2d": {
3363 "op": Op.MAX_POOL2D,
3364 "operands": (1, 0),
3365 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003366 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003367 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003368 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003369 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003370 TosaArgGen.agPooling,
3371 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003372 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003373 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003374 "error_if_validators": (
3375 TosaErrorValidator.evKernelSmallerOne,
3376 TosaErrorValidator.evStrideSmallerOne,
3377 TosaErrorValidator.evPadSmallerZero,
3378 TosaErrorValidator.evWrongRank,
3379 TosaErrorValidator.evWrongInputType,
3380 TosaErrorValidator.evWrongOutputType,
3381 TosaErrorValidator.evWrongInputList,
3382 TosaErrorValidator.evWrongOutputList,
3383 TosaErrorValidator.evPadLargerEqualKernel,
3384 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003385 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003386 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003387 "data_gen": {
3388 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3389 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003390 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003391 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003392 "transpose_conv2d_TEMPLATE": {
3393 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003394 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003395 "rank": (4, 4),
3396 "build_fcn": (
3397 build_transpose_conv2d,
3398 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003399 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003400 TosaArgGen.agTransposeConv2D,
3401 ),
3402 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003403 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003404 "invalid_test_validators": (
3405 TosaInvalidValidator.ivHeightWidthInvalid,
3406 TosaInvalidValidator.ivNonPositiveOutputShape,
3407 ),
3408 "error_if_validators": (
3409 TosaErrorValidator.evWrongInputType,
3410 TosaErrorValidator.evWrongOutputType,
3411 TosaErrorValidator.evWrongInputList,
3412 TosaErrorValidator.evWrongOutputList,
3413 TosaErrorValidator.evInputZeroPointNotZero,
3414 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003415 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003416 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003417 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003418 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003419 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003420 "data_gen": {
3421 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3422 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003423 "template": True,
3424 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003425 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003426 "clamp": {
3427 "op": Op.CLAMP,
3428 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003429 "build_fcn": (
3430 build_clamp,
3431 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003432 TosaTensorValuesGen.tvgLazyGenDefault,
3433 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003434 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003435 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003436 "error_if_validators": (
3437 TosaErrorValidator.evMaxSmallerMin,
3438 TosaErrorValidator.evWrongInputType,
3439 TosaErrorValidator.evWrongOutputType,
3440 TosaErrorValidator.evWrongInputList,
3441 TosaErrorValidator.evWrongOutputList,
3442 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003443 "data_gen": {
3444 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3445 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003446 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003447 "sigmoid": {
3448 "op": Op.SIGMOID,
3449 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003450 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003451 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003452 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003453 TosaTensorValuesGen.tvgLazyGenDefault,
3454 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003455 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003456 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003457 "error_if_validators": (
3458 TosaErrorValidator.evWrongInputType,
3459 TosaErrorValidator.evWrongOutputType,
3460 TosaErrorValidator.evWrongInputList,
3461 TosaErrorValidator.evWrongOutputList,
3462 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003463 "data_gen": {
3464 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3465 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003466 },
3467 "tanh": {
3468 "op": Op.TANH,
3469 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003470 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003471 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003472 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003473 TosaTensorValuesGen.tvgLazyGenDefault,
3474 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003475 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003476 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003477 "error_if_validators": (
3478 TosaErrorValidator.evWrongInputType,
3479 TosaErrorValidator.evWrongOutputType,
3480 TosaErrorValidator.evWrongInputList,
3481 TosaErrorValidator.evWrongOutputList,
3482 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003483 "data_gen": {
3484 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3485 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003486 "compliance": {
3487 "abs_error_lower_bound": 0.5,
3488 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003489 },
Won Jeon78155c62023-06-10 00:20:04 +00003490 "erf": {
3491 "op": Op.ERF,
3492 "operands": (1, 0),
3493 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003494 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003495 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003496 TosaTensorValuesGen.tvgLazyGenDefault,
3497 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003498 ),
3499 "types": TYPE_FP,
3500 "error_if_validators": (
3501 TosaErrorValidator.evWrongInputType,
3502 TosaErrorValidator.evWrongOutputType,
3503 TosaErrorValidator.evWrongInputList,
3504 TosaErrorValidator.evWrongOutputList,
3505 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003506 "data_gen": {
3507 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3508 },
3509 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003510 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003511 # Elementwise Binary Operators
3512 "add": {
3513 "op": Op.ADD,
3514 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003515 "build_fcn": (
3516 build_binary_broadcast,
3517 TosaTensorGen.tgBroadcastFuzz,
3518 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003519 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003520 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003521 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003522 "error_if_validators": (
3523 TosaErrorValidator.evRankMismatch,
3524 TosaErrorValidator.evWrongInputType,
3525 TosaErrorValidator.evWrongOutputType,
3526 TosaErrorValidator.evWrongInputList,
3527 TosaErrorValidator.evWrongOutputList,
3528 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003529 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003530 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003531 "data_gen": {
3532 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3533 },
3534 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003536 "arithmetic_right_shift": {
3537 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3538 "operands": (2, 0),
3539 "build_fcn": (
3540 build_arithmetic_right_shift,
3541 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003542 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003543 TosaArgGen.agArithmeticRightShift,
3544 ),
3545 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003546 "error_if_validators": (
3547 TosaErrorValidator.evRankMismatch,
3548 TosaErrorValidator.evWrongInputType,
3549 TosaErrorValidator.evWrongOutputType,
3550 TosaErrorValidator.evWrongInputList,
3551 TosaErrorValidator.evWrongOutputList,
3552 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003553 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003554 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003555 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003556 "bitwise_and": {
3557 "op": Op.BITWISE_AND,
3558 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003559 "build_fcn": (
3560 build_binary_broadcast,
3561 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003562 TosaTensorValuesGen.tvgLazyGenDefault,
3563 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003564 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003565 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003566 "error_if_validators": (
3567 TosaErrorValidator.evRankMismatch,
3568 TosaErrorValidator.evWrongInputType,
3569 TosaErrorValidator.evWrongOutputType,
3570 TosaErrorValidator.evWrongInputList,
3571 TosaErrorValidator.evWrongOutputList,
3572 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003573 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003574 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003575 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003576 "bitwise_or": {
3577 "op": Op.BITWISE_OR,
3578 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003579 "build_fcn": (
3580 build_binary_broadcast,
3581 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003582 TosaTensorValuesGen.tvgLazyGenDefault,
3583 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003584 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003585 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003586 "error_if_validators": (
3587 TosaErrorValidator.evRankMismatch,
3588 TosaErrorValidator.evWrongInputType,
3589 TosaErrorValidator.evWrongOutputType,
3590 TosaErrorValidator.evWrongInputList,
3591 TosaErrorValidator.evWrongOutputList,
3592 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003593 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003594 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003595 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003596 "bitwise_xor": {
3597 "op": Op.BITWISE_XOR,
3598 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003599 "build_fcn": (
3600 build_binary_broadcast,
3601 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003602 TosaTensorValuesGen.tvgLazyGenDefault,
3603 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003604 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003605 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003606 "error_if_validators": (
3607 TosaErrorValidator.evRankMismatch,
3608 TosaErrorValidator.evWrongInputType,
3609 TosaErrorValidator.evWrongOutputType,
3610 TosaErrorValidator.evWrongInputList,
3611 TosaErrorValidator.evWrongOutputList,
3612 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003613 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003614 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003615 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003616 "intdiv": {
3617 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003618 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003619 "build_fcn": (
3620 build_binary_broadcast,
3621 TosaTensorGen.tgBroadcastFuzz,
3622 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003623 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003624 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003625 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003626 "error_if_validators": (
3627 TosaErrorValidator.evRankMismatch,
3628 TosaErrorValidator.evWrongInputType,
3629 TosaErrorValidator.evWrongOutputType,
3630 TosaErrorValidator.evWrongInputList,
3631 TosaErrorValidator.evWrongOutputList,
3632 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003633 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003634 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003635 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003636 "logical_and": {
3637 "op": Op.LOGICAL_AND,
3638 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003639 "build_fcn": (
3640 build_binary_broadcast,
3641 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003642 TosaTensorValuesGen.tvgLazyGenDefault,
3643 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003644 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003645 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003646 "error_if_validators": (
3647 TosaErrorValidator.evRankMismatch,
3648 TosaErrorValidator.evWrongInputType,
3649 TosaErrorValidator.evWrongOutputType,
3650 TosaErrorValidator.evWrongInputList,
3651 TosaErrorValidator.evWrongOutputList,
3652 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003653 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003654 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003655 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003656 "logical_left_shift": {
3657 "op": Op.LOGICAL_LEFT_SHIFT,
3658 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003659 "build_fcn": (
3660 build_binary_broadcast,
3661 TosaTensorGen.tgBroadcastFuzz,
3662 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003663 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003664 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003665 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003666 "error_if_validators": (
3667 TosaErrorValidator.evRankMismatch,
3668 TosaErrorValidator.evWrongInputType,
3669 TosaErrorValidator.evWrongOutputType,
3670 TosaErrorValidator.evWrongInputList,
3671 TosaErrorValidator.evWrongOutputList,
3672 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003673 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003674 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003675 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003676 "logical_right_shift": {
3677 "op": Op.LOGICAL_RIGHT_SHIFT,
3678 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003679 "build_fcn": (
3680 build_binary_broadcast,
3681 TosaTensorGen.tgBroadcastFuzz,
3682 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003683 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003684 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003685 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003686 "error_if_validators": (
3687 TosaErrorValidator.evRankMismatch,
3688 TosaErrorValidator.evWrongInputType,
3689 TosaErrorValidator.evWrongOutputType,
3690 TosaErrorValidator.evWrongInputList,
3691 TosaErrorValidator.evWrongOutputList,
3692 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003693 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003694 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003695 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003696 "logical_or": {
3697 "op": Op.LOGICAL_OR,
3698 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003699 "build_fcn": (
3700 build_binary_broadcast,
3701 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003702 TosaTensorValuesGen.tvgLazyGenDefault,
3703 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003704 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003705 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003706 "error_if_validators": (
3707 TosaErrorValidator.evRankMismatch,
3708 TosaErrorValidator.evWrongInputType,
3709 TosaErrorValidator.evWrongOutputType,
3710 TosaErrorValidator.evWrongInputList,
3711 TosaErrorValidator.evWrongOutputList,
3712 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003713 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003714 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003715 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003716 "logical_xor": {
3717 "op": Op.LOGICAL_XOR,
3718 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003719 "build_fcn": (
3720 build_binary_broadcast,
3721 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003722 TosaTensorValuesGen.tvgLazyGenDefault,
3723 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003724 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003725 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003726 "error_if_validators": (
3727 TosaErrorValidator.evRankMismatch,
3728 TosaErrorValidator.evWrongInputType,
3729 TosaErrorValidator.evWrongOutputType,
3730 TosaErrorValidator.evWrongInputList,
3731 TosaErrorValidator.evWrongOutputList,
3732 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003733 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003734 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003735 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003736 "maximum": {
3737 "op": Op.MAXIMUM,
3738 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003739 "build_fcn": (
3740 build_binary_broadcast,
3741 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003742 TosaTensorValuesGen.tvgLazyGenDefault,
3743 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003744 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003745 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003746 "error_if_validators": (
3747 TosaErrorValidator.evRankMismatch,
3748 TosaErrorValidator.evWrongInputType,
3749 TosaErrorValidator.evWrongOutputType,
3750 TosaErrorValidator.evWrongInputList,
3751 TosaErrorValidator.evWrongOutputList,
3752 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003753 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003754 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003755 "data_gen": {
3756 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3757 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003758 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003759 "minimum": {
3760 "op": Op.MINIMUM,
3761 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003762 "build_fcn": (
3763 build_binary_broadcast,
3764 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003765 TosaTensorValuesGen.tvgLazyGenDefault,
3766 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003767 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003768 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003769 "error_if_validators": (
3770 TosaErrorValidator.evRankMismatch,
3771 TosaErrorValidator.evWrongInputType,
3772 TosaErrorValidator.evWrongOutputType,
3773 TosaErrorValidator.evWrongInputList,
3774 TosaErrorValidator.evWrongOutputList,
3775 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003776 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003777 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003778 "data_gen": {
3779 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3780 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003782 "mul": {
3783 "op": Op.MUL,
3784 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003785 "build_fcn": (
3786 build_mul,
3787 TosaTensorGen.tgBroadcastFuzz,
3788 TosaTensorValuesGen.tvgMul,
3789 TosaArgGen.agMul,
3790 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003791 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003792 "error_if_validators": (
3793 TosaErrorValidator.evWrongInputType,
3794 TosaErrorValidator.evWrongOutputType,
3795 TosaErrorValidator.evWrongInputList,
3796 TosaErrorValidator.evWrongOutputList,
3797 TosaErrorValidator.evRankMismatch,
3798 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003799 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003800 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003801 "data_gen": {
3802 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3803 },
3804 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003805 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003806 "pow": {
3807 "op": Op.POW,
3808 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003809 "build_fcn": (
3810 build_binary_broadcast,
3811 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003812 TosaTensorValuesGen.tvgPow,
3813 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003814 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003815 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003816 "error_if_validators": (
3817 TosaErrorValidator.evRankMismatch,
3818 TosaErrorValidator.evWrongInputType,
3819 TosaErrorValidator.evWrongOutputType,
3820 TosaErrorValidator.evWrongInputList,
3821 TosaErrorValidator.evWrongOutputList,
3822 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003823 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003824 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003825 "data_gen": {
3826 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3827 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003828 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003829 "sub": {
3830 "op": Op.SUB,
3831 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003832 "build_fcn": (
3833 build_binary_broadcast,
3834 TosaTensorGen.tgBroadcastFuzz,
3835 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003836 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003837 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003838 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003839 "error_if_validators": (
3840 TosaErrorValidator.evRankMismatch,
3841 TosaErrorValidator.evWrongInputType,
3842 TosaErrorValidator.evWrongOutputType,
3843 TosaErrorValidator.evWrongInputList,
3844 TosaErrorValidator.evWrongOutputList,
3845 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003846 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003847 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003848 "data_gen": {
3849 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3850 },
3851 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003852 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003853 "table": {
3854 "op": Op.TABLE,
3855 # Use the automatic generation functions to create the input array
3856 # but create the table tensor in the build function, as it may be
3857 # a different type from the input
3858 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003859 "build_fcn": (
3860 build_table,
3861 TosaTensorGen.tgBasic,
3862 TosaTensorValuesGen.tvgDefault,
3863 TosaArgGen.agTable,
3864 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003865 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003866 "error_if_validators": (
3867 TosaErrorValidator.evWrongInputType,
3868 TosaErrorValidator.evWrongOutputType,
3869 TosaErrorValidator.evWrongInputList,
3870 TosaErrorValidator.evWrongOutputList,
3871 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003872 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003873 # Elementwise Unary operators
3874 "abs": {
3875 "op": Op.ABS,
3876 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003877 "build_fcn": (
3878 build_unary,
3879 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003880 TosaTensorValuesGen.tvgLazyGenDefault,
3881 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003882 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003883 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003884 "error_if_validators": (
3885 TosaErrorValidator.evWrongInputType,
3886 TosaErrorValidator.evWrongOutputType,
3887 TosaErrorValidator.evWrongInputList,
3888 TosaErrorValidator.evWrongOutputList,
3889 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003890 "data_gen": {
3891 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3892 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003893 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003894 "bitwise_not": {
3895 "op": Op.BITWISE_NOT,
3896 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003897 "build_fcn": (
3898 build_unary,
3899 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003900 TosaTensorValuesGen.tvgLazyGenDefault,
3901 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003902 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003903 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003904 "error_if_validators": (
3905 TosaErrorValidator.evWrongInputType,
3906 TosaErrorValidator.evWrongOutputType,
3907 TosaErrorValidator.evWrongInputList,
3908 TosaErrorValidator.evWrongOutputList,
3909 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003910 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003911 "ceil": {
3912 "op": Op.CEIL,
3913 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003914 "build_fcn": (
3915 build_unary,
3916 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003917 TosaTensorValuesGen.tvgLazyGenDefault,
3918 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003919 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003920 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003921 "error_if_validators": (
3922 TosaErrorValidator.evWrongInputType,
3923 TosaErrorValidator.evWrongOutputType,
3924 TosaErrorValidator.evWrongInputList,
3925 TosaErrorValidator.evWrongOutputList,
3926 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003927 "data_gen": {
3928 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3929 },
3930 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003931 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003932 "clz": {
3933 "op": Op.CLZ,
3934 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003935 "build_fcn": (
3936 build_unary,
3937 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003938 TosaTensorValuesGen.tvgLazyGenDefault,
3939 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003940 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003941 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003942 "error_if_validators": (
3943 TosaErrorValidator.evWrongInputType,
3944 TosaErrorValidator.evWrongOutputType,
3945 TosaErrorValidator.evWrongInputList,
3946 TosaErrorValidator.evWrongOutputList,
3947 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003948 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003949 "exp": {
3950 "op": Op.EXP,
3951 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003952 "build_fcn": (
3953 build_unary,
3954 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003955 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003956 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003957 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003958 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003959 "error_if_validators": (
3960 TosaErrorValidator.evWrongInputType,
3961 TosaErrorValidator.evWrongOutputType,
3962 TosaErrorValidator.evWrongInputList,
3963 TosaErrorValidator.evWrongOutputList,
3964 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003965 "data_gen": {
3966 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3967 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003968 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003969 "floor": {
3970 "op": Op.FLOOR,
3971 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003972 "build_fcn": (
3973 build_unary,
3974 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003975 TosaTensorValuesGen.tvgLazyGenDefault,
3976 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003977 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003978 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003979 "error_if_validators": (
3980 TosaErrorValidator.evWrongInputType,
3981 TosaErrorValidator.evWrongOutputType,
3982 TosaErrorValidator.evWrongInputList,
3983 TosaErrorValidator.evWrongOutputList,
3984 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003985 "data_gen": {
3986 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3987 },
3988 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003989 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003990 "log": {
3991 "op": Op.LOG,
3992 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003993 "build_fcn": (
3994 build_unary,
3995 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003996 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003997 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003998 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003999 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004000 "error_if_validators": (
4001 TosaErrorValidator.evWrongInputType,
4002 TosaErrorValidator.evWrongOutputType,
4003 TosaErrorValidator.evWrongInputList,
4004 TosaErrorValidator.evWrongOutputList,
4005 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004006 "data_gen": {
4007 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4008 },
4009 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004010 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004011 "logical_not": {
4012 "op": Op.LOGICAL_NOT,
4013 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004014 "build_fcn": (
4015 build_unary,
4016 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004017 TosaTensorValuesGen.tvgLazyGenDefault,
4018 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004019 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004020 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004021 "error_if_validators": (
4022 TosaErrorValidator.evWrongInputType,
4023 TosaErrorValidator.evWrongOutputType,
4024 TosaErrorValidator.evWrongInputList,
4025 TosaErrorValidator.evWrongOutputList,
4026 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004027 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004028 "negate": {
4029 "op": Op.NEGATE,
4030 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004031 "build_fcn": (
4032 build_unary,
4033 TosaTensorGen.tgBasic,
4034 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004035 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004036 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004037 "qgen": TosaQuantGen.qgUnary,
4038 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004039 "error_if_validators": (
4040 TosaErrorValidator.evInputZeroPointNotZero,
4041 TosaErrorValidator.evOutputZeroPointNotZero,
4042 TosaErrorValidator.evWrongInputType,
4043 TosaErrorValidator.evWrongOutputType,
4044 TosaErrorValidator.evWrongInputList,
4045 TosaErrorValidator.evWrongOutputList,
4046 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004047 "data_gen": {
4048 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4049 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004050 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004051 "reciprocal": {
4052 "op": Op.RECIPROCAL,
4053 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004054 "build_fcn": (
4055 build_unary,
4056 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004057 TosaTensorValuesGen.tvgLazyGenDefault,
4058 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004059 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004060 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004061 "error_if_validators": (
4062 TosaErrorValidator.evWrongInputType,
4063 TosaErrorValidator.evWrongOutputType,
4064 TosaErrorValidator.evWrongInputList,
4065 TosaErrorValidator.evWrongOutputList,
4066 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004067 "data_gen": {
4068 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4069 },
4070 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004071 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004072 "rsqrt": {
4073 "op": Op.RSQRT,
4074 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004075 "build_fcn": (
4076 build_unary,
4077 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004078 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004079 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004080 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004081 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004082 "error_if_validators": (
4083 TosaErrorValidator.evWrongInputType,
4084 TosaErrorValidator.evWrongOutputType,
4085 TosaErrorValidator.evWrongInputList,
4086 TosaErrorValidator.evWrongOutputList,
4087 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004088 "data_gen": {
4089 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4090 },
4091 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004092 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004093 # Elementwise Ternary operators
4094 "select": {
4095 "op": Op.SELECT,
4096 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004097 "build_fcn": (
4098 build_select,
4099 TosaTensorGen.tgBroadcastFuzz,
4100 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004101 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004102 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004103 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004104 "error_if_validators": (
4105 TosaErrorValidator.evRankMismatch,
4106 TosaErrorValidator.evWrongInputType,
4107 TosaErrorValidator.evWrongOutputType,
4108 TosaErrorValidator.evWrongInputList,
4109 TosaErrorValidator.evWrongOutputList,
4110 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004111 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004112 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004113 "data_gen": {
4114 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4115 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004116 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004117 # Comparison operators
4118 "equal": {
4119 "op": Op.EQUAL,
4120 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004121 "build_fcn": (
4122 build_comparison,
4123 TosaTensorGen.tgBroadcastFuzz,
4124 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004125 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004126 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004127 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004128 "error_if_validators": (
4129 TosaErrorValidator.evRankMismatch,
4130 TosaErrorValidator.evWrongInputType,
4131 TosaErrorValidator.evWrongOutputType,
4132 TosaErrorValidator.evWrongInputList,
4133 TosaErrorValidator.evWrongOutputList,
4134 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004135 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004136 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004137 "data_gen": {
4138 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4139 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004140 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004141 "greater_equal": {
4142 "op": Op.GREATER_EQUAL,
4143 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004144 "build_fcn": (
4145 build_comparison,
4146 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004147 TosaTensorValuesGen.tvgLazyGenDefault,
4148 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004149 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004150 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004151 "error_if_validators": (
4152 TosaErrorValidator.evRankMismatch,
4153 TosaErrorValidator.evWrongInputType,
4154 TosaErrorValidator.evWrongOutputType,
4155 TosaErrorValidator.evWrongInputList,
4156 TosaErrorValidator.evWrongOutputList,
4157 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004158 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004159 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004160 "data_gen": {
4161 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4162 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004163 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004164 "greater": {
4165 "op": Op.GREATER,
4166 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004167 "build_fcn": (
4168 build_comparison,
4169 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004170 TosaTensorValuesGen.tvgLazyGenDefault,
4171 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004172 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004173 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004174 "error_if_validators": (
4175 TosaErrorValidator.evRankMismatch,
4176 TosaErrorValidator.evWrongInputType,
4177 TosaErrorValidator.evWrongOutputType,
4178 TosaErrorValidator.evWrongInputList,
4179 TosaErrorValidator.evWrongOutputList,
4180 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004181 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004182 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004183 "data_gen": {
4184 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4185 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004186 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004187 # Reduction operators
4188 "reduce_all": {
4189 "op": Op.REDUCE_ALL,
4190 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004191 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004192 "build_fcn": (
4193 build_reduce,
4194 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004195 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004196 TosaArgGen.agAxis,
4197 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004198 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004199 "error_if_validators": (
4200 TosaErrorValidator.evAxisLargerRank,
4201 TosaErrorValidator.evAxisSmallerZero,
4202 TosaErrorValidator.evShapeOfAxisNotOne,
4203 TosaErrorValidator.evWrongInputType,
4204 TosaErrorValidator.evWrongOutputType,
4205 TosaErrorValidator.evWrongRank,
4206 TosaErrorValidator.evWrongInputList,
4207 TosaErrorValidator.evWrongOutputList,
4208 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004209 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004210 "reduce_any": {
4211 "op": Op.REDUCE_ANY,
4212 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004213 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004214 "build_fcn": (
4215 build_reduce,
4216 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004217 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004218 TosaArgGen.agAxis,
4219 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004220 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004221 "error_if_validators": (
4222 TosaErrorValidator.evAxisLargerRank,
4223 TosaErrorValidator.evAxisSmallerZero,
4224 TosaErrorValidator.evShapeOfAxisNotOne,
4225 TosaErrorValidator.evWrongInputType,
4226 TosaErrorValidator.evWrongOutputType,
4227 TosaErrorValidator.evWrongRank,
4228 TosaErrorValidator.evWrongInputList,
4229 TosaErrorValidator.evWrongOutputList,
4230 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004231 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004232 "reduce_max": {
4233 "op": Op.REDUCE_MAX,
4234 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004235 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004236 "build_fcn": (
4237 build_reduce,
4238 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004239 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004240 TosaArgGen.agAxis,
4241 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004242 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004243 "error_if_validators": (
4244 TosaErrorValidator.evAxisLargerRank,
4245 TosaErrorValidator.evAxisSmallerZero,
4246 TosaErrorValidator.evShapeOfAxisNotOne,
4247 TosaErrorValidator.evWrongInputType,
4248 TosaErrorValidator.evWrongOutputType,
4249 TosaErrorValidator.evWrongRank,
4250 TosaErrorValidator.evWrongInputList,
4251 TosaErrorValidator.evWrongOutputList,
4252 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004253 "data_gen": {
4254 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4255 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004256 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004257 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004258 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004259 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004260 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004261 "build_fcn": (
4262 build_reduce,
4263 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004264 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004265 TosaArgGen.agAxis,
4266 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004267 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004268 "error_if_validators": (
4269 TosaErrorValidator.evAxisLargerRank,
4270 TosaErrorValidator.evAxisSmallerZero,
4271 TosaErrorValidator.evShapeOfAxisNotOne,
4272 TosaErrorValidator.evWrongInputType,
4273 TosaErrorValidator.evWrongOutputType,
4274 TosaErrorValidator.evWrongRank,
4275 TosaErrorValidator.evWrongInputList,
4276 TosaErrorValidator.evWrongOutputList,
4277 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004278 "data_gen": {
4279 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4280 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004281 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004282 "reduce_product": {
4283 "op": Op.REDUCE_PRODUCT,
4284 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004285 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004286 "build_fcn": (
4287 build_reduce,
4288 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004289 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004290 TosaArgGen.agAxis,
4291 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004292 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004293 "error_if_validators": (
4294 TosaErrorValidator.evAxisLargerRank,
4295 TosaErrorValidator.evAxisSmallerZero,
4296 TosaErrorValidator.evShapeOfAxisNotOne,
4297 TosaErrorValidator.evWrongInputType,
4298 TosaErrorValidator.evWrongOutputType,
4299 TosaErrorValidator.evWrongRank,
4300 TosaErrorValidator.evWrongInputList,
4301 TosaErrorValidator.evWrongOutputList,
4302 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004303 "data_gen": {
4304 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4305 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004306 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004307 "reduce_sum": {
4308 "op": Op.REDUCE_SUM,
4309 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004310 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004311 "build_fcn": (
4312 build_reduce,
4313 TosaTensorGen.tgBasic,
4314 TosaTensorValuesGen.tvgReduceSum,
4315 TosaArgGen.agAxis,
4316 ),
James Ward24dbc422022-10-19 12:20:31 +01004317 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004318 "error_if_validators": (
4319 TosaErrorValidator.evAxisLargerRank,
4320 TosaErrorValidator.evAxisSmallerZero,
4321 TosaErrorValidator.evShapeOfAxisNotOne,
4322 TosaErrorValidator.evWrongInputType,
4323 TosaErrorValidator.evWrongOutputType,
4324 TosaErrorValidator.evWrongRank,
4325 TosaErrorValidator.evWrongInputList,
4326 TosaErrorValidator.evWrongOutputList,
4327 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004328 "data_gen": {
4329 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4330 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004331 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004332 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004333 "concat": {
4334 "op": Op.CONCAT,
4335 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004336 "build_fcn": (
4337 build_concat,
4338 TosaTensorGen.tgConcat,
4339 TosaTensorValuesGen.tvgConcat,
4340 TosaArgGen.agAxis,
4341 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004342 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004343 "error_if_validators": (
4344 TosaErrorValidator.evAxisLargerRank,
4345 TosaErrorValidator.evAxisSmallerZero,
4346 TosaErrorValidator.evConcatInputRankMismatch,
4347 TosaErrorValidator.evConcatShapeSumMismatch,
4348 TosaErrorValidator.evConcatInputDimMismatch,
4349 TosaErrorValidator.evWrongInputType,
4350 TosaErrorValidator.evWrongOutputType,
4351 TosaErrorValidator.evWrongOutputList,
4352 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004353 "data_gen": {
4354 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4355 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004356 },
4357 "pad": {
4358 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004359 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004360 "build_fcn": (
4361 build_pad,
4362 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004363 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004364 TosaArgGen.agPad,
4365 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004366 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004367 "error_if_validators": (
4368 TosaErrorValidator.evWrongInputType,
4369 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004370 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004371 TosaErrorValidator.evWrongOutputType,
4372 TosaErrorValidator.evWrongInputList,
4373 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004374 TosaErrorValidator.evRankMismatch,
4375 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004376 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004377 "data_gen": {
4378 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4379 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004380 },
Won Jeona21b2e82023-08-10 10:33:01 +00004381 "dim": {
4382 "op": Op.DIM,
4383 "operands": (1, 0),
4384 "build_fcn": (
4385 build_dim,
4386 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004387 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004388 TosaArgGen.agAxis,
4389 ),
4390 "types": TYPE_FIB,
4391 "error_if_validators": (
4392 TosaErrorValidator.evAxisLargerRank,
4393 TosaErrorValidator.evAxisSmallerZero,
4394 TosaErrorValidator.evWrongInputType,
4395 TosaErrorValidator.evWrongInputList,
4396 TosaErrorValidator.evWrongOutputList,
4397 TosaErrorValidator.evWrongRank,
4398 ),
4399 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004400 "reshape": {
4401 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004402 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004403 "build_fcn": (
4404 build_reshape,
4405 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004406 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004407 TosaArgGen.agReshape,
4408 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004409 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004410 "error_if_validators": (
4411 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4412 TosaErrorValidator.evWrongInputType,
4413 TosaErrorValidator.evWrongOutputType,
4414 TosaErrorValidator.evWrongInputList,
4415 TosaErrorValidator.evWrongOutputList,
4416 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004417 "data_gen": {
4418 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4419 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004420 },
4421 "reverse": {
4422 "op": Op.REVERSE,
4423 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004424 "build_fcn": (
4425 build_reverse,
4426 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004427 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004428 TosaArgGen.agAxis,
4429 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004430 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004431 "error_if_validators": (
4432 TosaErrorValidator.evAxisSmallerZero,
4433 TosaErrorValidator.evAxisLargerRank,
4434 TosaErrorValidator.evWrongInputType,
4435 TosaErrorValidator.evWrongOutputType,
4436 TosaErrorValidator.evWrongInputList,
4437 TosaErrorValidator.evWrongOutputList,
4438 ),
evacha0198477222024-01-26 12:25:32 +00004439 "data_gen": {
4440 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4441 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004442 },
4443 "slice": {
4444 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004445 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004446 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004447 "build_fcn": (
4448 build_slice,
4449 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004450 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004451 TosaArgGen.agSlice,
4452 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004453 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004454 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004455 # TODO Turn off these error categories for now as the reference
4456 # model cannot allocate memory space for empty tensor. We probably
4457 # can report an accurate error messege at the right place during
4458 # exeuction.
4459 # TosaErrorValidator.evStartSmallerZero,
4460 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004461 TosaErrorValidator.evStartSizeOutsideBounds,
4462 TosaErrorValidator.evSizeOutputShapeMismatch,
4463 TosaErrorValidator.evInputSizeStartLengthMismatch,
4464 TosaErrorValidator.evWrongRank,
4465 TosaErrorValidator.evWrongInputType,
4466 TosaErrorValidator.evWrongOutputType,
4467 TosaErrorValidator.evWrongInputList,
4468 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004469 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004470 ),
evacha017f7d4252024-01-24 12:08:09 +00004471 "data_gen": {
4472 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4473 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004474 },
4475 "tile": {
4476 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004477 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004478 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004479 "build_fcn": (
4480 build_tile,
4481 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004482 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004483 TosaArgGen.agTile,
4484 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004485 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004486 "error_if_validators": (
4487 TosaErrorValidator.evWrongInputType,
4488 TosaErrorValidator.evWrongOutputType,
4489 TosaErrorValidator.evWrongInputList,
4490 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004491 TosaErrorValidator.evRankMismatch,
4492 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004493 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004494 "data_gen": {
4495 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4496 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004497 },
4498 "transpose": {
4499 "op": Op.TRANSPOSE,
4500 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004501 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004502 "build_fcn": (
4503 build_transpose,
4504 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004505 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004506 TosaArgGen.agTranspose,
4507 ),
4508 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004509 "error_if_validators": (
4510 TosaErrorValidator.evIndexOutsideBounds,
4511 TosaErrorValidator.evIndexUsedTwice,
4512 TosaErrorValidator.evWrongInputType,
4513 TosaErrorValidator.evWrongOutputType,
4514 TosaErrorValidator.evWrongInputList,
4515 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004516 TosaErrorValidator.evWrongRank,
4517 TosaErrorValidator.evRankMismatch,
4518 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004519 ),
evacha0198477222024-01-26 12:25:32 +00004520 "data_gen": {
4521 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4522 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004523 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004524 # Data nodes
4525 "const": {
4526 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004527 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004528 "build_fcn": (
4529 build_const,
4530 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004531 TosaTensorValuesGen.tvgLazyGenDefault,
4532 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004533 ),
Luke Hutton65872422023-02-20 10:33:04 +00004534 "types": TYPE_FIB + [DType.INT48],
evacha0198477222024-01-26 12:25:32 +00004535 "data_gen": {
4536 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4537 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004538 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004539 "identity": {
4540 "op": Op.IDENTITY,
4541 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004542 "build_fcn": (
4543 build_unary,
4544 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004545 TosaTensorValuesGen.tvgLazyGenDefault,
4546 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004547 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004548 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004549 "data_gen": {
4550 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4551 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004552 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004553 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004554 "gather": {
4555 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004556 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004557 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004558 "build_fcn": (
4559 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004560 TosaTensorGen.tgGather,
4561 TosaTensorValuesGen.tvgGather,
4562 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004563 ),
James Ward24dbc422022-10-19 12:20:31 +01004564 "types": (
4565 DType.INT8,
4566 DType.INT16,
4567 DType.INT32,
4568 DType.FP16,
4569 DType.BF16,
4570 DType.FP32,
4571 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004572 "error_if_validators": (
4573 TosaErrorValidator.evWrongInputType,
4574 TosaErrorValidator.evWrongOutputType,
4575 TosaErrorValidator.evWrongInputList,
4576 TosaErrorValidator.evWrongOutputList,
4577 TosaErrorValidator.evWrongRank,
4578 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004579 "data_gen": {
4580 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4581 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004582 },
4583 "scatter": {
4584 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004585 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004586 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004587 "build_fcn": (
4588 build_scatter,
4589 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004590 TosaTensorValuesGen.tvgScatter,
4591 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004592 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004593 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004594 "error_if_validators": (
4595 TosaErrorValidator.evWrongInputType,
4596 TosaErrorValidator.evWrongOutputType,
4597 TosaErrorValidator.evWrongInputList,
4598 TosaErrorValidator.evWrongOutputList,
4599 TosaErrorValidator.evWrongRank,
4600 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004601 "data_gen": {
4602 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4603 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004604 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004605 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004606 "resize": {
4607 "op": Op.RESIZE,
4608 "operands": (1, 0),
4609 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004610 "build_fcn": (
4611 build_resize,
4612 TosaTensorGen.tgNHWC,
4613 TosaTensorValuesGen.tvgDefault,
4614 TosaArgGen.agResize,
4615 ),
James Ward24dbc422022-10-19 12:20:31 +01004616 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004617 "invalid_test_validators": (
4618 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004619 ),
4620 "error_if_validators": (
4621 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004622 TosaErrorValidator.evScaleSmallerEqualZero,
4623 TosaErrorValidator.evScaleNLargerMax,
4624 TosaErrorValidator.evScaleDLargerMax,
4625 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004626 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004627 TosaErrorValidator.evBorderSmallerMin,
4628 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004629 TosaErrorValidator.evWrongInputType,
4630 TosaErrorValidator.evWrongOutputType,
4631 TosaErrorValidator.evWrongRank,
4632 TosaErrorValidator.evWrongInputList,
4633 TosaErrorValidator.evWrongOutputList,
4634 TosaErrorValidator.evBatchMismatch,
4635 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004636 TosaErrorValidator.evResizeOutputShapeMismatch,
4637 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004638 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004639 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004640 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004641 "cast": {
4642 "op": Op.CAST,
4643 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004644 "build_fcn": (
4645 build_cast,
4646 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004647 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004648 TosaArgGen.agCast,
4649 ),
James Ward8b390432022-08-12 20:48:56 +01004650 "types": (
4651 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004652 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004653 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004654 DType.INT8,
4655 DType.INT16,
4656 DType.INT32,
4657 DType.BOOL,
4658 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004659 "error_if_validators": (
4660 TosaErrorValidator.evWrongInputType,
4661 TosaErrorValidator.evWrongOutputType,
4662 TosaErrorValidator.evWrongInputList,
4663 TosaErrorValidator.evWrongOutputList,
4664 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004665 "data_gen": {
4666 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4667 },
4668 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004669 },
4670 "rescale": {
4671 "op": Op.RESCALE,
4672 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004673 "build_fcn": (
4674 build_rescale,
4675 TosaTensorGen.tgBasic,
4676 TosaTensorValuesGen.tvgDefault,
4677 TosaArgGen.agRescale,
4678 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004679 "types": [
4680 DType.UINT8,
4681 DType.INT8,
4682 DType.INT16,
4683 DType.INT32,
4684 DType.INT48,
4685 DType.UINT16,
4686 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004687 "error_if_validators": (
4688 TosaErrorValidator.evInputZeroPointNotZero,
4689 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004690 TosaErrorValidator.evU16InputZeroPointNotValid,
4691 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004692 TosaErrorValidator.evScaleTrue,
4693 TosaErrorValidator.evScaleNotTrue,
4694 TosaErrorValidator.evWrongInputType,
4695 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004696 TosaErrorValidator.evWrongInputList,
4697 TosaErrorValidator.evWrongOutputList,
4698 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004699 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004700 # Custom
4701 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004702 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004703 # Two varients of cond_if, one that generates one of two constant tensors (no
4704 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4705 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004706 "cond_if_const": {
4707 "op": Op.COND_IF,
4708 "operands": (0, 2),
4709 "build_fcn": (
4710 build_cond_if_const,
4711 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004712 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004713 TosaArgGen.agCondIf,
4714 ),
4715 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004716 "error_if_validators": (
4717 TosaErrorValidator.evOutputListThenGraphMismatch,
4718 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004719 TosaErrorValidator.evCondIfCondNotMatchingBool,
4720 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004721 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004722 },
4723 "cond_if_binary": {
4724 "op": Op.COND_IF,
4725 "operands": (2, 0),
4726 "build_fcn": (
4727 build_cond_if_binary,
4728 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004729 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004730 TosaArgGen.agCondIf,
4731 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004732 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004733 "error_if_validators": (
4734 TosaErrorValidator.evInputListThenGraphMismatch,
4735 TosaErrorValidator.evInputListElseGraphMismatch,
4736 TosaErrorValidator.evOutputListThenGraphMismatch,
4737 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004738 TosaErrorValidator.evCondIfCondNotMatchingBool,
4739 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004740 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004741 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004742 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004743 "while_loop": {
4744 "op": Op.WHILE_LOOP,
4745 "operands": (0, 1),
4746 "build_fcn": (
4747 build_while_loop,
4748 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004749 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004750 TosaArgGen.agWhileLoop,
4751 ),
4752 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004753 "error_if_validators": (
4754 TosaErrorValidator.evInputListOutputListMismatch,
4755 TosaErrorValidator.evInputListCondGraphMismatch,
4756 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4757 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4758 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004759 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004760 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004761 },
Luke Hutton57287132023-02-06 14:54:18 +00004762 "fft2d": {
4763 "op": Op.FFT2D,
4764 "operands": (2, 0),
4765 "rank": (3, 3),
4766 "build_fcn": (
4767 build_fft2d,
4768 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004769 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004770 TosaArgGen.agFFT2d,
4771 ),
4772 "types": [DType.FP32],
4773 "error_if_validators": (
4774 TosaErrorValidator.evWrongInputType,
4775 TosaErrorValidator.evWrongOutputType,
4776 TosaErrorValidator.evWrongInputList,
4777 TosaErrorValidator.evWrongOutputList,
4778 TosaErrorValidator.evWrongRank,
4779 TosaErrorValidator.evBatchMismatch,
4780 TosaErrorValidator.evKernelNotPowerOfTwo,
4781 TosaErrorValidator.evFFTInputShapeMismatch,
4782 TosaErrorValidator.evFFTOutputShapeMismatch,
4783 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004784 "data_gen": {
4785 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4786 },
Luke Hutton57287132023-02-06 14:54:18 +00004787 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004788 "rfft2d": {
4789 "op": Op.RFFT2D,
4790 "operands": (1, 0),
4791 "rank": (3, 3),
4792 "build_fcn": (
4793 build_rfft2d,
4794 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004795 TosaTensorValuesGen.tvgLazyGenDefault,
4796 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004797 ),
4798 "types": [DType.FP32],
4799 "error_if_validators": (
4800 TosaErrorValidator.evWrongInputType,
4801 TosaErrorValidator.evWrongOutputType,
4802 TosaErrorValidator.evWrongInputList,
4803 TosaErrorValidator.evWrongOutputList,
4804 TosaErrorValidator.evWrongRank,
4805 TosaErrorValidator.evBatchMismatch,
4806 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004807 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004808 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004809 "data_gen": {
4810 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4811 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004812 },
Won Jeon74342e52024-01-09 00:34:40 +00004813 # Shape
4814 "add_shape": {
4815 "op": Op.ADD_SHAPE,
4816 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004817 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004818 "build_fcn": (
4819 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004820 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004821 TosaTensorValuesGen.tvgAddSub,
4822 TosaArgGen.agNone,
4823 ),
4824 "types": [DType.SHAPE],
4825 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4826 },
4827 "sub_shape": {
4828 "op": Op.SUB_SHAPE,
4829 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004830 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004831 "build_fcn": (
4832 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004833 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004834 TosaTensorValuesGen.tvgAddSub,
4835 TosaArgGen.agNone,
4836 ),
4837 "types": [DType.SHAPE],
4838 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4839 },
4840 "mul_shape": {
4841 "op": Op.MUL_SHAPE,
4842 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004843 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004844 "build_fcn": (
4845 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004846 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004847 TosaTensorValuesGen.tvgMul,
4848 TosaArgGen.agNone,
4849 ),
4850 "types": [DType.SHAPE],
4851 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4852 },
4853 "div_shape": {
4854 "op": Op.DIV_SHAPE,
4855 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004856 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004857 "build_fcn": (
4858 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004859 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004860 TosaTensorValuesGen.tvgIntDiv,
4861 TosaArgGen.agNone,
4862 ),
4863 "types": [DType.SHAPE],
4864 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4865 },
4866 "concat_shape": {
4867 "op": Op.CONCAT_SHAPE,
4868 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004869 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004870 "build_fcn": (
4871 build_concat,
4872 TosaTensorGen.tgConcat,
4873 TosaTensorValuesGen.tvgConcat,
4874 TosaArgGen.agNone,
4875 ),
4876 "types": [DType.SHAPE],
4877 "error_if_validators": (),
4878 },
4879 "const_shape": {
4880 "op": Op.CONST_SHAPE,
4881 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004882 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004883 "build_fcn": (
4884 build_const,
4885 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004886 TosaTensorValuesGen.tvgLazyGenDefault,
4887 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00004888 ),
4889 "types": [DType.SHAPE],
4890 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004891 }
4892
Kevin Cheng550ccc52021-03-03 11:21:43 -08004893
Eric Kunzee5e26762020-10-13 16:11:07 -07004894class OutputShaper:
4895 # Methods in this class compute the expected output shape and datatype
4896 # for common classes of operations
4897 def __init__(self):
4898 pass
4899
4900 # These methods return arguments that can be used for
4901 # creating a new output tensor
4902 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004903 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4904 if error_name != ErrorIf.RankMismatch:
4905 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004906 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004907
4908 shape = []
4909 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004910 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004911 shape.append(b.shape[i])
4912 else:
4913 shape.append(a.shape[i])
4914
Jerry Ge135c9552023-05-23 20:59:32 +00004915 fuzz_idx = rng.integers(0, len(a.shape))
4916 if error_name == ErrorIf.DimensionMismatch:
4917 shape[fuzz_idx] += 1
4918
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004919 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004920 all_dtypes = [
4921 DType.INT8,
4922 DType.INT16,
4923 DType.INT32,
4924 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004925 DType.FP16,
4926 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004927 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004928 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004929 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4930 outputDType = rng.choice(wrong_dtypes)
4931 else:
4932 outputDType = a.dtype
4933
4934 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004935
4936 @staticmethod
4937 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004938 assert len(a.shape) == len(b.shape)
4939 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004940
4941 shape = []
4942 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004943 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004944 shape.append(a.shape[i])
4945
Kevin Cheng550ccc52021-03-03 11:21:43 -08004946 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004947
4948 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004949 def unaryOp(ser, rng, a, error_name=None):
4950 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004951 all_dtypes = [
4952 DType.INT8,
4953 DType.INT16,
4954 DType.INT32,
4955 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004956 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004957 DType.FP16,
4958 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004959 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004960 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4961 outputDType = rng.choice(wrong_dtypes)
4962 else:
4963 outputDType = a.dtype
4964
4965 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004966
4967 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004968 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004969 if error_name != ErrorIf.RankMismatch:
4970 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004971 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004972
4973 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004974 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004975 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004976 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4977 else:
4978 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004979
Jerry Ge135c9552023-05-23 20:59:32 +00004980 fuzz_idx = rng.integers(0, len(a.shape))
4981 if error_name == ErrorIf.DimensionMismatch:
4982 shape[fuzz_idx] += 1
4983
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004984 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004985 all_dtypes = [
4986 DType.INT8,
4987 DType.INT16,
4988 DType.INT32,
4989 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004990 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004991 DType.FP16,
4992 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004993 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004994 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4995 outputDType = rng.choice(wrong_dtypes)
4996 else:
4997 outputDType = a.dtype
4998
4999 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005000
5001 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005002 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005003 if error_name != ErrorIf.RankMismatch:
5004 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005005 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005006
5007 # Do broadcast
5008 shape = []
5009 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005010 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005011 shape.append(b.shape[i])
5012 else:
5013 shape.append(a.shape[i])
5014
Jerry Ge135c9552023-05-23 20:59:32 +00005015 fuzz_idx = rng.integers(0, len(a.shape))
5016 if error_name == ErrorIf.DimensionMismatch:
5017 shape[fuzz_idx] += 1
5018
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005019 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005020 wrong_dtypes = [
5021 DType.INT8,
5022 DType.INT16,
5023 DType.INT32,
5024 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005025 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005026 DType.FP16,
5027 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005028 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005029 outputDType = rng.choice(wrong_dtypes)
5030 else:
5031 outputDType = DType.BOOL
5032
5033 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005034
5035 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005036 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005037 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005038 if error_name not in [
5039 ErrorIf.AxisSmallerZero,
5040 ErrorIf.AxisLargerRank,
5041 ErrorIf.ShapeOfAxisNotOne,
5042 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005043 shape[axis] = 1
5044 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5045 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005046
Matthew Haddond6ce7252021-09-29 15:35:44 +01005047 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005048 all_dtypes = [
5049 DType.INT8,
5050 DType.INT16,
5051 DType.INT32,
5052 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005053 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005054 DType.FP16,
5055 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005056 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005057 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5058 outputDType = rng.choice(wrong_dtypes)
5059 else:
5060 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005061
Matthew Haddond6ce7252021-09-29 15:35:44 +01005062 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005063
5064 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005065 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005066 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005067
5068 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5069 del shape[axis]
5070
5071 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5072 remove = rng.choice([True, False])
5073 if remove and len(shape) > 1:
5074 del shape[0]
5075 else:
5076 shape.append(1)
5077 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5078 for i in range(len(shape)):
5079 shape[i] = shape[i] + rng.integers(1, 10)
5080
5081 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005082 all_dtypes = [
5083 DType.INT8,
5084 DType.INT16,
5085 DType.INT32,
5086 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005087 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005088 DType.FP16,
5089 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005090 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005091 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5092 outputDType = rng.choice(wrong_dtypes)
5093 else:
5094 outputDType = DType.INT32
5095
5096 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005097
5098 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005099 def conv2dOp(
5100 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5101 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005102
5103 # IFM: NHWC
5104 # Filter: OHWI
5105 # OFM: NHWC
5106
Kevin Cheng550ccc52021-03-03 11:21:43 -08005107 h = (
5108 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005109 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005110 + padding[0]
5111 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005112 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005113 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005114
Kevin Cheng550ccc52021-03-03 11:21:43 -08005115 w = (
5116 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005117 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005118 + padding[2]
5119 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005120 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005121 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005122
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005123 if error_name == ErrorIf.ConvOutputShapeMismatch:
5124 choices = [1, 2, 3]
5125 change = rng.choice(choices)
5126 # increment in multiples of stride to not hit non-integer error case
5127 if change in [1, 3]:
5128 h = h + (rng.choice(choices) * strides[0])
5129 if change in [2, 3]:
5130 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005131
Eric Kunzee5e26762020-10-13 16:11:07 -07005132 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5133
James Ward8b390432022-08-12 20:48:56 +01005134 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005135 # Pick some potentially correct output dtype if input type is incorrect
5136 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005137 else:
James Ward8b390432022-08-12 20:48:56 +01005138 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005139
5140 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005141 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005142 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005143 else:
5144 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005145 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005146 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005147
Kevin Cheng550ccc52021-03-03 11:21:43 -08005148 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005149
5150 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005151 def conv3dOp(
5152 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5153 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005154
5155 # IFM: NDHWC
5156 # Filter: ODHWI
5157 # OFM: NDHWC
5158
5159 d = (
5160 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005161 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005162 + padding[0]
5163 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005164 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005165 ) // strides[0] + 1
5166
5167 h = (
5168 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005169 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005170 + padding[2]
5171 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005172 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005173 ) // strides[1] + 1
5174
5175 w = (
5176 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005177 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005178 + padding[4]
5179 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005180 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005181 ) // strides[2] + 1
5182
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005183 if error_name == ErrorIf.ConvOutputShapeMismatch:
5184 choices = [1, 2, 3, 4]
5185 change = rng.choice(choices)
5186 # increment in multiples of stride to not hit non-integer error case
5187 if change in [1, 4]:
5188 d = d + (rng.choice(choices) * strides[0])
5189 if change in [2, 4]:
5190 h = h + (rng.choice(choices) * strides[1])
5191 if change in [3, 4]:
5192 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005193
Kevin Cheng1533b852021-09-01 12:51:58 -07005194 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5195
James Ward8b390432022-08-12 20:48:56 +01005196 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005197 # Pick some potentially correct output dtype if input type is incorrect
5198 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005199 else:
James Ward8b390432022-08-12 20:48:56 +01005200 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005201
5202 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005203 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005204 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005205 else:
5206 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005207 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005208 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005209
5210 return ser.addOutput(ofm_shape, out_dtype)
5211
5212 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005213 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005214 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005215 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005216 # IFM: NHWC
5217 # Filter: HWCM
5218 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005219
Kevin Cheng550ccc52021-03-03 11:21:43 -08005220 h = (
5221 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005222 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005223 + padding[0]
5224 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005225 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005226 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005227
Kevin Cheng550ccc52021-03-03 11:21:43 -08005228 w = (
5229 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005230 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005231 + padding[2]
5232 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005233 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005234 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005235
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005236 if error_name == ErrorIf.ConvOutputShapeMismatch:
5237 choices = [1, 2, 3]
5238 change = rng.choice(choices)
5239 # increment in multiples of stride to not hit non-integer error case
5240 if change in [1, 3]:
5241 h = h + (rng.choice(choices) * strides[0])
5242 if change in [2, 3]:
5243 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005244
Eric Kunzee5e26762020-10-13 16:11:07 -07005245 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5246
James Ward8b390432022-08-12 20:48:56 +01005247 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005248 # Pick some potentially correct output dtype if input type is incorrect
5249 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005250 else:
James Ward8b390432022-08-12 20:48:56 +01005251 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005252
5253 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005254 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005255 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005256 else:
5257 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005258 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005259 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005260
Kevin Cheng550ccc52021-03-03 11:21:43 -08005261 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005262
5263 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005264 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005265 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005266 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005267 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005268 h = 1
5269 w = 1
5270 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005271 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5272 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005273
5274 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005275 choices = [1, 2, 3]
5276 change = rng.choice(choices)
5277 # increment in multiples of stride to not hit non-integer error case
5278 if change in [1, 3]:
5279 h = h + (rng.choice(choices) * stride[0])
5280 if change in [2, 3]:
5281 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005282 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005283
5284 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005285 all_dtypes = [
5286 DType.INT8,
5287 DType.INT16,
5288 DType.INT32,
5289 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005290 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005291 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005292 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005293 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005294 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5295 outputDType = rng.choice(wrong_dtypes)
5296 else:
5297 outputDType = ifm.dtype
5298
5299 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005300
5301 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005302 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005303 # input: N, IC
5304 # filter: OC, IC
5305 # output: N, OC
5306
5307 output_shape = [input.shape[0], filter.shape[0]]
5308
James Ward8b390432022-08-12 20:48:56 +01005309 # Validated in arg_gen (also invalidated for ErrorIf)
5310 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005311
Kevin Cheng550ccc52021-03-03 11:21:43 -08005312 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005313
5314 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005315 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005316 # a: N, H, C
5317 # b: N, C, W
5318 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005319
Kevin Cheng2d60f002021-06-09 14:18:32 -07005320 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005321
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005322 if error_name == ErrorIf.WrongOutputType:
5323 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005324 incorrect_types = (
5325 DType.INT4,
5326 DType.INT8,
5327 DType.INT16,
5328 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005329 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005330 DType.FP16,
5331 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005332 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005333 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005334 incorrect_types = (
5335 DType.INT4,
5336 DType.INT8,
5337 DType.INT16,
5338 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005339 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005340 DType.FP16,
5341 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005342 )
James Ward24dbc422022-10-19 12:20:31 +01005343 elif (
5344 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5345 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005346 incorrect_types = (
5347 DType.INT4,
5348 DType.INT8,
5349 DType.INT16,
5350 DType.INT32,
5351 DType.INT48,
5352 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005353 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005354 elif error_name == ErrorIf.WrongInputType:
5355 # Pick some potentially correct output dtype if input type is incorrect
5356 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005357 else:
James Ward8b390432022-08-12 20:48:56 +01005358 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005359
Kevin Cheng550ccc52021-03-03 11:21:43 -08005360 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005361
5362 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005363 def concatOp(ser, rng, axis, inputs, error_name=None):
5364 input1 = inputs[0]
5365 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005366
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005367 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005368 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005369 if not (
5370 # unable to concat tensors of different ranks
5371 error_name == ErrorIf.ConcatInputRankMismatch
5372 # unable to concat tensors along an invalid axis
5373 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005374 ):
5375 for tensor in remaining_inputs:
5376 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005377
Matthew Haddon01c359d2021-10-15 16:30:48 +01005378 if error_name == ErrorIf.ConcatShapeSumMismatch:
5379 output_shape[axis] += rng.integers(5, 10)
5380
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005381 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005382 all_dtypes = {
5383 DType.INT8,
5384 DType.INT16,
5385 DType.INT32,
5386 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005387 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005388 DType.FP16,
5389 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005390 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005391 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5392 outputDType = rng.choice(wrong_dtypes)
5393 else:
5394 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005395
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005396 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005397
5398 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005399 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005400
5401 output_shape = a.shape.copy()
5402
5403 for i in range(len(output_shape)):
5404 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5405
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005406 if error_name == ErrorIf.PadOutputShapeMismatch:
5407 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005408 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005409 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005410 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005411
Matthew Haddone807aae2021-10-11 18:12:58 +01005412 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005413 all_dtypes = [
5414 DType.INT8,
5415 DType.INT16,
5416 DType.INT32,
5417 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005418 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005419 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005420 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005421 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005422 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5423 outputDType = rng.choice(wrong_dtypes)
5424 else:
5425 outputDType = a.dtype
5426
5427 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005428
5429 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005430 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005431 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005432
5433 if error_name == ErrorIf.WrongOutputType:
5434 all_dtypes = [
5435 DType.INT8,
5436 DType.INT16,
5437 DType.INT32,
5438 DType.INT48,
5439 DType.FP32,
5440 DType.FP16,
5441 DType.BF16,
5442 ]
5443 wrong_dtypes = list(set(all_dtypes))
5444 outputDType = rng.choice(wrong_dtypes)
5445 else:
5446 outputDType = DType.SHAPE
5447
5448 return ser.addOutput(output_shape, outputDType)
5449
5450 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005451 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005452 output_shape = shape.copy()
5453
Matthew Haddone807aae2021-10-11 18:12:58 +01005454 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5455 for i in range(len(output_shape)):
5456 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5457
5458 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005459 all_dtypes = [
5460 DType.INT8,
5461 DType.INT16,
5462 DType.INT32,
5463 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005464 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005465 DType.FP16,
5466 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005467 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005468 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5469 outputDType = rng.choice(wrong_dtypes)
5470 else:
5471 outputDType = a.dtype
5472
5473 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005474
5475 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005476 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005477
Matthew Haddone807aae2021-10-11 18:12:58 +01005478 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005479 all_dtypes = [
5480 DType.INT8,
5481 DType.INT16,
5482 DType.INT32,
5483 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005484 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005485 DType.FP16,
5486 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005487 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005488 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005489 outputDType = rng.choice(wrong_dtypes)
5490 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005491 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005492
Luke Huttona4e48ca2023-02-22 11:53:48 +00005493 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005494 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005495 for index in range(len(output_shape)):
5496 if output_shape[index] <= 2:
5497 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5498 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005499 output_shape[index] = output_shape[index] + rng.choice(
5500 [-2, -1, 1, 2]
5501 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005502 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5503 output_shape = input.shape.copy()
5504 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005505 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005506
5507 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005508
5509 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005510 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005511
5512 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005513 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005514
5515 for i in range(len(output_shape)):
5516 output_shape[i] = a.shape[i] * multiples[i]
5517
Luke Huttona4e48ca2023-02-22 11:53:48 +00005518 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005519 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005520
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005521 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005522 all_dtypes = [
5523 DType.INT8,
5524 DType.INT16,
5525 DType.INT32,
5526 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005527 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005528 DType.FP16,
5529 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005530 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005531 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5532 outputDType = rng.choice(wrong_dtypes)
5533 else:
5534 outputDType = a.dtype
5535
5536 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005537
5538 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005539 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005540 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005541
Kevin Cheng550ccc52021-03-03 11:21:43 -08005542 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005543
Luke Huttona4e48ca2023-02-22 11:53:48 +00005544 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005545 for i in range(len(output_shape)):
5546 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005547
Luke Huttona4e48ca2023-02-22 11:53:48 +00005548 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5549 for i in range(len(output_shape)):
5550 output_shape[i] += rng.integers(1, 10)
5551 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005552 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005553
Matthew Haddone807aae2021-10-11 18:12:58 +01005554 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005555 all_dtypes = [
5556 DType.INT8,
5557 DType.INT16,
5558 DType.INT32,
5559 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005560 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005561 DType.FP16,
5562 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005563 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005564 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5565 outputDType = rng.choice(wrong_dtypes)
5566 else:
5567 outputDType = a.dtype
5568
5569 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005570
5571 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005572 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005573 if error_name != ErrorIf.WrongRank:
5574 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005575 assert len(indices.shape) == 2
5576 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005577
Kevin Cheng77d0f762020-11-24 10:26:32 -08005578 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5579
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005580 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005581 all_dtypes = [
5582 DType.INT8,
5583 DType.INT16,
5584 DType.INT32,
5585 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005586 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005587 DType.FP16,
5588 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005589 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005590 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5591 outputDType = rng.choice(wrong_dtypes)
5592 else:
5593 outputDType = values.dtype
5594
5595 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005596
5597 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005598 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005599 if error_name != ErrorIf.WrongRank:
5600 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005601 assert len(indices.shape) == 2
5602 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005603 assert values_in.shape[0] == indices.shape[0] # N
5604 assert input.shape[1] == indices.shape[1] # W
5605 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005606
5607 output_shape = values_in.shape
5608
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005609 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005610 all_dtypes = [
5611 DType.INT8,
5612 DType.INT16,
5613 DType.INT32,
5614 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005615 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005616 DType.FP16,
5617 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005618 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005619 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5620 outputDType = rng.choice(wrong_dtypes)
5621 else:
5622 outputDType = values_in.dtype
5623
5624 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005625
5626 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005627 def tableOp(ser, rng, input, error_name=None):
5628 # Same shape as the input, dtype dependent on input dtype
5629 if error_name != ErrorIf.WrongInputType:
5630 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005631 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005632 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005633 wrong_dtypes = [
5634 DType.INT8,
5635 DType.INT16,
5636 DType.INT32,
5637 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005638 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005639 DType.FP16,
5640 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005641 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005642 wrong_dtypes.remove(output_dtype)
5643 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005644 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005645
5646 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005647 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005648 serializer,
5649 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005650 input,
5651 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005652 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005653 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005654 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005655 input_dtype,
5656 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005657 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005658 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005659 # Calculate OH, OW
5660 scale_y_n = scale[0]
5661 scale_y_d = scale[1]
5662 scale_x_n = scale[2]
5663 scale_x_d = scale[3]
5664 if error_name == ErrorIf.ScaleSmallerEqualZero:
5665 scale_y_n = max(scale_y_n, 1)
5666 scale_y_d = max(scale_y_d, 1)
5667 scale_x_n = max(scale_x_n, 1)
5668 scale_x_d = max(scale_x_d, 1)
5669
5670 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5671 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5672
5673 if error_name is not None:
5674 # Make sure the output tensor is valid, which can occur when
5675 # scale, offset or border have been changed for ERROR_IFs
5676 oh = max(oh, 1)
5677 ow = max(ow, 1)
5678 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005679 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5680 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005681
5682 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5683 choices = [1, 2, 3]
5684 change = rng.choice(choices)
5685 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5686 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005687 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005688 oh -= scale_y_d
5689 assert oh > 0 # Should have been caught in agResize
5690 else:
5691 oh += scale_y_d
5692 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005693 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005694 ow -= scale_x_d
5695 assert ow > 0 # Should have been caught in agResize
5696 else:
5697 ow += scale_x_d
5698
Matthew Haddon848efb42021-09-09 12:30:53 +01005699 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005700 output_dims = [
5701 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005702 oh,
5703 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005704 input.shape[0],
5705 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005706 elif error_name == ErrorIf.BatchMismatch:
5707 output_dims = [
5708 input.shape[0] + rng.integers(1, 10),
5709 oh,
5710 ow,
5711 input.shape[3],
5712 ]
5713 elif error_name == ErrorIf.ChannelMismatch:
5714 output_dims = [
5715 input.shape[0],
5716 oh,
5717 ow,
5718 input.shape[3] + rng.integers(1, 10),
5719 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005720 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005721 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005722
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005723 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005724
5725 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005726 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005727 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005728
5729 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005730 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005731 if error_name == ErrorIf.ConvOutputShapeMismatch:
5732 choices = [1, 2, 3]
5733 change = rng.choice(choices)
5734 if change in [1, 3]:
5735 output_shape[1] = output_shape[1] + rng.choice(choices)
5736 if change in [2, 3]:
5737 output_shape[2] = output_shape[2] + rng.choice(choices)
5738
James Ward8b390432022-08-12 20:48:56 +01005739 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005740 # Pick some potentially correct output dtype if input type is incorrect
5741 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005742 else:
James Ward8b390432022-08-12 20:48:56 +01005743 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005744
5745 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005746 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005747 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005748 else:
5749 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005750 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005751 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005752
Kevin Cheng550ccc52021-03-03 11:21:43 -08005753 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005754
5755 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005756 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5757 outputs = []
5758
5759 assert ifm1.dtype == ifm2.dtype
5760 input_dtype = ifm1.dtype
5761
5762 if error_name != ErrorIf.FFTInputShapeMismatch:
5763 assert ifm1.shape == ifm2.shape
5764
5765 input_shape = ifm1.shape
5766 if error_name != ErrorIf.WrongRank:
5767 assert len(input_shape) == 3
5768
5769 output_shape = input_shape.copy()
5770 output_dtype = input_dtype
5771
5772 if error_name == ErrorIf.WrongOutputType:
5773 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005774 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005775 output_dtype = rng.choice(wrong_dtypes)
5776 elif error_name == ErrorIf.BatchMismatch:
5777 output_shape[0] += rng.integers(1, 10)
5778 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5779 modify_dim = rng.choice([1, 2])
5780 output_shape[modify_dim] += rng.integers(1, 10)
5781
5782 outputs.append(serializer.addOutput(output_shape, output_dtype))
5783 outputs.append(serializer.addOutput(output_shape, output_dtype))
5784 return outputs
5785
5786 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005787 def rfft2dOp(serializer, rng, value, error_name=None):
5788 outputs = []
5789
5790 input_shape = value.shape
5791 if error_name != ErrorIf.WrongRank:
5792 assert len(input_shape) == 3
5793
5794 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5795
5796 output_dtype = value.dtype
5797 if error_name == ErrorIf.WrongOutputType:
5798 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005799 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005800 output_dtype = rng.choice(wrong_dtypes)
5801 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005802 output_shape[0] += rng.integers(1, 10)
5803 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5804 modify_dim = rng.choice([1, 2])
5805 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005806
5807 outputs.append(serializer.addOutput(output_shape, output_dtype))
5808 outputs.append(serializer.addOutput(output_shape, output_dtype))
5809 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005810
5811 @staticmethod
5812 def addShapeOp(ser, rng, a, b, error_name=None):
5813 if error_name != ErrorIf.RankMismatch:
5814 assert len(a.shape) == len(b.shape)
5815 assert a.dtype == b.dtype
5816
5817 shape = []
5818 for i in range(len(a.shape)):
5819 shape.append(a.shape[i])
5820
5821 fuzz_idx = rng.integers(0, len(a.shape))
5822 if error_name == ErrorIf.DimensionMismatch:
5823 shape[fuzz_idx] += 1
5824
5825 if error_name == ErrorIf.WrongOutputType:
5826 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5827 outputDType = rng.choice(wrong_dtypes)
5828 else:
5829 outputDType = DType.SHAPE
5830 return ser.addOutput(shape, outputDType)