blob: 2d471c054aeadca4c58a1fc73e04e44b55d5e4e8 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
79 for dtype in (DType.FP32, DType.FP16, DType.BF16):
80 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000198 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 return np.int64(self.rng.integers(low=low, high=high, size=shape))
200 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
201 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
202
203 if dtype == DType.FP16:
204 return np.float16(f_tensor)
205 else:
206 f32_tensor = np.float32(f_tensor)
207 if dtype == DType.BF16:
208 # Floor the last 16 bits of each f32 value
209 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
210 else:
211 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100213 # All other integer types
214 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
Kevin Cheng989cb052021-04-28 16:29:44 -0700216 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700217 placeholders = []
218
Kevin Cheng989cb052021-04-28 16:29:44 -0700219 assert len(shape_list) == len(dtype_list)
220
Jeremy Johnson1271c442023-09-05 11:39:26 +0100221 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700222 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 if not self.args.lazy_data_gen:
224 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700225 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700226
227 return placeholders
228
Kevin Cheng989cb052021-04-28 16:29:44 -0700229 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700230 consts = []
231
Kevin Cheng989cb052021-04-28 16:29:44 -0700232 assert len(shape_list) == len(dtype_list)
233
Jeremy Johnson1271c442023-09-05 11:39:26 +0100234 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700235 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100236 if not self.args.lazy_data_gen:
237 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700238 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700239
240 return consts
241
242 def makeShape(self, rank):
243 if self.targetted_shape:
244 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 return np.int32(
246 self.rng.integers(
247 low=self.args.tensor_shape_range[0],
248 high=self.args.tensor_shape_range[1],
249 size=rank,
250 )
251 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700252
253 def setTargetShape(self, shape):
254 self.targetted_shape = shape
255
256 def randInt(self, low=0, high=256):
257 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
258
259 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 low, high = self.getDTypeRange(dtype)
261
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100262 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100263 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100264 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100265 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100266 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100267 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
268 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700269 elif dtype == DType.BOOL:
270 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000271 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700272 # Special size
273 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 return np.int32(self.rng.integers(low, high, size=1))[0]
276
277 def shapeStr(self, shape):
278
279 sStr = []
280 # Convert to strings
281 for i in shape:
282 sStr.append(str(i))
283
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100286 def typeStr(self, dtype):
287 if isinstance(dtype, list) or isinstance(dtype, tuple):
288 assert len(dtype) >= 2
289 strs = [self.typeStr(t) for t in dtype]
290 # Limit types to the first 2 as the 3rd is the accumulator
291 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100293 if dtype in gtu.DTYPE_ATTRIBUTES:
294 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700295 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100296 raise Exception(
297 "Unknown dtype, cannot convert to string: {}".format(dtype)
298 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100300 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100301 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100302 if dtype in gtu.DTYPE_ATTRIBUTES:
303 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100305 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700306
Luke Hutton57287132023-02-06 14:54:18 +0000307 def constrictBatchSize(self, shape):
308 # Limit the batch size unless an explicit target shape set
309 if self.args.max_batch_size and not self.args.target_shapes:
310 shape[0] = min(shape[0], self.args.max_batch_size)
311 return shape
312
James Ward30124a82023-02-02 14:56:33 +0000313 def makeDimension(self):
314 return self.randInt(
315 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
316 )
317
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100318 def tensorComplianceMetaData(
319 self, op, inputType, argsDict, outputTensor, errorName
320 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000321 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
322 UNSUPPORTED_NON_FP32_INPUT_OPS = (
323 Op.MATMUL,
324 Op.CONV2D,
325 Op.FULLY_CONNECTED,
326 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000327 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000328 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000329 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100330 if (
331 errorName
332 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000333 or (
334 not gtu.dtypeIsSupportedByCompliance(inputType)
335 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
336 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100337 ):
338 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100339 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100340
Jeremy Johnson1271c442023-09-05 11:39:26 +0100341 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100342 compliance_tens = {
343 "mode": None,
344 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
345 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
346 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100347 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
348 mode = gtu.ComplianceMode.DOT_PRODUCT
349 compliance_tens["dot_product_info"] = {
350 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100351 "ks": int(argsDict["ksb"])
352 if "ksb" in argsDict
353 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100354 }
355 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
356 mode = gtu.ComplianceMode.FP_SPECIAL
357 elif "compliance" in op and "ulp" in op["compliance"]:
358 mode = gtu.ComplianceMode.ULP
359 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000360 elif "compliance" in op and "relative" in op["compliance"]:
361 mode = gtu.ComplianceMode.RELATIVE
362 compliance_tens["relative_info"] = {
363 "max": argsDict["max_abs_value"],
364 "scale": op["compliance"]["relative"],
365 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100366 elif op["op"] == Op.REDUCE_PRODUCT:
367 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000368 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000369 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000370 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000371 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
372 compliance_tens["abs_error_info"] = {
373 "lower_bound": op["compliance"]["abs_error_lower_bound"]
374 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100375 else:
376 mode = gtu.ComplianceMode.EXACT
377 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
378
379 return compliance_tens
380
381 # Build Op functions
382 # Create the output tensor (calling OutputShaper as needed)
383 # Do final tweaks to attributes (if necessary for errorIf)
384 # Add Op into graph
385 # Return resulting tensor information or BuildInfo
386
387 class BuildInfo:
388 """Enhanced build information containing result tensor and associated compliance dict."""
389
390 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000391 if isinstance(resultTensor, list):
392 assert complianceDict is None or isinstance(complianceDict, list)
393 self.resultTensorList = resultTensor
394 self.complianceDictList = complianceDict
395 else:
396 self.resultTensorList = [resultTensor]
397 if complianceDict is None:
398 self.complianceDictList = None
399 else:
400 self.complianceDictList = [complianceDict]
401
402 def getComplianceInfo(self):
403 if self.complianceDictList is None:
404 return None
405 else:
406 tens_dict = {}
407 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
408 if comp is not None:
409 tens_dict[tens.name] = comp
410
411 if tens_dict:
412 # Have some compliance data, so return the info
413 compliance = {
414 "version": "0.1",
415 "tensors": tens_dict,
416 }
417 else:
418 compliance = None
419 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700420
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000421 def build_unary(
422 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
423 ):
424 assert len(inputs) == 1
425 a = inputs[0]
426 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100427
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000428 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100429
430 # Ensure new output type has correct qinfo
431 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000432 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000433 qinfo = [
434 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000435 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000436 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100437
438 # Invalidate Input/Output list for error if checks.
439 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000440 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100441 pCount, cCount = op["operands"]
442 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000443 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
444 self, error_name, input_list, output_list
445 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100446
Les Bell729b0352021-11-24 10:28:21 +0000447 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100448 self.ser,
449 validator_fcns,
450 error_name,
451 op=op,
452 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000453 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000454 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000455 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100456 input_list=input_list,
457 output_list=output_list,
458 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000459 ):
460 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100461
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000462 attr = None
463 if op["op"] == Op.NEGATE:
464 attr = ts.TosaSerializerAttribute()
465 attr.NegateAttribute(qinfo[0], qinfo[1])
466
467 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000468
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000469 compliance = self.tensorComplianceMetaData(
470 op, a.dtype, args_dict, result_tensor, error_name
471 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000472 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700473
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000474 def build_binary_broadcast(
475 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
476 ):
477 assert len(inputs) == 2
478 a, b = inputs
479 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000480 self.ser, self.rng, a, b, error_name
481 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100482
483 # Invalidate Input/Output list for error if checks.
484 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000485 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100486 pCount, cCount = op["operands"]
487 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000488 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
489 self, error_name, input_list, output_list
490 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100491
Les Bell729b0352021-11-24 10:28:21 +0000492 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100493 self.ser,
494 validator_fcns,
495 error_name,
496 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000497 input1=a,
498 input2=b,
499 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000500 output_dtype=result_tensor.dtype,
501 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100502 input_list=input_list,
503 output_list=output_list,
504 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000505 ):
506 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100507
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000508 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000509
Jeremy Johnson9a758382023-11-07 16:27:35 +0000510 compliance = self.tensorComplianceMetaData(
511 op, a.dtype, args_dict, result_tensor, error_name
512 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000513
514 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700515
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100516 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700517 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000518 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700519 return result_tens
520
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000521 def build_arithmetic_right_shift(
522 self, op, a, b, round, validator_fcns=None, error_name=None
523 ):
524 result_tens = OutputShaper.binaryBroadcastOp(
525 self.ser, self.rng, a, b, error_name
526 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100527
528 # Invalidate Input/Output list for error if checks.
529 input_list = [a.name, b.name]
530 output_list = [result_tens.name]
531 pCount, cCount = op["operands"]
532 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000533 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
534 self, error_name, input_list, output_list
535 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100536
Les Bell729b0352021-11-24 10:28:21 +0000537 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100538 self.ser,
539 validator_fcns,
540 error_name,
541 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000542 input1=a,
543 input2=b,
544 input_dtype=a.dtype,
545 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000546 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100547 input_list=input_list,
548 output_list=output_list,
549 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000550 ):
551 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800552
553 attr = ts.TosaSerializerAttribute()
554 attr.ArithmeticRightShiftAttribute(round)
555
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000556 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800557 return result_tens
558
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100559 def build_mul(
560 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
561 ):
562 assert len(inputs) == 2
563 a, b = inputs
564 shift = args_dict["shift"]
565
566 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000567 self.ser, self.rng, a, b, error_name
568 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700569
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100570 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100571 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100572 result_tensor.setDtype(DType.INT32)
573
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100574 if error_name == ErrorIf.WrongOutputType:
575 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
576 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100577 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100578
579 # Invalidate Input/Output list for error if checks.
580 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100581 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100582 pCount, cCount = op["operands"]
583 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000584 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
585 self, error_name, input_list, output_list
586 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100587
Les Bell729b0352021-11-24 10:28:21 +0000588 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100589 self.ser,
590 validator_fcns,
591 error_name,
592 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000593 input1=a,
594 input2=b,
595 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100596 output_dtype=result_tensor.dtype,
597 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100598 input_list=input_list,
599 output_list=output_list,
600 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000601 ):
602 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700603
Kevin Chengaee1fac2020-11-11 13:54:06 -0800604 attr = ts.TosaSerializerAttribute()
605 attr.MulAttribute(shift)
606
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000607 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100608
609 compliance = self.tensorComplianceMetaData(
610 op, a.dtype, args_dict, result_tensor, error_name
611 )
612
613 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700614
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100615 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
616 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700617
Kevin Chengfe392ce2021-10-18 21:51:55 +0000618 attr = ts.TosaSerializerAttribute()
619 attr.TableAttribute(table)
620
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100621 # Invalidate Input/Output list for error if checks.
622 input_list = [a.name]
623 output_list = [result_tens.name]
624 pCount, cCount = op["operands"]
625 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000626 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
627 self, error_name, input_list, output_list
628 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100629
Les Bell729b0352021-11-24 10:28:21 +0000630 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100631 self.ser,
632 validator_fcns,
633 error_name,
634 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000635 input_shape=a.shape,
636 input_dtype=a.dtype,
637 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000638 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100639 input_list=input_list,
640 output_list=output_list,
641 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000642 ):
643 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100644
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000645 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700646
647 return result_tens
648
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000649 def build_select(
650 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
651 ):
652 assert len(inputs) == 3
653 cond, a, b = inputs
654
655 result_tensor = OutputShaper.selectOp(
656 self.ser, self.rng, cond, a, b, error_name
657 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100658
659 # Invalidate Input/Output list for error if checks.
660 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000661 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100662 pCount, cCount = op["operands"]
663 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000664 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
665 self, error_name, input_list, output_list
666 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100667
Les Bell729b0352021-11-24 10:28:21 +0000668 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100669 self.ser,
670 validator_fcns,
671 error_name,
672 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000673 input1=cond,
674 input2=a,
675 input3=b,
676 input_shape=a.shape,
677 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000678 output_dtype=result_tensor.dtype,
679 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100680 input_list=input_list,
681 output_list=output_list,
682 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000683 ):
684 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100685
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000686 self.ser.addOperator(
687 op["op"],
688 input_list,
689 output_list,
690 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000691 compliance = self.tensorComplianceMetaData(
692 op, a.dtype, args_dict, result_tensor, error_name
693 )
694
695 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700696
Jeremy Johnsona0150012023-11-15 15:52:06 +0000697 def build_comparison(
698 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
699 ):
700 assert len(inputs) == 2
701 a, b = inputs
702
703 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000704 self.ser, self.rng, a, b, error_name
705 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100706
707 # Invalidate Input/Output list for error if checks.
708 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000709 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100710 pCount, cCount = op["operands"]
711 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000712 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
713 self, error_name, input_list, output_list
714 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100715
Les Bell729b0352021-11-24 10:28:21 +0000716 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100717 self.ser,
718 validator_fcns,
719 error_name,
720 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000721 input1=a,
722 input2=b,
723 input_shape=a.shape,
724 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000725 output_shape=result_tensor.shape,
726 output_dtype=result_tensor.dtype,
727 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100728 input_list=input_list,
729 output_list=output_list,
730 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000731 ):
732 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100733
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000734 self.ser.addOperator(
735 op["op"],
736 input_list,
737 output_list,
738 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000739
740 compliance = self.tensorComplianceMetaData(
741 op, a.dtype, args_dict, result_tensor, error_name
742 )
743 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700744
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000745 def build_argmax(
746 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
747 ):
748 assert len(inputs) == 1
749 a = inputs[0]
750 axis = args_dict["axis"]
751 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100752
753 # Invalidate Input/Output list for error if checks.
754 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000755 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100756 pCount, cCount = op["operands"]
757 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000758 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
759 self, error_name, input_list, output_list
760 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100761
Les Bell729b0352021-11-24 10:28:21 +0000762 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100763 self.ser,
764 validator_fcns,
765 error_name,
766 op=op,
767 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000768 input_shape=a.shape,
769 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000770 output_shape=result_tensor.shape,
771 output_dtype=result_tensor.dtype,
772 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100773 input_list=input_list,
774 output_list=output_list,
775 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000776 ):
777 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700778
779 attr = ts.TosaSerializerAttribute()
780 attr.AxisAttribute(axis)
781
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000782 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000783
784 compliance = self.tensorComplianceMetaData(
785 op, inputs[0].dtype, args_dict, result_tensor, error_name
786 )
787 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700788
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000789 def build_pool2d(
790 self,
791 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100792 inputs,
793 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000794 validator_fcns=None,
795 error_name=None,
796 qinfo=None,
797 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100798 assert len(inputs) == 1
799 input = inputs[0]
800 # max_pool has no accum_dtype
801 accum_dtype = (
802 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
803 )
804 stride = args_dict["stride"]
805 pad = args_dict["pad"]
806 kernel = args_dict["kernel"]
807
Jeremy Johnson0601f802023-11-08 16:28:09 +0000808 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000809 self.ser, self.rng, input, kernel, stride, pad, error_name
810 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100811
812 # Ensure new output type has correct qinfo
813 if error_name == ErrorIf.WrongInputType:
814 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000815 qinfo = [
816 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000817 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000818 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100819
820 # Invalidate Input/Output list for error if checks.
821 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000822 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100823 pCount, cCount = op["operands"]
824 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000825 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
826 self, error_name, input_list, output_list
827 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100828
Les Bell729b0352021-11-24 10:28:21 +0000829 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100830 self.ser,
831 validator_fcns,
832 error_name,
833 op=op,
834 input_shape=input.shape,
835 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000836 output_shape=result_tensor.shape,
837 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000838 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100839 kernel=kernel,
840 stride=stride,
841 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000842 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000843 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100844 input_list=input_list,
845 output_list=output_list,
846 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000847 ):
848 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700849
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000850 if qinfo is None:
851 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700852
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000853 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100854 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000855
856 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700857
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100858 compliance = self.tensorComplianceMetaData(
859 op, inputs[0].dtype, args_dict, result_tensor, error_name
860 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100861
862 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100863
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000864 def build_conv2d(
865 self,
866 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100867 inputs,
868 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000869 validator_fcns=None,
870 error_name=None,
871 qinfo=None,
872 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100873 assert len(inputs) == 3
874 ifm, filter, bias = inputs
875 accum_dtype = args_dict["acc_type"]
876 strides = args_dict["stride"]
877 padding = args_dict["pad"]
878 dilations = args_dict["dilation"]
879
Kevin Cheng550ccc52021-03-03 11:21:43 -0800880 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100881 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100882 self.ser,
883 self.rng,
884 ifm,
885 filter,
886 accum_dtype,
887 strides,
888 padding,
889 dilations,
890 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000891 )
892
893 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000894 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
895 DType.INT8,
896 DType.UINT8,
897 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000898 qinfo = [
899 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100900 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000901 ]
Les Bell0e027d42021-11-09 14:42:14 +0000902
903 # Invalidate Input/Output list for error_if checks.
904 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100905 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000906 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000907 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
908 self, error_name, input_list, output_list
909 )
Les Bell0e027d42021-11-09 14:42:14 +0000910
Les Bell729b0352021-11-24 10:28:21 +0000911 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000912 self.ser,
913 validator_fcns,
914 error_name,
915 op=op,
916 input_dtype=ifm.dtype,
917 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100918 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000919 qinfo=qinfo,
920 input_list=input_list,
921 num_operands=num_operands,
922 output_list=output_list,
923 pad=padding,
924 stride=strides,
925 dilation=dilations,
926 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100927 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100928 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000929 ):
930 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700931
Tai Lyd3797f02023-11-15 23:06:19 +0000932 # TODO - Test local_bound, for now set local bound attribute to False
933 local_bound = False
934
Eric Kunzee5e26762020-10-13 16:11:07 -0700935 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000936 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700937
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000938 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100939
940 compliance = self.tensorComplianceMetaData(
941 op, ifm.dtype, args_dict, result_tensor, error_name
942 )
943
944 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700945
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000946 def build_conv3d(
947 self,
948 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100949 inputs,
950 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000951 validator_fcns=None,
952 error_name=None,
953 qinfo=None,
954 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100955 assert len(inputs) == 3
956 ifm, filter, bias = inputs
957 accum_dtype = args_dict["acc_type"]
958 strides = args_dict["stride"]
959 padding = args_dict["pad"]
960 dilations = args_dict["dilation"]
961
Kevin Cheng1533b852021-09-01 12:51:58 -0700962 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000963 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100964 self.ser,
965 self.rng,
966 ifm,
967 filter,
968 accum_dtype,
969 strides,
970 padding,
971 dilations,
972 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000973 )
974
975 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000976 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
977 DType.INT8,
978 DType.UINT8,
979 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000980 qinfo = [
981 TosaQuantGen.getZeroPoint(self, ifm.dtype),
evacha0147ab1762024-01-29 13:23:23 +0000982 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000983 ]
Les Bell0e027d42021-11-09 14:42:14 +0000984
985 # Invalidate Input/Output list for error_if checks.
986 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +0000987 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000988 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000989 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
990 self, error_name, input_list, output_list
991 )
Les Bell0e027d42021-11-09 14:42:14 +0000992
Les Bell729b0352021-11-24 10:28:21 +0000993 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000994 self.ser,
995 validator_fcns,
996 error_name,
997 op=op,
998 input_dtype=ifm.dtype,
999 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +00001000 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001001 qinfo=qinfo,
1002 input_list=input_list,
1003 num_operands=num_operands,
1004 output_list=output_list,
1005 pad=padding,
1006 stride=strides,
1007 dilation=dilations,
1008 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001009 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +00001010 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001011 ):
1012 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07001013
Tai Lyd3797f02023-11-15 23:06:19 +00001014 # TODO - Test local_bound, for now set local bound attribute to False
1015 local_bound = False
1016
Kevin Cheng1533b852021-09-01 12:51:58 -07001017 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001018 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -07001019
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001020 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001021
1022 compliance = self.tensorComplianceMetaData(
1023 op, ifm.dtype, args_dict, result_tensor, error_name
1024 )
1025
1026 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001027
Kevin Cheng550ccc52021-03-03 11:21:43 -08001028 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001029 self,
1030 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001031 inputs,
1032 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001033 validator_fcns=None,
1034 error_name=None,
1035 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001036 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001037 assert len(inputs) == 3
1038 ifm, filter, bias = inputs
1039 accum_dtype = args_dict["acc_type"]
1040 strides = args_dict["stride"]
1041 out_pad = args_dict["pad"]
1042 output_shape = args_dict["out_shape"]
1043
TatWai Chong24594f52022-06-08 00:48:04 -07001044 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001045 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001046 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001047 )
Les Bell0e027d42021-11-09 14:42:14 +00001048
1049 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001050 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1051 DType.INT8,
1052 DType.UINT8,
1053 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001054 qinfo = [
1055 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001056 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001057 ]
Les Bell0e027d42021-11-09 14:42:14 +00001058
1059 # Invalidate Input/Output list for error_if checks.
1060 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001061 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001062 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001063 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1064 self, error_name, input_list, output_list
1065 )
Les Bell0e027d42021-11-09 14:42:14 +00001066
Les Bell729b0352021-11-24 10:28:21 +00001067 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001068 self.ser,
1069 validator_fcns,
1070 error_name,
1071 op=op,
1072 input_dtype=ifm.dtype,
1073 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001074 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001075 qinfo=qinfo,
1076 input_list=input_list,
1077 num_operands=num_operands,
1078 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001079 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001080 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001081 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001082 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001083 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001084 ):
1085 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001086
Tai Lyd3797f02023-11-15 23:06:19 +00001087 # TODO - Test local_bound, for now set local bound attribute to False
1088 local_bound = False
1089
Eric Kunzee5e26762020-10-13 16:11:07 -07001090 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001091 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001092 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001093 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001094
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001095 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001096
1097 compliance = self.tensorComplianceMetaData(
1098 op, ifm.dtype, args_dict, result_tensor, error_name
1099 )
1100
1101 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001102
Kevin Cheng550ccc52021-03-03 11:21:43 -08001103 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001104 self,
1105 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001106 inputs,
1107 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001108 validator_fcns=None,
1109 error_name=None,
1110 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001111 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001112 assert len(inputs) == 3
1113 ifm, filter, bias = inputs
1114 accum_dtype = args_dict["acc_type"]
1115 strides = args_dict["stride"]
1116 padding = args_dict["pad"]
1117 dilations = args_dict["dilation"]
1118
Jeremy Johnson4f931302024-01-04 17:05:24 +00001119 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001120 self.ser,
1121 self.rng,
1122 ifm,
1123 filter,
1124 accum_dtype,
1125 strides,
1126 padding,
1127 dilations,
1128 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001129 )
1130
1131 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001132 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1133 DType.INT8,
1134 DType.UINT8,
1135 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001136 qinfo = [
1137 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001138 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001139 ]
Les Bell0e027d42021-11-09 14:42:14 +00001140
1141 # Invalidate Input/Output list for error_if checks.
1142 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001143 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001144 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001145 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1146 self, error_name, input_list, output_list
1147 )
Les Bell0e027d42021-11-09 14:42:14 +00001148
Les Bell729b0352021-11-24 10:28:21 +00001149 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001150 self.ser,
1151 validator_fcns,
1152 error_name,
1153 op=op,
1154 input_dtype=ifm.dtype,
1155 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001156 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001157 qinfo=qinfo,
1158 input_list=input_list,
1159 num_operands=num_operands,
1160 output_list=output_list,
1161 pad=padding,
1162 stride=strides,
1163 dilation=dilations,
1164 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001165 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001166 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001167 ):
1168 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001169
Tai Lyd3797f02023-11-15 23:06:19 +00001170 # TODO - Test local_bound, for now set local bound attribute to False
1171 local_bound = False
1172
Eric Kunzee5e26762020-10-13 16:11:07 -07001173 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001174 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001175
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001176 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001177
1178 compliance = self.tensorComplianceMetaData(
1179 op, ifm.dtype, args_dict, result_tensor, error_name
1180 )
1181
1182 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001183
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001184 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001185 self,
1186 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001187 inputs,
1188 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001189 validator_fcns=None,
1190 error_name=None,
1191 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001192 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001193 assert len(inputs) == 3
1194 ifm, filter, bias = inputs
1195 accum_dtype = args_dict["acc_type"]
1196
1197 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001198 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001199 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001200
1201 # Invalidate Input/Output list for error if checks.
1202 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001203 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001204 pCount, cCount = op["operands"]
1205 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001206 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1207 self, error_name, input_list, output_list
1208 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001209
Les Bell729b0352021-11-24 10:28:21 +00001210 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001211 self.ser,
1212 validator_fcns,
1213 error_name,
1214 op=op,
1215 input_shape=ifm.shape,
1216 input_dtype=ifm.dtype,
1217 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001218 output_shape=result_tensor.shape,
1219 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001220 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001221 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001222 input_list=input_list,
1223 output_list=output_list,
1224 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001225 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001226 ):
1227 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001228
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001229 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001230 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001231
1232 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001233
1234 compliance = self.tensorComplianceMetaData(
1235 op, ifm.dtype, args_dict, result_tensor, error_name
1236 )
1237
1238 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001239
James Ward8b390432022-08-12 20:48:56 +01001240 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001241 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001242 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001243 assert len(inputs) == 2
1244 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001245 accum_dtype = args_dict["acc_type"]
1246 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001247 self.ser, self.rng, a, b, accum_dtype, error_name
1248 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001249
1250 # Invalidate Input/Output list for error if checks.
1251 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001252 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001253 pCount, cCount = op["operands"]
1254 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001255 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1256 self, error_name, input_list, output_list
1257 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001258
Les Bell729b0352021-11-24 10:28:21 +00001259 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001260 self.ser,
1261 validator_fcns,
1262 error_name,
1263 op=op,
1264 input_shape=a.shape,
1265 input_dtype=a.dtype,
1266 input2_shape=b.shape,
1267 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001268 output_shape=result_tensor.shape,
1269 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001270 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001271 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001272 input_list=input_list,
1273 output_list=output_list,
1274 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001275 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001276 ):
1277 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001278
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001279 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001280 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001281
1282 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001283
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001284 compliance = self.tensorComplianceMetaData(
1285 op, a.dtype, args_dict, result_tensor, error_name
1286 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001287
1288 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001289
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001290 def build_reduce(
1291 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1292 ):
1293 assert len(inputs) == 1
1294 a = inputs[0]
1295 axis = args_dict["axis"]
1296 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001297
1298 # Invalidate Input/Output list for error if checks.
1299 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001300 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001301 pCount, cCount = op["operands"]
1302 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001303 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1304 self, error_name, input_list, output_list
1305 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001306
Les Bell729b0352021-11-24 10:28:21 +00001307 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001308 self.ser,
1309 validator_fcns,
1310 error_name,
1311 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001312 axis=axis,
1313 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001314 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001315 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001316 output_dtype=result_tensor.dtype,
1317 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001318 input_list=input_list,
1319 output_list=output_list,
1320 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001321 ):
1322 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001323
1324 attr = ts.TosaSerializerAttribute()
1325 attr.AxisAttribute(axis)
1326
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001327 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001328
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001329 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1330 # Number of products - needed for compliance
1331 args_dict["n"] = a.shape[axis]
1332
1333 compliance = self.tensorComplianceMetaData(
1334 op, a.dtype, args_dict, result_tensor, error_name
1335 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001336
1337 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001338
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001339 def build_clamp(
1340 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1341 ):
1342 assert len(inputs) == 1
1343 a = inputs[0]
1344
1345 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001346
Jeremy Johnson18e26662021-07-22 16:15:29 +01001347 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001348
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001349 if error_name == ErrorIf.MaxSmallerMin:
1350 # Make sure the numbers are different to invoke this error
1351 while v[0] == v[1]:
1352 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1353 max_val = min(v)
1354 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001355 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001356 max_val = max(v)
1357 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001358
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001359 # Invalidate Input/Output list for error if checks.
1360 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001361 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001362 pCount, cCount = op["operands"]
1363 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001364 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1365 self, error_name, input_list, output_list
1366 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001367
Les Bell729b0352021-11-24 10:28:21 +00001368 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001369 self.ser,
1370 validator_fcns,
1371 error_name,
1372 op=op,
1373 max_val=max_val,
1374 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001375 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001376 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001377 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001378 output_dtype=result_tensor.dtype,
1379 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001380 input_list=input_list,
1381 output_list=output_list,
1382 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001383 ):
1384 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001385
1386 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001387 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1388 if a.dtype == DType.FP16:
1389 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1390 min_val = min_val.astype(np.float32)
1391 max_val = max_val.astype(np.float32)
1392
1393 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001394 else:
James Ward34071252022-12-07 15:48:47 +00001395 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001396
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001397 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001398
1399 compliance = self.tensorComplianceMetaData(
1400 op, a.dtype, args_dict, result_tensor, error_name
1401 )
1402
1403 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001404
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001405 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1406 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001407 attr = ts.TosaSerializerAttribute()
1408
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001409 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001410
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001411 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001412 return result_tens
1413
1414 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001415 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1416 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001417
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001418 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001419 return result_tens
1420
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001421 def build_activation(
1422 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1423 ):
1424 assert len(inputs) == 1
1425 a = inputs[0]
1426
1427 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001428
1429 # Invalidate Input/Output list for error if checks.
1430 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001431 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432 pCount, cCount = op["operands"]
1433 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001434 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1435 self, error_name, input_list, output_list
1436 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001437
Les Bell729b0352021-11-24 10:28:21 +00001438 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001439 self.ser,
1440 validator_fcns,
1441 error_name,
1442 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001443 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001444 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001445 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001446 output_dtype=result_tensor.dtype,
1447 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001448 input_list=input_list,
1449 output_list=output_list,
1450 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001451 ):
1452 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001453
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001454 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001455
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001456 compliance = self.tensorComplianceMetaData(
1457 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001458 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001459
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001460 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001461
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001462 def build_concat(
1463 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1464 ):
Won Jeon74342e52024-01-09 00:34:40 +00001465 if op["op"] == Op.CONCAT_SHAPE:
1466 axis = 0
1467 else:
1468 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001469 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001470 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001471
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001472 result_tensor = OutputShaper.concatOp(
1473 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001474 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001475
Matthew Haddon818ab902021-07-27 09:12:49 +01001476 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001477 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001478 input_tensor_names.append(tensor.name)
1479
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001480 # Invalidate Input/Output list for error if checks.
1481 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001482 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001483 pCount, cCount = op["operands"]
1484 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001485 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1486 self, error_name, input_list, output_list
1487 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001488
Les Bell729b0352021-11-24 10:28:21 +00001489 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001490 self.ser,
1491 validator_fcns,
1492 error_name,
1493 op=op,
1494 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001495 input_shape=inputs[0].shape,
1496 output_shape=result_tensor.shape,
1497 input_dtype=inputs[0].dtype,
1498 output_dtype=result_tensor.dtype,
1499 inputs=inputs,
1500 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001501 input_list=input_list,
1502 output_list=output_list,
1503 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001504 ):
1505 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001506
Won Jeon74342e52024-01-09 00:34:40 +00001507 if op["op"] == Op.CONCAT:
1508 attr = ts.TosaSerializerAttribute()
1509 attr.AxisAttribute(axis)
1510 else:
1511 assert op["op"] == Op.CONCAT_SHAPE
1512 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001513 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001514
1515 compliance = self.tensorComplianceMetaData(
1516 op, inputs[0].dtype, args_dict, result_tensor, error_name
1517 )
1518
1519 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001520
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001521 def build_pad(
1522 self,
1523 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001524 inputs,
1525 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001526 validator_fcns=None,
1527 error_name=None,
1528 qinfo=None,
1529 ):
Tai Lye095da72024-01-25 22:00:18 +00001530 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001531 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001532 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001533 padding = args_dict["pad"]
1534 pad_const_int = args_dict["pad_const_int"]
1535 pad_const_float = args_dict["pad_const_fp"]
1536
1537 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001538
Tai Lye095da72024-01-25 22:00:18 +00001539 # write empty padding into PadAttribute to ensure inputs[1] is used
Kevin Chengfe392ce2021-10-18 21:51:55 +00001540 attr = ts.TosaSerializerAttribute()
Tai Lye095da72024-01-25 22:00:18 +00001541 attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001542
Matthew Haddone807aae2021-10-11 18:12:58 +01001543 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001544 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001545 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001546 pCount, cCount = op["operands"]
1547 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001548 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1549 self, error_name, input_list, output_list
1550 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001551
Les Bell729b0352021-11-24 10:28:21 +00001552 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001553 self.ser,
1554 validator_fcns,
1555 error_name,
1556 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001557 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001558 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001559 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001560 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001561 pad=padding,
1562 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001563 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001564 input_list=input_list,
1565 output_list=output_list,
1566 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001567 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001568 ):
1569 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001570
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001571 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001572
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001573 compliance = self.tensorComplianceMetaData(
1574 op, a.dtype, args_dict, result_tensor, error_name
1575 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001576
1577 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001578
Won Jeona21b2e82023-08-10 10:33:01 +00001579 def build_dim(
1580 self,
1581 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001582 inputs,
1583 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001584 validator_fcns=None,
1585 error_name=None,
1586 qinfo=None,
1587 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001588 assert len(inputs) == 1
1589 a = inputs[0]
1590 axis = args_dict["axis"]
1591 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001592
1593 # Invalidate Input/Output list for error if checks.
1594 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001595 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001596 pCount, cCount = op["operands"]
1597 num_operands = pCount + cCount
1598 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1599 self, error_name, input_list, output_list
1600 )
1601
1602 if not TosaErrorValidator.evValidateErrorIfs(
1603 self.ser,
1604 validator_fcns,
1605 error_name,
1606 op=op,
1607 axis=axis,
1608 input_shape=a.shape,
1609 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001610 output_shape=result_tensor.shape,
1611 output_dtype=result_tensor.dtype,
1612 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001613 input_list=input_list,
1614 output_list=output_list,
1615 num_operands=num_operands,
1616 ):
1617 return None
1618
1619 attr = ts.TosaSerializerAttribute()
1620 attr.AxisAttribute(axis)
1621
1622 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001623 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001624
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001625 def build_reshape(
1626 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1627 ):
Tai Ly8690a082023-12-18 20:40:24 +00001628 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001629 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001630 shape = inputs[1]
1631 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001632 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001633 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001634 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001635
1636 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001637 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001638 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001639 pCount, cCount = op["operands"]
1640 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001641 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1642 self, error_name, input_list, output_list
1643 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001644
Les Bell729b0352021-11-24 10:28:21 +00001645 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001646 self.ser,
1647 validator_fcns,
1648 error_name,
1649 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001650 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001651 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001652 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001653 output_dtype=result_tensor.dtype,
1654 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001655 input_list=input_list,
1656 output_list=output_list,
1657 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001658 ):
1659 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001660
Tai Ly8690a082023-12-18 20:40:24 +00001661 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001662
1663 compliance = self.tensorComplianceMetaData(
1664 op, a.dtype, args_dict, result_tensor, error_name
1665 )
1666
1667 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001668
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001669 def build_reverse(
1670 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1671 ):
1672 assert len(inputs) == 1
1673 a = inputs[0]
1674 axis = args_dict["axis"]
1675 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001676
1677 # Invalidate Input/Output list for error if checks.
1678 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001679 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001680 pCount, cCount = op["operands"]
1681 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001682 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1683 self, error_name, input_list, output_list
1684 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001685
Les Bell729b0352021-11-24 10:28:21 +00001686 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001687 self.ser,
1688 validator_fcns,
1689 error_name,
1690 op=op,
1691 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001692 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001693 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001694 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001695 output_dtype=result_tensor.dtype,
1696 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001697 input_list=input_list,
1698 output_list=output_list,
1699 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001700 ):
1701 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001702
1703 attr = ts.TosaSerializerAttribute()
1704 attr.AxisAttribute(axis)
1705
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001706 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001707 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001708
evacha0198477222024-01-26 12:25:32 +00001709 def build_transpose(
1710 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1711 ):
1712 assert len(inputs) == 1
1713 a = inputs[0]
1714 perms = args_dict["perms"]
1715
1716 result_tensor = OutputShaper.transposeOp(
1717 self.ser, self.rng, a, perms, error_name
1718 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001719
Kevin Chengfe392ce2021-10-18 21:51:55 +00001720 attr = ts.TosaSerializerAttribute()
1721 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001722
Matthew Haddone807aae2021-10-11 18:12:58 +01001723 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001724 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001725 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001726 pCount, cCount = op["operands"]
1727 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001728 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1729 self, error_name, input_list, output_list
1730 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001731
Les Bell729b0352021-11-24 10:28:21 +00001732 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001733 self.ser,
1734 validator_fcns,
1735 error_name,
1736 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001737 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001738 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001739 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001740 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001741 output_dtype=result_tensor.dtype,
1742 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001743 input_list=input_list,
1744 output_list=output_list,
1745 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001746 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001747 ):
1748 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001749
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001750 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001751
1752 compliance = self.tensorComplianceMetaData(
1753 op, a.dtype, args_dict, result_tensor, error_name
1754 )
1755
1756 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001757
evacha017f7d4252024-01-24 12:08:09 +00001758 def build_slice(
1759 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1760 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001761 assert len(inputs) == 3
1762 a, start_var, size_var = inputs
1763 start_const = args_dict["start"]
1764 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001765
1766 result_tensor = OutputShaper.sliceOp(
TatWai Chongf15bad82024-01-31 21:33:27 -08001767 self.ser, self.rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001768 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001769
1770 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001771 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001772 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001773 pCount, cCount = op["operands"]
1774 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001775 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1776 self, error_name, input_list, output_list
1777 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001778
Les Bell729b0352021-11-24 10:28:21 +00001779 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001780 self.ser,
1781 validator_fcns,
1782 error_name,
1783 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001784 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001785 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001786 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001787 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001788 start=start_const,
1789 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001790 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001791 input_list=input_list,
1792 output_list=output_list,
1793 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001794 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001795 ):
1796 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001797
TatWai Chongf15bad82024-01-31 21:33:27 -08001798 # TODO remove the slice attribute once shape dynamism support is mature.
Eric Kunzee5e26762020-10-13 16:11:07 -07001799 attr = ts.TosaSerializerAttribute()
TatWai Chongf15bad82024-01-31 21:33:27 -08001800 attr.SliceAttribute(start_const, size_const)
Eric Kunzee5e26762020-10-13 16:11:07 -07001801
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001802 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha017f7d4252024-01-24 12:08:09 +00001803
1804 compliance = self.tensorComplianceMetaData(
1805 op, a.dtype, args_dict, result_tensor, error_name
1806 )
1807
1808 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001809
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001810 def build_tile(
1811 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1812 ):
Tai Ly8690a082023-12-18 20:40:24 +00001813 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001814 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001815 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001816 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001817 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001818 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001819 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001820
1821 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001822 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001823 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001824 pCount, cCount = op["operands"]
1825 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001826 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1827 self, error_name, input_list, output_list
1828 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001829
Les Bell729b0352021-11-24 10:28:21 +00001830 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001831 self.ser,
1832 validator_fcns,
1833 error_name,
1834 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001835 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001836 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001837 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001838 output_dtype=result_tensor.dtype,
1839 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001840 input_list=input_list,
1841 output_list=output_list,
1842 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001843 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001844 ):
1845 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001846
Tai Ly8690a082023-12-18 20:40:24 +00001847 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001848
1849 compliance = self.tensorComplianceMetaData(
1850 op, a.dtype, args_dict, result_tensor, error_name
1851 )
1852
1853 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001854
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001855 def build_gather(
1856 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1857 ):
1858 assert len(inputs) == 2
1859 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001860
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001861 result_tensor = OutputShaper.gatherOp(
1862 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001863 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001864
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001865 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001866 input_list = [values.name, indices.name]
1867 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001868 pCount, cCount = op["operands"]
1869 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001870 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1871 self, error_name, input_list, output_list
1872 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001873
Les Bell729b0352021-11-24 10:28:21 +00001874 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001875 self.ser,
1876 validator_fcns,
1877 error_name,
1878 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001879 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001880 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001881 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001882 output_dtype=result_tensor.dtype,
1883 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001884 input_list=input_list,
1885 output_list=output_list,
1886 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001887 ):
1888 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001889
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001890 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001891
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001892 compliance = self.tensorComplianceMetaData(
1893 op, values.dtype, args_dict, result_tensor, error_name
1894 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001895
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001896 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001897
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001898 def build_scatter(
1899 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1900 ):
1901 assert len(inputs) == 3
1902 values_in, indices, input = inputs
1903 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001904 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001905 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001906
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001907 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001908 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001909 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001910 pCount, cCount = op["operands"]
1911 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001912 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1913 self, error_name, input_list, output_list
1914 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001915
Les Bell729b0352021-11-24 10:28:21 +00001916 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001917 self.ser,
1918 validator_fcns,
1919 error_name,
1920 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001921 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001922 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001923 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001924 output_dtype=result_tensor.dtype,
1925 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001926 input_list=input_list,
1927 output_list=output_list,
1928 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001929 ):
1930 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001931
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001932 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001933
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001934 compliance = self.tensorComplianceMetaData(
1935 op, values_in.dtype, args_dict, result_tensor, error_name
1936 )
1937
1938 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001939
Kevin Cheng550ccc52021-03-03 11:21:43 -08001940 def build_resize(
1941 self,
1942 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001943 inputs,
1944 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01001945 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001946 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001947 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001948 ):
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001949 assert len(inputs) == 1
1950 input = inputs[0]
1951 mode = args_dict["mode"]
1952 scale = args_dict["scale"]
1953 offset = args_dict["offset"]
1954 border = args_dict["border"]
1955 output_dtype = args_dict["output_dtype"]
1956
1957 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08001958 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001959 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001960 input,
1961 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001962 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001963 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001964 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001965 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001966 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001967 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001968 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001969
Matthew Haddon848efb42021-09-09 12:30:53 +01001970 # Invalidate Input/Output list for error if checks.
1971 input_list = [input.name]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001972 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01001973 pCount, cCount = op["operands"]
1974 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001975 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1976 self, error_name, input_list, output_list
1977 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001978
Les Bell729b0352021-11-24 10:28:21 +00001979 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001980 self.ser,
1981 validator_fcns,
1982 error_name,
1983 op=op,
1984 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001985 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001986 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01001987 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001988 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001989 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001990 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001991 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001992 input_list=input_list,
1993 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001994 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01001995 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001996 ):
1997 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001998
Eric Kunzee5e26762020-10-13 16:11:07 -07001999 attr = ts.TosaSerializerAttribute()
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002000 attr.ResizeAttribute(scale, offset, border, mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002001 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002002
2003 compliance = self.tensorComplianceMetaData(
2004 op, input.dtype, args_dict, result_tensor, error_name
2005 )
2006
2007 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002008
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002009 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
2010 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
2011 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002012 self.ser.addOperator(
2013 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2014 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002015 return result_tens
2016
evacha0198477222024-01-26 12:25:32 +00002017 def build_const(
2018 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2019 ):
2020 assert len(inputs) == 1
2021 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002022 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002023
2024 compliance = self.tensorComplianceMetaData(
2025 op, val.dtype, args_dict, val, error_name
2026 )
2027
2028 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002029
2030 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002031 def build_cast(
2032 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2033 ):
2034 assert len(inputs) == 1
2035 val = inputs[0]
2036 out_dtype = args_dict["out_type"]
2037
2038 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002039 self.ser, self.rng, val, out_dtype, error_name
2040 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002041
2042 # Invalidate Input/Output list for error if checks.
2043 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002044 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002045 pCount, cCount = op["operands"]
2046 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002047 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2048 self, error_name, input_list, output_list
2049 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002050
Les Bell729b0352021-11-24 10:28:21 +00002051 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002052 self.ser,
2053 validator_fcns,
2054 error_name,
2055 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002056 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002057 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002058 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002059 output_dtype=result_tensor.dtype,
2060 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002061 input_list=input_list,
2062 output_list=output_list,
2063 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002064 ):
2065 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002066
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002067 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002068
2069 compliance = self.tensorComplianceMetaData(
2070 op, val.dtype, args_dict, result_tensor, error_name
2071 )
2072
2073 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002074
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002075 def build_rescale(
2076 self,
2077 op,
2078 val,
2079 out_dtype,
2080 scale32,
2081 double_round,
2082 per_channel,
2083 validator_fcns,
2084 error_name,
2085 ):
2086 result_tens = OutputShaper.typeConversionOp(
2087 self.ser, self.rng, val, out_dtype, error_name
2088 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002089
2090 if per_channel:
2091 nc = val.shape[-1]
2092 else:
2093 nc = 1
2094
2095 in_type_width = self.typeWidth(val.dtype)
2096 out_type_width = self.typeWidth(out_dtype)
2097
Tai Ly8690a082023-12-18 20:40:24 +00002098 input_unsigned = False
2099 output_unsigned = False
2100
Kevin Cheng3a478572021-01-22 17:21:02 -08002101 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002102 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002103 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002104 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002105 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002106 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002107 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002108 elif error_name in [
2109 ErrorIf.InputZeroPointNotZero,
2110 ErrorIf.U16InputZeroPointNotValid,
2111 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002112 input_zp = self.randInt(-128, 128)
2113 if input_zp == 0:
2114 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002115 in_type_width += 1
2116 elif val.dtype == DType.UINT16:
2117 # Must come after ErrorIf.U16InputZeroPointNotValid check
2118 input_zp = self.rng.choice([0, 32768])
2119 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002120 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002121 else:
2122 input_zp = 0
2123
Kevin Cheng3a478572021-01-22 17:21:02 -08002124 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002125 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002126 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002127 elif out_dtype == DType.UINT8:
2128 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002129 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002130 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002131 elif error_name in [
2132 ErrorIf.OutputZeroPointNotZero,
2133 ErrorIf.U16OutputZeroPointNotValid,
2134 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002135 output_zp = self.randInt(-128, 128)
2136 if output_zp == 0:
2137 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002138 out_type_width += 1
2139 elif out_dtype == DType.UINT16:
2140 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2141 output_zp = self.rng.choice([0, 32768])
2142 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002143 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002144 else:
2145 output_zp = 0
2146
2147 # Calculate scale based on:
2148 # scale = a *(2^output_width)/(2^input_width))
2149
2150 a = np.float32(self.rng.random(size=[nc]))
2151 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2152
2153 if scale32:
2154 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002155 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002156 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2157 else:
2158 # Cap the scaling at 2^15 - 1 for scale16
2159 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2160
Kevin Cheng550ccc52021-03-03 11:21:43 -08002161 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002162
2163 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2164 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002165 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2166 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002167
2168 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002169 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2170 scale_arr[i], scale32
2171 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002172 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2173 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002174
Kevin Cheng550ccc52021-03-03 11:21:43 -08002175 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002176 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002177 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002178 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002179 assert val.placeholderFilename
2180 values = np.load(
2181 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2182 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002183 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2184 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2185 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002186 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2187 # Check we can safely convert to the expected dtype
2188 assert (
2189 val_adj.all() >= np.iinfo(values.dtype).min
2190 and val_adj.all() <= np.iinfo(values.dtype).max
2191 )
2192
2193 # Force casting to output datatype
2194 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2195
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002196 if not np.all(np.array_equal(values, val_adj)):
2197 # Values changed so overwrite file with new values
2198 np.save(
2199 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2200 val_adj,
2201 False,
2202 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002203
Matthew Haddonc2025212021-10-08 21:21:05 +01002204 # Invalidate Input/Output list for error if checks.
2205 input_list = [val.name]
2206 output_list = [result_tens.name]
2207 pCount, cCount = op["operands"]
2208 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002209 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2210 self, error_name, input_list, output_list
2211 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002212
2213 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002214 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002215 self.ser,
2216 validator_fcns,
2217 error_name,
2218 op=op,
2219 input_dtype=val.dtype,
2220 output_dtype=out_dtype,
2221 input_shape=val.shape,
2222 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002223 scale32=scale32,
2224 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002225 input_list=input_list,
2226 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002227 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002228 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002229 ):
2230 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002231
Eric Kunzee5e26762020-10-13 16:11:07 -07002232 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002233 attr.RescaleAttribute(
2234 input_zp,
2235 output_zp,
2236 multiplier_arr,
2237 shift_arr,
2238 scale32,
2239 double_round,
2240 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002241 input_unsigned,
2242 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002243 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002244
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002245 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002246 return result_tens
2247
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002248 def _get_condition_tensor(self, op, cond, error_name):
2249 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002250 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002251 else:
2252 cond_type = DType.BOOL
2253 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2254 choice = self.rng.choice([1, 2])
2255 if choice == 1:
2256 cond_shape = [2]
2257 else:
2258 cond_shape = [1, 2]
2259 else:
2260 # Must be of size 1 (rank 0)
2261 cond_shape = []
2262 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2263 return cond_tens
2264
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002265 def build_cond_if_const(
2266 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2267 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002268 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002269 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002270 # and fill them with const nodes for the body.
2271
2272 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002273 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002274
2275 # Make then/else tensors
2276 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002277
2278 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002279 if error_name in [
2280 ErrorIf.CondIfOutputListThenGraphMismatch,
2281 ErrorIf.CondIfOutputListElseGraphMismatch,
2282 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002283 incorrect_shape = deepcopy(then_tens.shape)
2284 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002285 incorrect_shape[i] += (
2286 self.rng.choice([-3, -2, 2, 3])
2287 if incorrect_shape[i] > 3
2288 else self.rng.choice([1, 2, 4])
2289 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002290 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2291
Jeremy Johnson18e26662021-07-22 16:15:29 +01002292 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2293 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002294
2295 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002296 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002297
2298 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002299 then_block = "THEN_BLOCK"
2300 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002301 attr = ts.TosaSerializerAttribute()
2302 attr.CondIfAttribute(then_block, else_block)
2303
2304 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002305 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002306
Jerry Ge9e94af82022-10-27 09:57:00 -07002307 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002308 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002309 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2310 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2311 else:
2312 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002313 self.ser.addOutputTensor(then_tens)
2314
Jerry Ge9e94af82022-10-27 09:57:00 -07002315 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002316 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2317 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2318 else:
2319 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002320 self.ser.addOutputTensor(else_tens)
2321
Les Bell729b0352021-11-24 10:28:21 +00002322 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002323 self.ser,
2324 validator_fcns,
2325 error_name,
2326 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002327 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002328 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002329 ):
2330 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002331
Eric Kunzee5e26762020-10-13 16:11:07 -07002332 return result_tens
2333
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002334 def build_cond_if_binary(
2335 self, op, a, b, cond, validator_fcns=None, error_name=None
2336 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002337 # For cond_if with a binary op in the then/else blocks, take a and b and
2338 # alternately add or subtract them based on the condition
2339
2340 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002341 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002342
Kevin Cheng550ccc52021-03-03 11:21:43 -08002343 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002344
2345 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002346 then_block = "THEN_BLOCK"
2347 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002348 attr = ts.TosaSerializerAttribute()
2349 attr.CondIfAttribute(then_block, else_block)
2350
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002351 if error_name in [
2352 ErrorIf.CondIfInputListThenGraphMismatch,
2353 ErrorIf.CondIfInputListElseGraphMismatch,
2354 ErrorIf.CondIfOutputListElseGraphMismatch,
2355 ErrorIf.CondIfOutputListThenGraphMismatch,
2356 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002357 incorrect_shape = a.shape.copy()
2358 for i in range(len(incorrect_shape)):
2359 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2360 incorrect_block_input = deepcopy(a)
2361 incorrect_block_input.shape = incorrect_shape
2362
Eric Kunzee5e26762020-10-13 16:11:07 -07002363 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002364 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002365 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002366 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002367
James Ward24dbc422022-10-19 12:20:31 +01002368 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002369 then_op, else_op = Op.ADD, Op.SUB
2370 elif a.dtype in (DType.INT8, DType.INT16):
2371 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2372 else:
2373 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002374
Les Bell6040b4d2021-10-11 12:50:31 +01002375 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002376 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002377 if (
2378 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2379 and block == then_block
2380 ) or (
2381 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2382 and block == else_block
2383 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002384 self.ser.addInputTensor(incorrect_block_input)
2385 self.ser.addInputTensor(b)
2386 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002387 elif (
2388 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2389 and block == then_block
2390 ) or (
2391 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2392 and block == else_block
2393 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002394 self.ser.addInputTensor(a)
2395 self.ser.addInputTensor(b)
2396 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2397 else:
2398 self.ser.addInputTensor(a)
2399 self.ser.addInputTensor(b)
2400 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002401 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002402
Les Bell729b0352021-11-24 10:28:21 +00002403 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002404 self.ser,
2405 validator_fcns,
2406 error_name,
2407 op=op,
2408 a=a,
2409 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002410 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002411 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002412 ):
2413 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002414
Eric Kunzee5e26762020-10-13 16:11:07 -07002415 return result_tens
2416
Matthew Haddon630c17c2021-10-14 15:05:41 +01002417 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002418 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002419
Kevin Cheng550ccc52021-03-03 11:21:43 -08002420 cond_block = "COND_BLOCK"
2421 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002422
2423 attr = ts.TosaSerializerAttribute()
2424 attr.WhileLoopAttribute(cond_block, body_block)
2425
2426 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002427 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002428 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002429 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002430
2431 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002432 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2433 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002434 if error_name == ErrorIf.InputListOutputListMismatch:
2435 incorrect_acc = deepcopy(acc)
2436 for i in range(len(incorrect_acc.shape)):
2437 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2438 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2439 else:
2440 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002441
2442 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002443 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002444 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002445 [iter.name, a.name, acc.name],
2446 [iter_out.name, a_out.name, acc_out.name],
2447 attr,
2448 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002449 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002450
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002451 if error_name in [
2452 ErrorIf.InputListCondGraphMismatch,
2453 ErrorIf.InputListBodyGraphInputMismatch,
2454 ErrorIf.InputListBodyGraphOutputMismatch,
2455 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002456 incorrect_iter = deepcopy(iter)
2457 for i in range(len(incorrect_iter.shape)):
2458 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2459 if len(incorrect_iter.shape) == 0:
2460 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2461
2462 incorrect_acc = deepcopy(acc)
2463 for i in range(len(incorrect_acc.shape)):
2464 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2465
Eric Kunzee5e26762020-10-13 16:11:07 -07002466 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002467 self.ser.addBasicBlock(cond_block)
2468
Matthew Haddon630c17c2021-10-14 15:05:41 +01002469 if error_name == ErrorIf.InputListCondGraphMismatch:
2470 self.ser.addInputTensor(incorrect_iter)
2471 self.ser.addInputTensor(a)
2472 self.ser.addInputTensor(incorrect_acc)
2473 else:
2474 self.ser.addInputTensor(iter)
2475 self.ser.addInputTensor(a)
2476 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002477 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002478
2479 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002480 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002481 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002482 cond_type = DType.BOOL
2483 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2484 choice = self.rng.choice([1, 2])
2485 if choice == 1:
2486 cond_shape = [3]
2487 else:
2488 cond_shape = [1, 2]
2489 else:
2490 cond_shape = []
2491 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002492
Kevin Cheng550ccc52021-03-03 11:21:43 -08002493 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002494
2495 # BODY block (input: a, acc, iter, output: a, acc, iter)
2496 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002497 self.ser.addBasicBlock(body_block)
2498
Matthew Haddon630c17c2021-10-14 15:05:41 +01002499 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2500 self.ser.addInputTensor(incorrect_iter)
2501 self.ser.addInputTensor(a)
2502 self.ser.addInputTensor(incorrect_acc)
2503 else:
2504 self.ser.addInputTensor(iter)
2505 self.ser.addInputTensor(a)
2506 self.ser.addInputTensor(acc)
2507
Kevin Cheng550ccc52021-03-03 11:21:43 -08002508 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002509
2510 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002511 iter_body_out = self.ser.addIntermediate(
2512 incorrect_iter.shape, incorrect_iter.dtype
2513 )
2514 acc_body_out = self.ser.addIntermediate(
2515 incorrect_acc.shape, incorrect_acc.dtype
2516 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002517 else:
2518 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2519 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2520
Eric Kunzee5e26762020-10-13 16:11:07 -07002521 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2522 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2523 self.ser.addOutputTensor(iter_body_out)
2524 self.ser.addOutputTensor(a)
2525 self.ser.addOutputTensor(acc_body_out)
2526
Les Bell729b0352021-11-24 10:28:21 +00002527 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002528 self.ser,
2529 validator_fcns,
2530 error_name,
2531 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002532 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002533 ):
2534 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002535
Eric Kunzee5e26762020-10-13 16:11:07 -07002536 return acc_out
2537
Luke Hutton57287132023-02-06 14:54:18 +00002538 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002539 self,
2540 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002541 inputs,
2542 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002543 validator_fcns=None,
2544 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002545 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002546 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002547 assert len(inputs) == 2
2548 val1, val2 = inputs
2549 inverse = args_dict["inverse"]
2550
Luke Hutton57287132023-02-06 14:54:18 +00002551 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2552
2553 input_names = [val1.name, val2.name]
2554 pCount, cCount = op["operands"]
2555 num_operands = pCount + cCount
2556
2557 output_names = [res.name for res in results]
2558 output_shapes = [res.shape for res in results]
2559 output_dtypes = [res.dtype for res in results]
2560
2561 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2562 self, error_name, input_names, output_names
2563 )
2564
2565 if not TosaErrorValidator.evValidateErrorIfs(
2566 self.ser,
2567 validator_fcns,
2568 error_name,
2569 op=op,
2570 inverse=inverse,
2571 input1=val1,
2572 input2=val2,
2573 input_shape=val1.shape,
2574 input_dtype=val1.dtype,
2575 output_shape=output_shapes,
2576 output_dtype=output_dtypes,
2577 result_tensors=results,
2578 input_list=input_names,
2579 output_list=output_names,
2580 num_operands=num_operands,
2581 ):
2582 return None
2583
Tai Lyd3797f02023-11-15 23:06:19 +00002584 # TODO - Test local_bound, for now set local bound attribute to False
2585 local_bound = False
2586
Luke Hutton57287132023-02-06 14:54:18 +00002587 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002588 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002589
2590 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002591
2592 compliance = []
2593 for res in results:
2594 compliance.append(
2595 self.tensorComplianceMetaData(
2596 op, val1.dtype, args_dict, res, error_name
2597 )
2598 )
2599
2600 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002601
Tai Lyd3797f02023-11-15 23:06:19 +00002602 def build_rfft2d(
2603 self,
2604 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002605 inputs,
2606 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002607 validator_fcns=None,
2608 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002609 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002610 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002611 assert len(inputs) == 1
2612 val = inputs[0]
Luke Hutton261b7b62023-01-10 14:50:31 +00002613 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2614
2615 input_names = [val.name]
2616 pCount, cCount = op["operands"]
2617 num_operands = pCount + cCount
2618
2619 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002620 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002621 output_dtypes = [res.dtype for res in results]
2622
2623 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2624 self, error_name, input_names, output_names
2625 )
2626
2627 if not TosaErrorValidator.evValidateErrorIfs(
2628 self.ser,
2629 validator_fcns,
2630 error_name,
2631 op=op,
2632 input_shape=val.shape,
2633 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002634 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002635 output_dtype=output_dtypes,
2636 result_tensors=results,
2637 input_list=input_names,
2638 output_list=output_names,
2639 num_operands=num_operands,
2640 ):
2641 return None
2642
Tai Lyd3797f02023-11-15 23:06:19 +00002643 # TODO - Test local_bound, for now set local bound attribute to False
2644 local_bound = False
2645
2646 attr = ts.TosaSerializerAttribute()
2647 attr.RFFTAttribute(local_bound)
2648
2649 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002650
2651 compliance = []
2652 for res in results:
2653 compliance.append(
2654 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2655 )
2656
2657 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002658
Won Jeon74342e52024-01-09 00:34:40 +00002659 def build_shape_op(
2660 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2661 ):
2662 assert len(inputs) == 2
2663 a, b = inputs
2664
2665 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2666
2667 # Invalidate Input/Output list for error if checks.
2668 input_list = [a.name, b.name]
2669 output_list = [result_tensor.name]
2670 pCount, cCount = op["operands"]
2671 num_operands = pCount + cCount
2672 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2673 self, error_name, input_list, output_list
2674 )
2675
2676 if not TosaErrorValidator.evValidateErrorIfs(
2677 self.ser,
2678 validator_fcns,
2679 error_name,
2680 op=op,
2681 input1=a,
2682 input2=b,
2683 input_shape=a.shape,
2684 input_dtype=a.dtype,
2685 output_shape=result_tensor.shape,
2686 output_dtype=result_tensor.dtype,
2687 result_tensors=[result_tensor],
2688 input_list=input_list,
2689 output_list=output_list,
2690 num_operands=num_operands,
2691 ):
2692 return None
2693
2694 self.ser.addOperator(
2695 op["op"],
2696 input_list,
2697 output_list,
2698 )
2699 compliance = self.tensorComplianceMetaData(
2700 op, a.dtype, args_dict, result_tensor, error_name
2701 )
2702
2703 return TosaTestGen.BuildInfo(result_tensor, compliance)
2704
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002705 def create_filter_lists(
2706 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2707 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002708 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2709 default_test_rank_range = range(1, 5)
2710 if not shapeFilter:
2711 shapeFilter = [None]
2712
2713 # Calculate the filters based on what is requested and what the operator allows
2714 rmin, rmax = op["rank"]
2715 if rankFilter is not None:
2716 cleanRankFilter = []
2717 # Ensure rankFilter values are allowed by operator
2718 for rank in rankFilter:
2719 if rank >= rmin and rank <= rmax:
2720 cleanRankFilter.append(rank)
2721 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002722 # Ensure default behaviour is bounded by default range or by operator,
2723 # whichever is the smaller range of ranks.
2724 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002725 cleanRankFilter = (
2726 opRankRange
2727 if len(opRankRange) <= len(default_test_rank_range)
2728 else default_test_rank_range
2729 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002730 else:
2731 cleanRankFilter = range(rmin, rmax + 1)
2732
2733 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002734
Matthew Haddon1c00b712021-10-01 15:51:03 +01002735 if dtypeFilter is not None:
2736 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002737 # Create list of operator dtypes filtered by requested dtypes
2738 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002739 if dtype in dtypeFilter or (
2740 isinstance(dtype, list) and dtype[0] in dtypeFilter
2741 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002742 cleanDtypeFilter.append(dtype)
2743 else:
2744 cleanDtypeFilter = dtypes
2745
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002746 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002747 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002748 "shapeFilter": shapeFilter,
2749 "rankFilter": cleanRankFilter,
2750 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002751 }
2752 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002753 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002754 if validator is not None:
2755 validator_info = validator(check=False, op=op)
2756 else:
2757 return None
2758
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002759 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002760
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002761 # Set parameters as required
2762 if error_arguments["rank"] is not None:
2763 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002764 else:
2765 rankFilter = cleanRankFilter
2766
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002767 if error_arguments["dtype"] is not None:
2768 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002769 else:
2770 dtypeFilter = cleanDtypeFilter
2771
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002772 if error_arguments["shape"] is not None:
2773 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002774 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002775 shapeFilter = shapeFilter[
2776 :2
2777 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002778
2779 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002780 "shapeFilter": shapeFilter,
2781 "rankFilter": rankFilter,
2782 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002783 }
2784 return filterDict
2785
Kevin Cheng550ccc52021-03-03 11:21:43 -08002786 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002787 self,
2788 opName,
2789 shapeFilter=[None],
2790 rankFilter=None,
2791 dtypeFilter=None,
2792 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002793 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002794
2795 try:
2796 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002797 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002798 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002799
2800 # Initialize a new random number generator
2801 self.rng = np.random.default_rng(self.random_seed)
2802
Jeremy Johnson1271c442023-09-05 11:39:26 +01002803 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002804
Eric Kunzee5e26762020-10-13 16:11:07 -07002805 # Test list consists of a tuple of:
2806 # (opName, testNameStr, dtype, shapeList, argumentsList)
2807 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002808 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002809 error_if_validators = op["error_if_validators"]
2810 else:
2811 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002812
Matthew Haddon1c00b712021-10-01 15:51:03 +01002813 for validator in error_if_validators:
2814 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002815 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002816 else:
2817 error_name = None
2818
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002819 filterDict = self.create_filter_lists(
2820 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2821 )
2822 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002823 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002824 cleanRankFilter = filterDict["rankFilter"]
2825 cleanDtypeFilter = filterDict["dtypeFilter"]
2826 cleanShapeFilter = filterDict["shapeFilter"]
2827 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002828
2829 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002830 for t in cleanDtypeFilter:
2831 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002832 # Filter out by rank
2833 if shape is not None and len(shape) != r:
2834 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002835 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002836 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002837
Matthew Haddon74567092021-07-16 15:38:20 +01002838 shapeStr = self.shapeStr(shapeList[0])
2839 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002840
Matthew Haddon74567092021-07-16 15:38:20 +01002841 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2842 argList = []
2843 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002844 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002845 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002846 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002847
Matthew Haddon74567092021-07-16 15:38:20 +01002848 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002849 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002850 if argStr:
2851 testStr = "{}_{}_{}_{}".format(
2852 opName, shapeStr, typeStr, argStr
2853 )
2854 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002855 testStr = "{}_{}_{}".format(
2856 opName, shapeStr, typeStr
2857 )
2858 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002859 if argStr:
2860 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2861 opName, error_name, shapeStr, typeStr, argStr
2862 )
2863 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002864 testStr = "{}_ERRORIF_{}_{}_{}".format(
2865 opName, error_name, shapeStr, typeStr
2866 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002867
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002868 testList.append(
2869 (opName, testStr, t, error_name, shapeList, args)
2870 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002871
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002872 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002873 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2874 if "invalid_test_validators" in op:
2875 invalid_test_validators = op["invalid_test_validators"]
2876 clean_testList = []
2877 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002878 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002879 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002880 if validator_fcn(
2881 opName=test[0],
2882 input_dtype=test[2],
2883 shapeList=test[4],
2884 args=test[5],
2885 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002886 remove_test = True
2887 if not remove_test:
2888 clean_testList.append(test)
2889 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002890
2891 return testList
2892
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002893 def serializeTest(
2894 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2895 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002896 try:
2897 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002898 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002899 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002900
Jeremy Johnson0c716862023-04-13 17:18:19 +01002901 if self.args.verbose:
2902 print(f"Creating {testStr}")
2903
Eric Kunzee5e26762020-10-13 16:11:07 -07002904 # Create a serializer
2905 self.createSerializer(opName, testStr)
2906
Jeremy Johnson1271c442023-09-05 11:39:26 +01002907 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002908 if "error_if_validators" in op:
2909 error_if_validators = op["error_if_validators"]
2910 else:
2911 error_if_validators = None
2912
Kevin Cheng550ccc52021-03-03 11:21:43 -08002913 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002914 num_operands = pCount + cCount
2915
2916 if isinstance(dtype_or_dtypeList, list):
2917 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00002918 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002919 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002920 else:
2921 dtypeList = [dtype_or_dtypeList] * (num_operands)
2922
Won Jeon74342e52024-01-09 00:34:40 +00002923 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002924 assert (
2925 len(shapeList) == num_operands
2926 ), "shapeList length {} must match number of operands {}".format(
2927 len(shapeList), num_operands
2928 )
2929 assert (
2930 len(dtypeList) == num_operands
2931 ), "dtypeList length {} must match number of operands {}".format(
2932 len(dtypeList), num_operands
2933 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002934
2935 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002936 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002937 except KeyError:
2938 qgen = None
2939
2940 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002941
Matthew Haddon1c00b712021-10-01 15:51:03 +01002942 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002943 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002944 else:
2945 qinfo = None
2946
Jeremy Johnson1271c442023-09-05 11:39:26 +01002947 # Extra meta data for the desc.json
2948 tensMeta = {}
2949
2950 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002951 if isinstance(testArgs, dict):
2952 # New interface with args info in dictionary
2953 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002954 assert "dg_type" in argsDict
2955 tvgInfo = tvgen_fcn(
2956 self, opName, dtypeList, shapeList, argsDict, error_name
2957 )
2958 if tvgInfo.dataGenDict:
2959 tensMeta["data_gen"] = tvgInfo.dataGenDict
2960 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002961
2962 result = build_fcn(
2963 self,
2964 op,
2965 tens,
2966 argsDict,
2967 validator_fcns=error_if_validators,
2968 error_name=error_name,
2969 qinfo=qinfo,
2970 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002971 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002972 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002973 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002974
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002975 try:
2976 if error_if_validators is None:
2977 if qinfo is not None:
2978 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2979 else:
2980 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002981 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002982 if qinfo is not None:
2983 result = build_fcn(
2984 self,
2985 op,
2986 *tens,
2987 *testArgs,
2988 validator_fcns=error_if_validators,
2989 error_name=error_name,
2990 qinfo=qinfo,
2991 )
2992 else:
2993 result = build_fcn(
2994 self,
2995 op,
2996 *tens,
2997 *testArgs,
2998 validator_fcns=error_if_validators,
2999 error_name=error_name,
3000 )
3001 except TypeError as e:
3002 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
3003 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01003004
Jeremy Johnson1271c442023-09-05 11:39:26 +01003005 if result:
Les Bell729b0352021-11-24 10:28:21 +00003006 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003007 if isinstance(result, TosaTestGen.BuildInfo):
3008 # Add the compliance meta data (if any)
3009 compliance = result.getComplianceInfo()
3010 if compliance:
3011 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003012 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00003013 else:
3014 # The test is not valid
3015 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01003016
Eric Kunzee5e26762020-10-13 16:11:07 -07003017 def createDynamicOpLists(self):
3018
Jeremy Johnson00423432022-09-12 17:27:37 +01003019 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3020 # Already created these lists (can occur when class is initialized more than once)
3021 return
3022
Eric Kunzee5e26762020-10-13 16:11:07 -07003023 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003024 if not self.args.level8k:
3025 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3026 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3027 else:
3028 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3029 KERNELS_2D = [[1, bigK], [bigK, 2]]
3030 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003031
Kevin Cheng1533b852021-09-01 12:51:58 -07003032 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003033 testName = "conv2d_{}x{}".format(k[0], k[1])
3034 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3035 self.TOSA_OP_LIST[testName]["filter"] = k
3036 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003037
Kevin Cheng550ccc52021-03-03 11:21:43 -08003038 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3039 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3040 "depthwise_conv2d_TEMPLATE"
3041 ].copy()
3042 self.TOSA_OP_LIST[testName]["filter"] = k
3043 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003044
Kevin Cheng550ccc52021-03-03 11:21:43 -08003045 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3046 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3047 "transpose_conv2d_TEMPLATE"
3048 ].copy()
3049 self.TOSA_OP_LIST[testName]["filter"] = k
3050 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003051
Kevin Cheng1533b852021-09-01 12:51:58 -07003052 for k in KERNELS_3D:
3053 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3054 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3055 self.TOSA_OP_LIST[testName]["filter"] = k
3056 self.TOSA_OP_LIST[testName]["template"] = False
3057
Eric Kunzee5e26762020-10-13 16:11:07 -07003058 # Delete any templates after having created any dynamic ops
3059 # This is a two-pass operation because it's bad practice to delete
3060 # keys from dictionaries while iterating
3061 keyList = []
3062 for k in self.TOSA_OP_LIST:
3063 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003064 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003065 keyList.append(k)
3066 continue
3067 except KeyError:
3068 pass
3069
3070 for k in keyList:
3071 del self.TOSA_OP_LIST[k]
3072
3073 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003074 """Fill in default fields for ops if they aren't already specified.
3075 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003076 for op in self.TOSA_OP_LIST:
3077
3078 # Required fields
3079 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003080 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003081 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003082 raise Exception(
3083 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3084 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003085
3086 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003087 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003088 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003089 raise Exception(
3090 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3091 op
3092 )
3093 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003094
3095 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003096 _ = self.TOSA_OP_LIST[op]["types"]
3097 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003098 raise Exception(
3099 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3100 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003101
3102 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003103 _ = self.TOSA_OP_LIST[op]["op"]
3104 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003105 raise Exception(
3106 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3107 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003108
3109 # Put in default rank range, if missing
3110 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003111 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003112 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003113 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003114
3115 # Tensor operator list
3116 # 'op': op name
3117 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003118 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3119 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003120 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3121 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003122 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003123
Kevin Cheng550ccc52021-03-03 11:21:43 -08003124 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003125 TYPE_INT_FP = [
3126 DType.INT8,
3127 DType.INT16,
3128 DType.INT32,
3129 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003130 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003131 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003132 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003133
Kevin Cheng550ccc52021-03-03 11:21:43 -08003134 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003135 TYPE_FI32 = [
3136 DType.FP32,
3137 DType.FP16,
3138 DType.BF16,
3139 DType.INT32,
3140 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003141 TYPE_FIB = [
3142 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003143 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003144 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003145 DType.INT8,
3146 DType.INT16,
3147 DType.INT32,
3148 DType.BOOL,
3149 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003150 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003151
James Ward24dbc422022-10-19 12:20:31 +01003152 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07003153
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003154 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003155 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003156 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003157 [DType.INT8, DType.INT8, DType.INT32],
3158 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003159 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003160 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003161 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003162 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003163 ]
3164
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003165 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003166
3167 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003168 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003169 "argmax": {
3170 "op": Op.ARGMAX,
3171 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003172 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003173 "build_fcn": (
3174 build_argmax,
3175 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003176 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003177 TosaArgGen.agAxis,
3178 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003179 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003180 "error_if_validators": (
3181 TosaErrorValidator.evAxisSmallerZero,
3182 TosaErrorValidator.evAxisLargerRank,
3183 TosaErrorValidator.evArgmaxOutputRankMismatch,
3184 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3185 TosaErrorValidator.evWrongRank,
3186 TosaErrorValidator.evWrongInputType,
3187 TosaErrorValidator.evWrongOutputType,
3188 TosaErrorValidator.evWrongInputList,
3189 TosaErrorValidator.evWrongOutputList,
3190 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003191 "data_gen": {
3192 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3193 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003194 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003195 "avg_pool2d": {
3196 "op": Op.AVG_POOL2D,
3197 "operands": (1, 0),
3198 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003199 "build_fcn": (
3200 build_pool2d,
3201 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003202 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003203 TosaArgGen.agPooling,
3204 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003205 "qgen": TosaQuantGen.qgUnary,
3206 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003207 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003208 "error_if_validators": (
3209 TosaErrorValidator.evKernelSmallerOne,
3210 TosaErrorValidator.evStrideSmallerOne,
3211 TosaErrorValidator.evPadSmallerZero,
3212 TosaErrorValidator.evWrongRank,
3213 TosaErrorValidator.evWrongInputType,
3214 TosaErrorValidator.evWrongOutputType,
3215 TosaErrorValidator.evWrongInputList,
3216 TosaErrorValidator.evWrongOutputList,
3217 TosaErrorValidator.evInputZeroPointNotZero,
3218 TosaErrorValidator.evOutputZeroPointNotZero,
3219 TosaErrorValidator.evPadLargerEqualKernel,
3220 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003221 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003222 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003223 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003224 "data_gen": {
3225 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3226 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003227 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003228 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003229 "conv2d_TEMPLATE": {
3230 "op": Op.CONV2D,
3231 "operands": (1, 2),
3232 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003233 "build_fcn": (
3234 build_conv2d,
3235 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003236 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003237 TosaArgGen.agConv,
3238 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003239 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003240 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003241 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3242 "error_if_validators": (
3243 TosaErrorValidator.evWrongInputType,
3244 TosaErrorValidator.evWrongOutputType,
3245 TosaErrorValidator.evWrongInputList,
3246 TosaErrorValidator.evWrongOutputList,
3247 TosaErrorValidator.evInputZeroPointNotZero,
3248 TosaErrorValidator.evWeightZeroPointNotZero,
3249 TosaErrorValidator.evPadSmallerZero,
3250 TosaErrorValidator.evStrideSmallerOne,
3251 TosaErrorValidator.evDilationSmallerOne,
3252 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003253 TosaErrorValidator.evConvOutputShapeMismatch,
3254 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003255 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003256 "data_gen": {
3257 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3258 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003259 "template": True,
3260 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003261 # Templated operator. Filled in by createDynamicOpLists
3262 "conv3d_TEMPLATE": {
3263 "op": Op.CONV3D,
3264 "operands": (1, 2),
3265 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003266 "build_fcn": (
3267 build_conv3d,
3268 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003269 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003270 TosaArgGen.agConv,
3271 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003272 "qgen": TosaQuantGen.qgConv,
3273 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003274 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3275 "error_if_validators": (
3276 TosaErrorValidator.evWrongInputType,
3277 TosaErrorValidator.evWrongOutputType,
3278 TosaErrorValidator.evWrongInputList,
3279 TosaErrorValidator.evWrongOutputList,
3280 TosaErrorValidator.evInputZeroPointNotZero,
3281 TosaErrorValidator.evWeightZeroPointNotZero,
3282 TosaErrorValidator.evPadSmallerZero,
3283 TosaErrorValidator.evStrideSmallerOne,
3284 TosaErrorValidator.evDilationSmallerOne,
3285 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003286 TosaErrorValidator.evConvOutputShapeMismatch,
3287 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003288 ),
evacha0147ab1762024-01-29 13:23:23 +00003289 "data_gen": {
3290 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3291 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003292 "template": True,
3293 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003294 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003295 "depthwise_conv2d_TEMPLATE": {
3296 "op": Op.DEPTHWISE_CONV2D,
3297 "operands": (1, 2),
3298 "filter": [1, 1],
3299 "rank": (4, 4),
3300 "build_fcn": (
3301 build_depthwise_conv2d,
3302 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003303 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003304 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003305 ),
3306 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003307 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003308 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3309 "error_if_validators": (
3310 TosaErrorValidator.evWrongInputType,
3311 TosaErrorValidator.evWrongOutputType,
3312 TosaErrorValidator.evWrongInputList,
3313 TosaErrorValidator.evWrongOutputList,
3314 TosaErrorValidator.evInputZeroPointNotZero,
3315 TosaErrorValidator.evWeightZeroPointNotZero,
3316 TosaErrorValidator.evPadSmallerZero,
3317 TosaErrorValidator.evStrideSmallerOne,
3318 TosaErrorValidator.evDilationSmallerOne,
3319 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003320 TosaErrorValidator.evConvOutputShapeMismatch,
3321 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003322 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003323 "data_gen": {
3324 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3325 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003326 "template": True,
3327 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003328 "fully_connected": {
3329 "op": Op.FULLY_CONNECTED,
3330 "operands": (1, 2),
3331 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003332 "build_fcn": (
3333 build_fully_connected,
3334 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003335 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003336 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003337 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003339 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003340 "error_if_validators": (
3341 TosaErrorValidator.evInputZeroPointNotZero,
3342 TosaErrorValidator.evWeightZeroPointNotZero,
3343 TosaErrorValidator.evWrongRank,
3344 TosaErrorValidator.evWrongInputType,
3345 TosaErrorValidator.evWrongOutputType,
3346 TosaErrorValidator.evWrongInputList,
3347 TosaErrorValidator.evWrongOutputList,
3348 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003349 "data_gen": {
3350 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3351 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003352 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003353 "matmul": {
3354 "op": Op.MATMUL,
3355 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003356 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003357 "build_fcn": (
3358 build_matmul,
3359 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003360 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003361 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003362 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003363 "qgen": TosaQuantGen.qgMatmul,
3364 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003365 "error_if_validators": (
3366 TosaErrorValidator.evInputZeroPointNotZero,
3367 TosaErrorValidator.evWrongRank,
3368 TosaErrorValidator.evWrongInputType,
3369 TosaErrorValidator.evWrongOutputType,
3370 TosaErrorValidator.evWrongInputList,
3371 TosaErrorValidator.evWrongOutputList,
3372 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003373 "data_gen": {
3374 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003375 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003376 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003377 "max_pool2d": {
3378 "op": Op.MAX_POOL2D,
3379 "operands": (1, 0),
3380 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003381 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003382 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003383 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003384 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003385 TosaArgGen.agPooling,
3386 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003387 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003388 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003389 "error_if_validators": (
3390 TosaErrorValidator.evKernelSmallerOne,
3391 TosaErrorValidator.evStrideSmallerOne,
3392 TosaErrorValidator.evPadSmallerZero,
3393 TosaErrorValidator.evWrongRank,
3394 TosaErrorValidator.evWrongInputType,
3395 TosaErrorValidator.evWrongOutputType,
3396 TosaErrorValidator.evWrongInputList,
3397 TosaErrorValidator.evWrongOutputList,
3398 TosaErrorValidator.evPadLargerEqualKernel,
3399 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003400 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003401 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003402 "data_gen": {
3403 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3404 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003405 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003406 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003407 "transpose_conv2d_TEMPLATE": {
3408 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003409 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003410 "rank": (4, 4),
3411 "build_fcn": (
3412 build_transpose_conv2d,
3413 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003414 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003415 TosaArgGen.agTransposeConv2D,
3416 ),
3417 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003418 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003419 "invalid_test_validators": (
3420 TosaInvalidValidator.ivHeightWidthInvalid,
3421 TosaInvalidValidator.ivNonPositiveOutputShape,
3422 ),
3423 "error_if_validators": (
3424 TosaErrorValidator.evWrongInputType,
3425 TosaErrorValidator.evWrongOutputType,
3426 TosaErrorValidator.evWrongInputList,
3427 TosaErrorValidator.evWrongOutputList,
3428 TosaErrorValidator.evInputZeroPointNotZero,
3429 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003430 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003431 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003432 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003433 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003434 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003435 "data_gen": {
3436 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3437 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003438 "template": True,
3439 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003440 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003441 "clamp": {
3442 "op": Op.CLAMP,
3443 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003444 "build_fcn": (
3445 build_clamp,
3446 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003447 TosaTensorValuesGen.tvgLazyGenDefault,
3448 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003449 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003450 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003451 "error_if_validators": (
3452 TosaErrorValidator.evMaxSmallerMin,
3453 TosaErrorValidator.evWrongInputType,
3454 TosaErrorValidator.evWrongOutputType,
3455 TosaErrorValidator.evWrongInputList,
3456 TosaErrorValidator.evWrongOutputList,
3457 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003458 "data_gen": {
3459 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3460 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003461 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003462 "sigmoid": {
3463 "op": Op.SIGMOID,
3464 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003465 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003466 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003467 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003468 TosaTensorValuesGen.tvgLazyGenDefault,
3469 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003470 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003471 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003472 "error_if_validators": (
3473 TosaErrorValidator.evWrongInputType,
3474 TosaErrorValidator.evWrongOutputType,
3475 TosaErrorValidator.evWrongInputList,
3476 TosaErrorValidator.evWrongOutputList,
3477 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003478 "data_gen": {
3479 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3480 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003481 },
3482 "tanh": {
3483 "op": Op.TANH,
3484 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003485 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003486 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003487 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003488 TosaTensorValuesGen.tvgLazyGenDefault,
3489 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003490 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003491 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003492 "error_if_validators": (
3493 TosaErrorValidator.evWrongInputType,
3494 TosaErrorValidator.evWrongOutputType,
3495 TosaErrorValidator.evWrongInputList,
3496 TosaErrorValidator.evWrongOutputList,
3497 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003498 "data_gen": {
3499 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3500 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003501 "compliance": {
3502 "abs_error_lower_bound": 0.5,
3503 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003504 },
Won Jeon78155c62023-06-10 00:20:04 +00003505 "erf": {
3506 "op": Op.ERF,
3507 "operands": (1, 0),
3508 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003509 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003510 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003511 TosaTensorValuesGen.tvgLazyGenDefault,
3512 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003513 ),
3514 "types": TYPE_FP,
3515 "error_if_validators": (
3516 TosaErrorValidator.evWrongInputType,
3517 TosaErrorValidator.evWrongOutputType,
3518 TosaErrorValidator.evWrongInputList,
3519 TosaErrorValidator.evWrongOutputList,
3520 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003521 "data_gen": {
3522 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3523 },
3524 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003525 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003526 # Elementwise Binary Operators
3527 "add": {
3528 "op": Op.ADD,
3529 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003530 "build_fcn": (
3531 build_binary_broadcast,
3532 TosaTensorGen.tgBroadcastFuzz,
3533 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003534 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003535 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003536 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003537 "error_if_validators": (
3538 TosaErrorValidator.evRankMismatch,
3539 TosaErrorValidator.evWrongInputType,
3540 TosaErrorValidator.evWrongOutputType,
3541 TosaErrorValidator.evWrongInputList,
3542 TosaErrorValidator.evWrongOutputList,
3543 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003544 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003545 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003546 "data_gen": {
3547 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3548 },
3549 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003550 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003551 "arithmetic_right_shift": {
3552 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3553 "operands": (2, 0),
3554 "build_fcn": (
3555 build_arithmetic_right_shift,
3556 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003557 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003558 TosaArgGen.agArithmeticRightShift,
3559 ),
3560 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003561 "error_if_validators": (
3562 TosaErrorValidator.evRankMismatch,
3563 TosaErrorValidator.evWrongInputType,
3564 TosaErrorValidator.evWrongOutputType,
3565 TosaErrorValidator.evWrongInputList,
3566 TosaErrorValidator.evWrongOutputList,
3567 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003568 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003569 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003570 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003571 "bitwise_and": {
3572 "op": Op.BITWISE_AND,
3573 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003574 "build_fcn": (
3575 build_binary_broadcast,
3576 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003577 TosaTensorValuesGen.tvgLazyGenDefault,
3578 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003579 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003580 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003581 "error_if_validators": (
3582 TosaErrorValidator.evRankMismatch,
3583 TosaErrorValidator.evWrongInputType,
3584 TosaErrorValidator.evWrongOutputType,
3585 TosaErrorValidator.evWrongInputList,
3586 TosaErrorValidator.evWrongOutputList,
3587 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003588 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003589 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003590 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003591 "bitwise_or": {
3592 "op": Op.BITWISE_OR,
3593 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003594 "build_fcn": (
3595 build_binary_broadcast,
3596 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003597 TosaTensorValuesGen.tvgLazyGenDefault,
3598 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003599 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003600 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003601 "error_if_validators": (
3602 TosaErrorValidator.evRankMismatch,
3603 TosaErrorValidator.evWrongInputType,
3604 TosaErrorValidator.evWrongOutputType,
3605 TosaErrorValidator.evWrongInputList,
3606 TosaErrorValidator.evWrongOutputList,
3607 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003608 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003609 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003610 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003611 "bitwise_xor": {
3612 "op": Op.BITWISE_XOR,
3613 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003614 "build_fcn": (
3615 build_binary_broadcast,
3616 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003617 TosaTensorValuesGen.tvgLazyGenDefault,
3618 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003619 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003620 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003621 "error_if_validators": (
3622 TosaErrorValidator.evRankMismatch,
3623 TosaErrorValidator.evWrongInputType,
3624 TosaErrorValidator.evWrongOutputType,
3625 TosaErrorValidator.evWrongInputList,
3626 TosaErrorValidator.evWrongOutputList,
3627 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003628 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003629 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003630 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003631 "intdiv": {
3632 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003633 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003634 "build_fcn": (
3635 build_binary_broadcast,
3636 TosaTensorGen.tgBroadcastFuzz,
3637 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003638 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003639 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003640 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003641 "error_if_validators": (
3642 TosaErrorValidator.evRankMismatch,
3643 TosaErrorValidator.evWrongInputType,
3644 TosaErrorValidator.evWrongOutputType,
3645 TosaErrorValidator.evWrongInputList,
3646 TosaErrorValidator.evWrongOutputList,
3647 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003648 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003649 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003650 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003651 "logical_and": {
3652 "op": Op.LOGICAL_AND,
3653 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003654 "build_fcn": (
3655 build_binary_broadcast,
3656 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003657 TosaTensorValuesGen.tvgLazyGenDefault,
3658 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003659 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003660 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003661 "error_if_validators": (
3662 TosaErrorValidator.evRankMismatch,
3663 TosaErrorValidator.evWrongInputType,
3664 TosaErrorValidator.evWrongOutputType,
3665 TosaErrorValidator.evWrongInputList,
3666 TosaErrorValidator.evWrongOutputList,
3667 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003668 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003669 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003670 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003671 "logical_left_shift": {
3672 "op": Op.LOGICAL_LEFT_SHIFT,
3673 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003674 "build_fcn": (
3675 build_binary_broadcast,
3676 TosaTensorGen.tgBroadcastFuzz,
3677 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003678 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003679 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003680 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003681 "error_if_validators": (
3682 TosaErrorValidator.evRankMismatch,
3683 TosaErrorValidator.evWrongInputType,
3684 TosaErrorValidator.evWrongOutputType,
3685 TosaErrorValidator.evWrongInputList,
3686 TosaErrorValidator.evWrongOutputList,
3687 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003688 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003689 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003690 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003691 "logical_right_shift": {
3692 "op": Op.LOGICAL_RIGHT_SHIFT,
3693 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003694 "build_fcn": (
3695 build_binary_broadcast,
3696 TosaTensorGen.tgBroadcastFuzz,
3697 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003698 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003699 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003700 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003701 "error_if_validators": (
3702 TosaErrorValidator.evRankMismatch,
3703 TosaErrorValidator.evWrongInputType,
3704 TosaErrorValidator.evWrongOutputType,
3705 TosaErrorValidator.evWrongInputList,
3706 TosaErrorValidator.evWrongOutputList,
3707 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003708 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003709 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003710 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003711 "logical_or": {
3712 "op": Op.LOGICAL_OR,
3713 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003714 "build_fcn": (
3715 build_binary_broadcast,
3716 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003717 TosaTensorValuesGen.tvgLazyGenDefault,
3718 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003719 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003720 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003721 "error_if_validators": (
3722 TosaErrorValidator.evRankMismatch,
3723 TosaErrorValidator.evWrongInputType,
3724 TosaErrorValidator.evWrongOutputType,
3725 TosaErrorValidator.evWrongInputList,
3726 TosaErrorValidator.evWrongOutputList,
3727 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003728 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003729 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003730 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003731 "logical_xor": {
3732 "op": Op.LOGICAL_XOR,
3733 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003734 "build_fcn": (
3735 build_binary_broadcast,
3736 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003737 TosaTensorValuesGen.tvgLazyGenDefault,
3738 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003739 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003740 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003741 "error_if_validators": (
3742 TosaErrorValidator.evRankMismatch,
3743 TosaErrorValidator.evWrongInputType,
3744 TosaErrorValidator.evWrongOutputType,
3745 TosaErrorValidator.evWrongInputList,
3746 TosaErrorValidator.evWrongOutputList,
3747 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003748 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003749 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003750 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003751 "maximum": {
3752 "op": Op.MAXIMUM,
3753 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003754 "build_fcn": (
3755 build_binary_broadcast,
3756 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003757 TosaTensorValuesGen.tvgLazyGenDefault,
3758 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003759 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003760 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003761 "error_if_validators": (
3762 TosaErrorValidator.evRankMismatch,
3763 TosaErrorValidator.evWrongInputType,
3764 TosaErrorValidator.evWrongOutputType,
3765 TosaErrorValidator.evWrongInputList,
3766 TosaErrorValidator.evWrongOutputList,
3767 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003768 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003769 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003770 "data_gen": {
3771 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3772 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003773 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003774 "minimum": {
3775 "op": Op.MINIMUM,
3776 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003777 "build_fcn": (
3778 build_binary_broadcast,
3779 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003780 TosaTensorValuesGen.tvgLazyGenDefault,
3781 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003782 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003783 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003784 "error_if_validators": (
3785 TosaErrorValidator.evRankMismatch,
3786 TosaErrorValidator.evWrongInputType,
3787 TosaErrorValidator.evWrongOutputType,
3788 TosaErrorValidator.evWrongInputList,
3789 TosaErrorValidator.evWrongOutputList,
3790 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003791 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003792 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003793 "data_gen": {
3794 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3795 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003796 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003797 "mul": {
3798 "op": Op.MUL,
3799 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003800 "build_fcn": (
3801 build_mul,
3802 TosaTensorGen.tgBroadcastFuzz,
3803 TosaTensorValuesGen.tvgMul,
3804 TosaArgGen.agMul,
3805 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003806 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003807 "error_if_validators": (
3808 TosaErrorValidator.evWrongInputType,
3809 TosaErrorValidator.evWrongOutputType,
3810 TosaErrorValidator.evWrongInputList,
3811 TosaErrorValidator.evWrongOutputList,
3812 TosaErrorValidator.evRankMismatch,
3813 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003814 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003815 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003816 "data_gen": {
3817 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3818 },
3819 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003820 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003821 "pow": {
3822 "op": Op.POW,
3823 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003824 "build_fcn": (
3825 build_binary_broadcast,
3826 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003827 TosaTensorValuesGen.tvgPow,
3828 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003829 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003830 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003831 "error_if_validators": (
3832 TosaErrorValidator.evRankMismatch,
3833 TosaErrorValidator.evWrongInputType,
3834 TosaErrorValidator.evWrongOutputType,
3835 TosaErrorValidator.evWrongInputList,
3836 TosaErrorValidator.evWrongOutputList,
3837 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003838 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003839 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003840 "data_gen": {
3841 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3842 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003843 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003844 "sub": {
3845 "op": Op.SUB,
3846 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003847 "build_fcn": (
3848 build_binary_broadcast,
3849 TosaTensorGen.tgBroadcastFuzz,
3850 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003851 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003852 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003853 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003854 "error_if_validators": (
3855 TosaErrorValidator.evRankMismatch,
3856 TosaErrorValidator.evWrongInputType,
3857 TosaErrorValidator.evWrongOutputType,
3858 TosaErrorValidator.evWrongInputList,
3859 TosaErrorValidator.evWrongOutputList,
3860 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003861 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003862 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003863 "data_gen": {
3864 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3865 },
3866 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003867 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003868 "table": {
3869 "op": Op.TABLE,
3870 # Use the automatic generation functions to create the input array
3871 # but create the table tensor in the build function, as it may be
3872 # a different type from the input
3873 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003874 "build_fcn": (
3875 build_table,
3876 TosaTensorGen.tgBasic,
3877 TosaTensorValuesGen.tvgDefault,
3878 TosaArgGen.agTable,
3879 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003880 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003881 "error_if_validators": (
3882 TosaErrorValidator.evWrongInputType,
3883 TosaErrorValidator.evWrongOutputType,
3884 TosaErrorValidator.evWrongInputList,
3885 TosaErrorValidator.evWrongOutputList,
3886 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003887 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003888 # Elementwise Unary operators
3889 "abs": {
3890 "op": Op.ABS,
3891 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003892 "build_fcn": (
3893 build_unary,
3894 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003895 TosaTensorValuesGen.tvgLazyGenDefault,
3896 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003897 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003898 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003899 "error_if_validators": (
3900 TosaErrorValidator.evWrongInputType,
3901 TosaErrorValidator.evWrongOutputType,
3902 TosaErrorValidator.evWrongInputList,
3903 TosaErrorValidator.evWrongOutputList,
3904 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003905 "data_gen": {
3906 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3907 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003908 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003909 "bitwise_not": {
3910 "op": Op.BITWISE_NOT,
3911 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003912 "build_fcn": (
3913 build_unary,
3914 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003915 TosaTensorValuesGen.tvgLazyGenDefault,
3916 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003917 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003918 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003919 "error_if_validators": (
3920 TosaErrorValidator.evWrongInputType,
3921 TosaErrorValidator.evWrongOutputType,
3922 TosaErrorValidator.evWrongInputList,
3923 TosaErrorValidator.evWrongOutputList,
3924 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003925 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003926 "ceil": {
3927 "op": Op.CEIL,
3928 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003929 "build_fcn": (
3930 build_unary,
3931 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003932 TosaTensorValuesGen.tvgLazyGenDefault,
3933 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003934 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003935 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003936 "error_if_validators": (
3937 TosaErrorValidator.evWrongInputType,
3938 TosaErrorValidator.evWrongOutputType,
3939 TosaErrorValidator.evWrongInputList,
3940 TosaErrorValidator.evWrongOutputList,
3941 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003942 "data_gen": {
3943 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3944 },
3945 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003946 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003947 "clz": {
3948 "op": Op.CLZ,
3949 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003950 "build_fcn": (
3951 build_unary,
3952 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003953 TosaTensorValuesGen.tvgLazyGenDefault,
3954 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003955 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003956 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003957 "error_if_validators": (
3958 TosaErrorValidator.evWrongInputType,
3959 TosaErrorValidator.evWrongOutputType,
3960 TosaErrorValidator.evWrongInputList,
3961 TosaErrorValidator.evWrongOutputList,
3962 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003963 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003964 "exp": {
3965 "op": Op.EXP,
3966 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003967 "build_fcn": (
3968 build_unary,
3969 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003970 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003971 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003972 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003973 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003974 "error_if_validators": (
3975 TosaErrorValidator.evWrongInputType,
3976 TosaErrorValidator.evWrongOutputType,
3977 TosaErrorValidator.evWrongInputList,
3978 TosaErrorValidator.evWrongOutputList,
3979 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003980 "data_gen": {
3981 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3982 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003983 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003984 "floor": {
3985 "op": Op.FLOOR,
3986 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003987 "build_fcn": (
3988 build_unary,
3989 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003990 TosaTensorValuesGen.tvgLazyGenDefault,
3991 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003992 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003993 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003994 "error_if_validators": (
3995 TosaErrorValidator.evWrongInputType,
3996 TosaErrorValidator.evWrongOutputType,
3997 TosaErrorValidator.evWrongInputList,
3998 TosaErrorValidator.evWrongOutputList,
3999 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004000 "data_gen": {
4001 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4002 },
4003 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004004 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004005 "log": {
4006 "op": Op.LOG,
4007 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004008 "build_fcn": (
4009 build_unary,
4010 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004011 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004012 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004013 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004014 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004015 "error_if_validators": (
4016 TosaErrorValidator.evWrongInputType,
4017 TosaErrorValidator.evWrongOutputType,
4018 TosaErrorValidator.evWrongInputList,
4019 TosaErrorValidator.evWrongOutputList,
4020 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004021 "data_gen": {
4022 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4023 },
4024 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004025 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004026 "logical_not": {
4027 "op": Op.LOGICAL_NOT,
4028 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004029 "build_fcn": (
4030 build_unary,
4031 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004032 TosaTensorValuesGen.tvgLazyGenDefault,
4033 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004034 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004035 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004036 "error_if_validators": (
4037 TosaErrorValidator.evWrongInputType,
4038 TosaErrorValidator.evWrongOutputType,
4039 TosaErrorValidator.evWrongInputList,
4040 TosaErrorValidator.evWrongOutputList,
4041 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004042 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004043 "negate": {
4044 "op": Op.NEGATE,
4045 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004046 "build_fcn": (
4047 build_unary,
4048 TosaTensorGen.tgBasic,
4049 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004050 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004051 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004052 "qgen": TosaQuantGen.qgUnary,
4053 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004054 "error_if_validators": (
4055 TosaErrorValidator.evInputZeroPointNotZero,
4056 TosaErrorValidator.evOutputZeroPointNotZero,
4057 TosaErrorValidator.evWrongInputType,
4058 TosaErrorValidator.evWrongOutputType,
4059 TosaErrorValidator.evWrongInputList,
4060 TosaErrorValidator.evWrongOutputList,
4061 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004062 "data_gen": {
4063 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4064 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004065 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004066 "reciprocal": {
4067 "op": Op.RECIPROCAL,
4068 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004069 "build_fcn": (
4070 build_unary,
4071 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004072 TosaTensorValuesGen.tvgLazyGenDefault,
4073 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004074 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004075 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004076 "error_if_validators": (
4077 TosaErrorValidator.evWrongInputType,
4078 TosaErrorValidator.evWrongOutputType,
4079 TosaErrorValidator.evWrongInputList,
4080 TosaErrorValidator.evWrongOutputList,
4081 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004082 "data_gen": {
4083 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4084 },
4085 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004086 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004087 "rsqrt": {
4088 "op": Op.RSQRT,
4089 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004090 "build_fcn": (
4091 build_unary,
4092 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004093 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004094 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004095 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004096 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004097 "error_if_validators": (
4098 TosaErrorValidator.evWrongInputType,
4099 TosaErrorValidator.evWrongOutputType,
4100 TosaErrorValidator.evWrongInputList,
4101 TosaErrorValidator.evWrongOutputList,
4102 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004103 "data_gen": {
4104 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4105 },
4106 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004107 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004108 # Elementwise Ternary operators
4109 "select": {
4110 "op": Op.SELECT,
4111 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004112 "build_fcn": (
4113 build_select,
4114 TosaTensorGen.tgBroadcastFuzz,
4115 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004116 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004117 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004118 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004119 "error_if_validators": (
4120 TosaErrorValidator.evRankMismatch,
4121 TosaErrorValidator.evWrongInputType,
4122 TosaErrorValidator.evWrongOutputType,
4123 TosaErrorValidator.evWrongInputList,
4124 TosaErrorValidator.evWrongOutputList,
4125 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004126 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004127 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004128 "data_gen": {
4129 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4130 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004131 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004132 # Comparison operators
4133 "equal": {
4134 "op": Op.EQUAL,
4135 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004136 "build_fcn": (
4137 build_comparison,
4138 TosaTensorGen.tgBroadcastFuzz,
4139 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004140 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004141 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004142 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004143 "error_if_validators": (
4144 TosaErrorValidator.evRankMismatch,
4145 TosaErrorValidator.evWrongInputType,
4146 TosaErrorValidator.evWrongOutputType,
4147 TosaErrorValidator.evWrongInputList,
4148 TosaErrorValidator.evWrongOutputList,
4149 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004150 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004151 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004152 "data_gen": {
4153 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4154 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004155 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004156 "greater_equal": {
4157 "op": Op.GREATER_EQUAL,
4158 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004159 "build_fcn": (
4160 build_comparison,
4161 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004162 TosaTensorValuesGen.tvgLazyGenDefault,
4163 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004164 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004165 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004166 "error_if_validators": (
4167 TosaErrorValidator.evRankMismatch,
4168 TosaErrorValidator.evWrongInputType,
4169 TosaErrorValidator.evWrongOutputType,
4170 TosaErrorValidator.evWrongInputList,
4171 TosaErrorValidator.evWrongOutputList,
4172 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004173 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004174 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004175 "data_gen": {
4176 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4177 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004178 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004179 "greater": {
4180 "op": Op.GREATER,
4181 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004182 "build_fcn": (
4183 build_comparison,
4184 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004185 TosaTensorValuesGen.tvgLazyGenDefault,
4186 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004187 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004188 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004189 "error_if_validators": (
4190 TosaErrorValidator.evRankMismatch,
4191 TosaErrorValidator.evWrongInputType,
4192 TosaErrorValidator.evWrongOutputType,
4193 TosaErrorValidator.evWrongInputList,
4194 TosaErrorValidator.evWrongOutputList,
4195 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004196 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004197 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004198 "data_gen": {
4199 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4200 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004201 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004202 # Reduction operators
4203 "reduce_all": {
4204 "op": Op.REDUCE_ALL,
4205 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004206 "build_fcn": (
4207 build_reduce,
4208 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004209 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004210 TosaArgGen.agAxis,
4211 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004212 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004213 "error_if_validators": (
4214 TosaErrorValidator.evAxisLargerRank,
4215 TosaErrorValidator.evAxisSmallerZero,
4216 TosaErrorValidator.evShapeOfAxisNotOne,
4217 TosaErrorValidator.evWrongInputType,
4218 TosaErrorValidator.evWrongOutputType,
4219 TosaErrorValidator.evWrongRank,
4220 TosaErrorValidator.evWrongInputList,
4221 TosaErrorValidator.evWrongOutputList,
4222 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004223 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004224 "reduce_any": {
4225 "op": Op.REDUCE_ANY,
4226 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004227 "build_fcn": (
4228 build_reduce,
4229 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004230 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004231 TosaArgGen.agAxis,
4232 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004233 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004234 "error_if_validators": (
4235 TosaErrorValidator.evAxisLargerRank,
4236 TosaErrorValidator.evAxisSmallerZero,
4237 TosaErrorValidator.evShapeOfAxisNotOne,
4238 TosaErrorValidator.evWrongInputType,
4239 TosaErrorValidator.evWrongOutputType,
4240 TosaErrorValidator.evWrongRank,
4241 TosaErrorValidator.evWrongInputList,
4242 TosaErrorValidator.evWrongOutputList,
4243 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004244 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004245 "reduce_max": {
4246 "op": Op.REDUCE_MAX,
4247 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004248 "build_fcn": (
4249 build_reduce,
4250 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004251 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004252 TosaArgGen.agAxis,
4253 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004254 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004255 "error_if_validators": (
4256 TosaErrorValidator.evAxisLargerRank,
4257 TosaErrorValidator.evAxisSmallerZero,
4258 TosaErrorValidator.evShapeOfAxisNotOne,
4259 TosaErrorValidator.evWrongInputType,
4260 TosaErrorValidator.evWrongOutputType,
4261 TosaErrorValidator.evWrongRank,
4262 TosaErrorValidator.evWrongInputList,
4263 TosaErrorValidator.evWrongOutputList,
4264 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004265 "data_gen": {
4266 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4267 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004268 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004269 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004270 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004271 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004272 "build_fcn": (
4273 build_reduce,
4274 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004275 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004276 TosaArgGen.agAxis,
4277 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004278 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004279 "error_if_validators": (
4280 TosaErrorValidator.evAxisLargerRank,
4281 TosaErrorValidator.evAxisSmallerZero,
4282 TosaErrorValidator.evShapeOfAxisNotOne,
4283 TosaErrorValidator.evWrongInputType,
4284 TosaErrorValidator.evWrongOutputType,
4285 TosaErrorValidator.evWrongRank,
4286 TosaErrorValidator.evWrongInputList,
4287 TosaErrorValidator.evWrongOutputList,
4288 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004289 "data_gen": {
4290 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4291 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004292 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004293 "reduce_product": {
4294 "op": Op.REDUCE_PRODUCT,
4295 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004296 "build_fcn": (
4297 build_reduce,
4298 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004299 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004300 TosaArgGen.agAxis,
4301 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004302 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004303 "error_if_validators": (
4304 TosaErrorValidator.evAxisLargerRank,
4305 TosaErrorValidator.evAxisSmallerZero,
4306 TosaErrorValidator.evShapeOfAxisNotOne,
4307 TosaErrorValidator.evWrongInputType,
4308 TosaErrorValidator.evWrongOutputType,
4309 TosaErrorValidator.evWrongRank,
4310 TosaErrorValidator.evWrongInputList,
4311 TosaErrorValidator.evWrongOutputList,
4312 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004313 "data_gen": {
4314 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4315 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004316 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004317 "reduce_sum": {
4318 "op": Op.REDUCE_SUM,
4319 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004320 "build_fcn": (
4321 build_reduce,
4322 TosaTensorGen.tgBasic,
4323 TosaTensorValuesGen.tvgReduceSum,
4324 TosaArgGen.agAxis,
4325 ),
James Ward24dbc422022-10-19 12:20:31 +01004326 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004327 "error_if_validators": (
4328 TosaErrorValidator.evAxisLargerRank,
4329 TosaErrorValidator.evAxisSmallerZero,
4330 TosaErrorValidator.evShapeOfAxisNotOne,
4331 TosaErrorValidator.evWrongInputType,
4332 TosaErrorValidator.evWrongOutputType,
4333 TosaErrorValidator.evWrongRank,
4334 TosaErrorValidator.evWrongInputList,
4335 TosaErrorValidator.evWrongOutputList,
4336 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004337 "data_gen": {
4338 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4339 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004340 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004341 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004342 "concat": {
4343 "op": Op.CONCAT,
4344 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004345 "build_fcn": (
4346 build_concat,
4347 TosaTensorGen.tgConcat,
4348 TosaTensorValuesGen.tvgConcat,
4349 TosaArgGen.agAxis,
4350 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004351 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004352 "error_if_validators": (
4353 TosaErrorValidator.evAxisLargerRank,
4354 TosaErrorValidator.evAxisSmallerZero,
4355 TosaErrorValidator.evConcatInputRankMismatch,
4356 TosaErrorValidator.evConcatShapeSumMismatch,
4357 TosaErrorValidator.evConcatInputDimMismatch,
4358 TosaErrorValidator.evWrongInputType,
4359 TosaErrorValidator.evWrongOutputType,
4360 TosaErrorValidator.evWrongOutputList,
4361 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004362 "data_gen": {
4363 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4364 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004365 },
4366 "pad": {
4367 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004368 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004369 "build_fcn": (
4370 build_pad,
4371 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004372 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004373 TosaArgGen.agPad,
4374 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004375 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004376 "error_if_validators": (
4377 TosaErrorValidator.evWrongInputType,
4378 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004379 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004380 TosaErrorValidator.evWrongOutputType,
4381 TosaErrorValidator.evWrongInputList,
4382 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004383 TosaErrorValidator.evRankMismatch,
4384 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004385 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004386 "data_gen": {
4387 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4388 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004389 },
Won Jeona21b2e82023-08-10 10:33:01 +00004390 "dim": {
4391 "op": Op.DIM,
4392 "operands": (1, 0),
4393 "build_fcn": (
4394 build_dim,
4395 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004396 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004397 TosaArgGen.agAxis,
4398 ),
4399 "types": TYPE_FIB,
4400 "error_if_validators": (
4401 TosaErrorValidator.evAxisLargerRank,
4402 TosaErrorValidator.evAxisSmallerZero,
4403 TosaErrorValidator.evWrongInputType,
4404 TosaErrorValidator.evWrongInputList,
4405 TosaErrorValidator.evWrongOutputList,
4406 TosaErrorValidator.evWrongRank,
4407 ),
4408 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004409 "reshape": {
4410 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004411 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004412 "build_fcn": (
4413 build_reshape,
4414 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004415 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004416 TosaArgGen.agReshape,
4417 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004418 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004419 "error_if_validators": (
4420 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4421 TosaErrorValidator.evWrongInputType,
4422 TosaErrorValidator.evWrongOutputType,
4423 TosaErrorValidator.evWrongInputList,
4424 TosaErrorValidator.evWrongOutputList,
4425 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004426 "data_gen": {
4427 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4428 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004429 },
4430 "reverse": {
4431 "op": Op.REVERSE,
4432 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004433 "build_fcn": (
4434 build_reverse,
4435 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004436 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004437 TosaArgGen.agAxis,
4438 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004439 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004440 "error_if_validators": (
4441 TosaErrorValidator.evAxisSmallerZero,
4442 TosaErrorValidator.evAxisLargerRank,
4443 TosaErrorValidator.evWrongInputType,
4444 TosaErrorValidator.evWrongOutputType,
4445 TosaErrorValidator.evWrongInputList,
4446 TosaErrorValidator.evWrongOutputList,
4447 ),
evacha0198477222024-01-26 12:25:32 +00004448 "data_gen": {
4449 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4450 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004451 },
4452 "slice": {
4453 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004454 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004455 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004456 "build_fcn": (
4457 build_slice,
4458 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004459 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004460 TosaArgGen.agSlice,
4461 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004462 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004463 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004464 # TODO Turn off these error categories for now as the reference
4465 # model cannot allocate memory space for empty tensor. We probably
4466 # can report an accurate error messege at the right place during
4467 # exeuction.
4468 # TosaErrorValidator.evStartSmallerZero,
4469 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004470 TosaErrorValidator.evStartSizeOutsideBounds,
4471 TosaErrorValidator.evSizeOutputShapeMismatch,
4472 TosaErrorValidator.evInputSizeStartLengthMismatch,
4473 TosaErrorValidator.evWrongRank,
4474 TosaErrorValidator.evWrongInputType,
4475 TosaErrorValidator.evWrongOutputType,
4476 TosaErrorValidator.evWrongInputList,
4477 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004478 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004479 ),
evacha017f7d4252024-01-24 12:08:09 +00004480 "data_gen": {
4481 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4482 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004483 },
4484 "tile": {
4485 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004486 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004487 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004488 "build_fcn": (
4489 build_tile,
4490 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004491 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004492 TosaArgGen.agTile,
4493 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004494 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004495 "error_if_validators": (
4496 TosaErrorValidator.evWrongInputType,
4497 TosaErrorValidator.evWrongOutputType,
4498 TosaErrorValidator.evWrongInputList,
4499 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004500 TosaErrorValidator.evRankMismatch,
4501 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004502 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004503 "data_gen": {
4504 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4505 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004506 },
4507 "transpose": {
4508 "op": Op.TRANSPOSE,
4509 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004510 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004511 "build_fcn": (
4512 build_transpose,
4513 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004514 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004515 TosaArgGen.agTranspose,
4516 ),
4517 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004518 "error_if_validators": (
4519 TosaErrorValidator.evIndexOutsideBounds,
4520 TosaErrorValidator.evIndexUsedTwice,
4521 TosaErrorValidator.evWrongInputType,
4522 TosaErrorValidator.evWrongOutputType,
4523 TosaErrorValidator.evWrongInputList,
4524 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004525 TosaErrorValidator.evWrongRank,
4526 TosaErrorValidator.evRankMismatch,
4527 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004528 ),
evacha0198477222024-01-26 12:25:32 +00004529 "data_gen": {
4530 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4531 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004532 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004533 # Data nodes
4534 "const": {
4535 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004536 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004537 "build_fcn": (
4538 build_const,
4539 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004540 TosaTensorValuesGen.tvgLazyGenDefault,
4541 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004542 ),
Luke Hutton65872422023-02-20 10:33:04 +00004543 "types": TYPE_FIB + [DType.INT48],
evacha0198477222024-01-26 12:25:32 +00004544 "data_gen": {
4545 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4546 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004547 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004548 "identity": {
4549 "op": Op.IDENTITY,
4550 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004551 "build_fcn": (
4552 build_unary,
4553 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004554 TosaTensorValuesGen.tvgLazyGenDefault,
4555 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004556 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004557 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004558 "data_gen": {
4559 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4560 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004561 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004562 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004563 "gather": {
4564 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004565 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004566 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004567 "build_fcn": (
4568 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004569 TosaTensorGen.tgGather,
4570 TosaTensorValuesGen.tvgGather,
4571 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004572 ),
James Ward24dbc422022-10-19 12:20:31 +01004573 "types": (
4574 DType.INT8,
4575 DType.INT16,
4576 DType.INT32,
4577 DType.FP16,
4578 DType.BF16,
4579 DType.FP32,
4580 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004581 "error_if_validators": (
4582 TosaErrorValidator.evWrongInputType,
4583 TosaErrorValidator.evWrongOutputType,
4584 TosaErrorValidator.evWrongInputList,
4585 TosaErrorValidator.evWrongOutputList,
4586 TosaErrorValidator.evWrongRank,
4587 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004588 "data_gen": {
4589 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4590 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004591 },
4592 "scatter": {
4593 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004594 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004595 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004596 "build_fcn": (
4597 build_scatter,
4598 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004599 TosaTensorValuesGen.tvgScatter,
4600 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004601 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004602 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004603 "error_if_validators": (
4604 TosaErrorValidator.evWrongInputType,
4605 TosaErrorValidator.evWrongOutputType,
4606 TosaErrorValidator.evWrongInputList,
4607 TosaErrorValidator.evWrongOutputList,
4608 TosaErrorValidator.evWrongRank,
4609 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004610 "data_gen": {
4611 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4612 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004613 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004614 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004615 "resize": {
4616 "op": Op.RESIZE,
4617 "operands": (1, 0),
4618 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004619 "build_fcn": (
4620 build_resize,
4621 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004622 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004623 TosaArgGen.agResize,
4624 ),
James Ward24dbc422022-10-19 12:20:31 +01004625 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004626 "invalid_test_validators": (
4627 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004628 ),
4629 "error_if_validators": (
4630 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004631 TosaErrorValidator.evScaleSmallerEqualZero,
4632 TosaErrorValidator.evScaleNLargerMax,
4633 TosaErrorValidator.evScaleDLargerMax,
4634 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004635 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004636 TosaErrorValidator.evBorderSmallerMin,
4637 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004638 TosaErrorValidator.evWrongInputType,
4639 TosaErrorValidator.evWrongOutputType,
4640 TosaErrorValidator.evWrongRank,
4641 TosaErrorValidator.evWrongInputList,
4642 TosaErrorValidator.evWrongOutputList,
4643 TosaErrorValidator.evBatchMismatch,
4644 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004645 TosaErrorValidator.evResizeOutputShapeMismatch,
4646 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004647 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004648 "data_gen": {
4649 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4650 },
4651 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004652 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004653 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004654 "cast": {
4655 "op": Op.CAST,
4656 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004657 "build_fcn": (
4658 build_cast,
4659 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004660 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004661 TosaArgGen.agCast,
4662 ),
James Ward8b390432022-08-12 20:48:56 +01004663 "types": (
4664 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004665 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004666 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004667 DType.INT8,
4668 DType.INT16,
4669 DType.INT32,
4670 DType.BOOL,
4671 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004672 "error_if_validators": (
4673 TosaErrorValidator.evWrongInputType,
4674 TosaErrorValidator.evWrongOutputType,
4675 TosaErrorValidator.evWrongInputList,
4676 TosaErrorValidator.evWrongOutputList,
4677 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004678 "data_gen": {
4679 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4680 },
4681 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004682 },
4683 "rescale": {
4684 "op": Op.RESCALE,
4685 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004686 "build_fcn": (
4687 build_rescale,
4688 TosaTensorGen.tgBasic,
4689 TosaTensorValuesGen.tvgDefault,
4690 TosaArgGen.agRescale,
4691 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004692 "types": [
4693 DType.UINT8,
4694 DType.INT8,
4695 DType.INT16,
4696 DType.INT32,
4697 DType.INT48,
4698 DType.UINT16,
4699 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004700 "error_if_validators": (
4701 TosaErrorValidator.evInputZeroPointNotZero,
4702 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004703 TosaErrorValidator.evU16InputZeroPointNotValid,
4704 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004705 TosaErrorValidator.evScaleTrue,
4706 TosaErrorValidator.evScaleNotTrue,
4707 TosaErrorValidator.evWrongInputType,
4708 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004709 TosaErrorValidator.evWrongInputList,
4710 TosaErrorValidator.evWrongOutputList,
4711 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004712 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004713 # Custom
4714 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004715 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004716 # Two varients of cond_if, one that generates one of two constant tensors (no
4717 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4718 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004719 "cond_if_const": {
4720 "op": Op.COND_IF,
4721 "operands": (0, 2),
4722 "build_fcn": (
4723 build_cond_if_const,
4724 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004725 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004726 TosaArgGen.agCondIf,
4727 ),
4728 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004729 "error_if_validators": (
4730 TosaErrorValidator.evOutputListThenGraphMismatch,
4731 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004732 TosaErrorValidator.evCondIfCondNotMatchingBool,
4733 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004734 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004735 },
4736 "cond_if_binary": {
4737 "op": Op.COND_IF,
4738 "operands": (2, 0),
4739 "build_fcn": (
4740 build_cond_if_binary,
4741 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004742 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004743 TosaArgGen.agCondIf,
4744 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004745 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004746 "error_if_validators": (
4747 TosaErrorValidator.evInputListThenGraphMismatch,
4748 TosaErrorValidator.evInputListElseGraphMismatch,
4749 TosaErrorValidator.evOutputListThenGraphMismatch,
4750 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004751 TosaErrorValidator.evCondIfCondNotMatchingBool,
4752 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004753 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004754 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004755 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004756 "while_loop": {
4757 "op": Op.WHILE_LOOP,
4758 "operands": (0, 1),
4759 "build_fcn": (
4760 build_while_loop,
4761 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004762 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004763 TosaArgGen.agWhileLoop,
4764 ),
4765 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004766 "error_if_validators": (
4767 TosaErrorValidator.evInputListOutputListMismatch,
4768 TosaErrorValidator.evInputListCondGraphMismatch,
4769 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4770 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4771 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004772 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004773 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004774 },
Luke Hutton57287132023-02-06 14:54:18 +00004775 "fft2d": {
4776 "op": Op.FFT2D,
4777 "operands": (2, 0),
4778 "rank": (3, 3),
4779 "build_fcn": (
4780 build_fft2d,
4781 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004782 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004783 TosaArgGen.agFFT2d,
4784 ),
4785 "types": [DType.FP32],
4786 "error_if_validators": (
4787 TosaErrorValidator.evWrongInputType,
4788 TosaErrorValidator.evWrongOutputType,
4789 TosaErrorValidator.evWrongInputList,
4790 TosaErrorValidator.evWrongOutputList,
4791 TosaErrorValidator.evWrongRank,
4792 TosaErrorValidator.evBatchMismatch,
4793 TosaErrorValidator.evKernelNotPowerOfTwo,
4794 TosaErrorValidator.evFFTInputShapeMismatch,
4795 TosaErrorValidator.evFFTOutputShapeMismatch,
4796 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004797 "data_gen": {
4798 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4799 },
Luke Hutton57287132023-02-06 14:54:18 +00004800 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004801 "rfft2d": {
4802 "op": Op.RFFT2D,
4803 "operands": (1, 0),
4804 "rank": (3, 3),
4805 "build_fcn": (
4806 build_rfft2d,
4807 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004808 TosaTensorValuesGen.tvgLazyGenDefault,
4809 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004810 ),
4811 "types": [DType.FP32],
4812 "error_if_validators": (
4813 TosaErrorValidator.evWrongInputType,
4814 TosaErrorValidator.evWrongOutputType,
4815 TosaErrorValidator.evWrongInputList,
4816 TosaErrorValidator.evWrongOutputList,
4817 TosaErrorValidator.evWrongRank,
4818 TosaErrorValidator.evBatchMismatch,
4819 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004820 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004821 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004822 "data_gen": {
4823 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4824 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004825 },
Won Jeon74342e52024-01-09 00:34:40 +00004826 # Shape
4827 "add_shape": {
4828 "op": Op.ADD_SHAPE,
4829 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004830 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004831 "build_fcn": (
4832 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004833 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004834 TosaTensorValuesGen.tvgAddSub,
4835 TosaArgGen.agNone,
4836 ),
4837 "types": [DType.SHAPE],
4838 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4839 },
4840 "sub_shape": {
4841 "op": Op.SUB_SHAPE,
4842 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004843 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004844 "build_fcn": (
4845 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004846 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004847 TosaTensorValuesGen.tvgAddSub,
4848 TosaArgGen.agNone,
4849 ),
4850 "types": [DType.SHAPE],
4851 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4852 },
4853 "mul_shape": {
4854 "op": Op.MUL_SHAPE,
4855 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004856 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004857 "build_fcn": (
4858 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004859 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004860 TosaTensorValuesGen.tvgMul,
4861 TosaArgGen.agNone,
4862 ),
4863 "types": [DType.SHAPE],
4864 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4865 },
4866 "div_shape": {
4867 "op": Op.DIV_SHAPE,
4868 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004869 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004870 "build_fcn": (
4871 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004872 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004873 TosaTensorValuesGen.tvgIntDiv,
4874 TosaArgGen.agNone,
4875 ),
4876 "types": [DType.SHAPE],
4877 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4878 },
4879 "concat_shape": {
4880 "op": Op.CONCAT_SHAPE,
4881 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004882 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004883 "build_fcn": (
4884 build_concat,
4885 TosaTensorGen.tgConcat,
4886 TosaTensorValuesGen.tvgConcat,
4887 TosaArgGen.agNone,
4888 ),
4889 "types": [DType.SHAPE],
4890 "error_if_validators": (),
4891 },
4892 "const_shape": {
4893 "op": Op.CONST_SHAPE,
4894 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004895 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004896 "build_fcn": (
4897 build_const,
4898 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004899 TosaTensorValuesGen.tvgLazyGenDefault,
4900 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00004901 ),
4902 "types": [DType.SHAPE],
4903 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004904 }
4905
Kevin Cheng550ccc52021-03-03 11:21:43 -08004906
Eric Kunzee5e26762020-10-13 16:11:07 -07004907class OutputShaper:
4908 # Methods in this class compute the expected output shape and datatype
4909 # for common classes of operations
4910 def __init__(self):
4911 pass
4912
4913 # These methods return arguments that can be used for
4914 # creating a new output tensor
4915 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004916 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4917 if error_name != ErrorIf.RankMismatch:
4918 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004919 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004920
4921 shape = []
4922 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004923 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004924 shape.append(b.shape[i])
4925 else:
4926 shape.append(a.shape[i])
4927
Jerry Ge135c9552023-05-23 20:59:32 +00004928 fuzz_idx = rng.integers(0, len(a.shape))
4929 if error_name == ErrorIf.DimensionMismatch:
4930 shape[fuzz_idx] += 1
4931
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004932 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004933 all_dtypes = [
4934 DType.INT8,
4935 DType.INT16,
4936 DType.INT32,
4937 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004938 DType.FP16,
4939 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004940 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004941 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004942 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4943 outputDType = rng.choice(wrong_dtypes)
4944 else:
4945 outputDType = a.dtype
4946
4947 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004948
4949 @staticmethod
4950 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004951 assert len(a.shape) == len(b.shape)
4952 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004953
4954 shape = []
4955 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004956 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004957 shape.append(a.shape[i])
4958
Kevin Cheng550ccc52021-03-03 11:21:43 -08004959 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004960
4961 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004962 def unaryOp(ser, rng, a, error_name=None):
4963 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004964 all_dtypes = [
4965 DType.INT8,
4966 DType.INT16,
4967 DType.INT32,
4968 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004969 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004970 DType.FP16,
4971 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004972 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004973 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4974 outputDType = rng.choice(wrong_dtypes)
4975 else:
4976 outputDType = a.dtype
4977
4978 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004979
4980 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004981 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004982 if error_name != ErrorIf.RankMismatch:
4983 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004984 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004985
4986 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004987 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004988 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004989 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4990 else:
4991 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004992
Jerry Ge135c9552023-05-23 20:59:32 +00004993 fuzz_idx = rng.integers(0, len(a.shape))
4994 if error_name == ErrorIf.DimensionMismatch:
4995 shape[fuzz_idx] += 1
4996
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004997 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004998 all_dtypes = [
4999 DType.INT8,
5000 DType.INT16,
5001 DType.INT32,
5002 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005003 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005004 DType.FP16,
5005 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005006 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005007 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5008 outputDType = rng.choice(wrong_dtypes)
5009 else:
5010 outputDType = a.dtype
5011
5012 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005013
5014 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005015 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005016 if error_name != ErrorIf.RankMismatch:
5017 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005018 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005019
5020 # Do broadcast
5021 shape = []
5022 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005023 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005024 shape.append(b.shape[i])
5025 else:
5026 shape.append(a.shape[i])
5027
Jerry Ge135c9552023-05-23 20:59:32 +00005028 fuzz_idx = rng.integers(0, len(a.shape))
5029 if error_name == ErrorIf.DimensionMismatch:
5030 shape[fuzz_idx] += 1
5031
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005032 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005033 wrong_dtypes = [
5034 DType.INT8,
5035 DType.INT16,
5036 DType.INT32,
5037 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005038 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005039 DType.FP16,
5040 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005041 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005042 outputDType = rng.choice(wrong_dtypes)
5043 else:
5044 outputDType = DType.BOOL
5045
5046 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005047
5048 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005049 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005050 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005051 if error_name not in [
5052 ErrorIf.AxisSmallerZero,
5053 ErrorIf.AxisLargerRank,
5054 ErrorIf.ShapeOfAxisNotOne,
5055 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005056 shape[axis] = 1
5057 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5058 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005059
Matthew Haddond6ce7252021-09-29 15:35:44 +01005060 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005061 all_dtypes = [
5062 DType.INT8,
5063 DType.INT16,
5064 DType.INT32,
5065 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005066 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005067 DType.FP16,
5068 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005069 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005070 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5071 outputDType = rng.choice(wrong_dtypes)
5072 else:
5073 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005074
Matthew Haddond6ce7252021-09-29 15:35:44 +01005075 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005076
5077 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005078 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005079 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005080
5081 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5082 del shape[axis]
5083
5084 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5085 remove = rng.choice([True, False])
5086 if remove and len(shape) > 1:
5087 del shape[0]
5088 else:
5089 shape.append(1)
5090 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5091 for i in range(len(shape)):
5092 shape[i] = shape[i] + rng.integers(1, 10)
5093
5094 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005095 all_dtypes = [
5096 DType.INT8,
5097 DType.INT16,
5098 DType.INT32,
5099 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005100 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005101 DType.FP16,
5102 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005103 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005104 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5105 outputDType = rng.choice(wrong_dtypes)
5106 else:
5107 outputDType = DType.INT32
5108
5109 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005110
5111 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005112 def conv2dOp(
5113 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5114 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005115
5116 # IFM: NHWC
5117 # Filter: OHWI
5118 # OFM: NHWC
5119
Kevin Cheng550ccc52021-03-03 11:21:43 -08005120 h = (
5121 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005122 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005123 + padding[0]
5124 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005125 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005126 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005127
Kevin Cheng550ccc52021-03-03 11:21:43 -08005128 w = (
5129 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005130 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005131 + padding[2]
5132 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005133 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005134 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005135
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005136 if error_name == ErrorIf.ConvOutputShapeMismatch:
5137 choices = [1, 2, 3]
5138 change = rng.choice(choices)
5139 # increment in multiples of stride to not hit non-integer error case
5140 if change in [1, 3]:
5141 h = h + (rng.choice(choices) * strides[0])
5142 if change in [2, 3]:
5143 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005144
Eric Kunzee5e26762020-10-13 16:11:07 -07005145 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5146
James Ward8b390432022-08-12 20:48:56 +01005147 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005148 # Pick some potentially correct output dtype if input type is incorrect
5149 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005150 else:
James Ward8b390432022-08-12 20:48:56 +01005151 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005152
5153 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005154 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005155 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005156 else:
5157 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005158 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005159 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005160
Kevin Cheng550ccc52021-03-03 11:21:43 -08005161 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005162
5163 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005164 def conv3dOp(
5165 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5166 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005167
5168 # IFM: NDHWC
5169 # Filter: ODHWI
5170 # OFM: NDHWC
5171
5172 d = (
5173 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005174 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005175 + padding[0]
5176 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005177 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005178 ) // strides[0] + 1
5179
5180 h = (
5181 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005182 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005183 + padding[2]
5184 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005185 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005186 ) // strides[1] + 1
5187
5188 w = (
5189 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005190 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005191 + padding[4]
5192 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005193 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005194 ) // strides[2] + 1
5195
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005196 if error_name == ErrorIf.ConvOutputShapeMismatch:
5197 choices = [1, 2, 3, 4]
5198 change = rng.choice(choices)
5199 # increment in multiples of stride to not hit non-integer error case
5200 if change in [1, 4]:
5201 d = d + (rng.choice(choices) * strides[0])
5202 if change in [2, 4]:
5203 h = h + (rng.choice(choices) * strides[1])
5204 if change in [3, 4]:
5205 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005206
Kevin Cheng1533b852021-09-01 12:51:58 -07005207 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5208
James Ward8b390432022-08-12 20:48:56 +01005209 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005210 # Pick some potentially correct output dtype if input type is incorrect
5211 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005212 else:
James Ward8b390432022-08-12 20:48:56 +01005213 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005214
5215 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005216 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005217 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005218 else:
5219 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005220 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005221 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005222
5223 return ser.addOutput(ofm_shape, out_dtype)
5224
5225 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005226 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005227 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005228 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005229 # IFM: NHWC
5230 # Filter: HWCM
5231 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005232
Kevin Cheng550ccc52021-03-03 11:21:43 -08005233 h = (
5234 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005235 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005236 + padding[0]
5237 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005238 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005239 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005240
Kevin Cheng550ccc52021-03-03 11:21:43 -08005241 w = (
5242 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005243 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005244 + padding[2]
5245 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005246 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005247 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005248
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005249 if error_name == ErrorIf.ConvOutputShapeMismatch:
5250 choices = [1, 2, 3]
5251 change = rng.choice(choices)
5252 # increment in multiples of stride to not hit non-integer error case
5253 if change in [1, 3]:
5254 h = h + (rng.choice(choices) * strides[0])
5255 if change in [2, 3]:
5256 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005257
Eric Kunzee5e26762020-10-13 16:11:07 -07005258 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5259
James Ward8b390432022-08-12 20:48:56 +01005260 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005261 # Pick some potentially correct output dtype if input type is incorrect
5262 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005263 else:
James Ward8b390432022-08-12 20:48:56 +01005264 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005265
5266 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005267 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005268 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005269 else:
5270 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005271 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005272 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005273
Kevin Cheng550ccc52021-03-03 11:21:43 -08005274 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005275
5276 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005277 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005278 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005279 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005280 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005281 h = 1
5282 w = 1
5283 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005284 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5285 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005286
5287 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005288 choices = [1, 2, 3]
5289 change = rng.choice(choices)
5290 # increment in multiples of stride to not hit non-integer error case
5291 if change in [1, 3]:
5292 h = h + (rng.choice(choices) * stride[0])
5293 if change in [2, 3]:
5294 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005295 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005296
5297 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005298 all_dtypes = [
5299 DType.INT8,
5300 DType.INT16,
5301 DType.INT32,
5302 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005303 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005304 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005305 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005306 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005307 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5308 outputDType = rng.choice(wrong_dtypes)
5309 else:
5310 outputDType = ifm.dtype
5311
5312 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005313
5314 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005315 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005316 # input: N, IC
5317 # filter: OC, IC
5318 # output: N, OC
5319
5320 output_shape = [input.shape[0], filter.shape[0]]
5321
James Ward8b390432022-08-12 20:48:56 +01005322 # Validated in arg_gen (also invalidated for ErrorIf)
5323 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005324
Kevin Cheng550ccc52021-03-03 11:21:43 -08005325 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005326
5327 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005328 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005329 # a: N, H, C
5330 # b: N, C, W
5331 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005332
Kevin Cheng2d60f002021-06-09 14:18:32 -07005333 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005334
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005335 if error_name == ErrorIf.WrongOutputType:
5336 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005337 incorrect_types = (
5338 DType.INT4,
5339 DType.INT8,
5340 DType.INT16,
5341 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005342 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005343 DType.FP16,
5344 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005345 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005346 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005347 incorrect_types = (
5348 DType.INT4,
5349 DType.INT8,
5350 DType.INT16,
5351 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005352 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005353 DType.FP16,
5354 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005355 )
James Ward24dbc422022-10-19 12:20:31 +01005356 elif (
5357 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5358 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005359 incorrect_types = (
5360 DType.INT4,
5361 DType.INT8,
5362 DType.INT16,
5363 DType.INT32,
5364 DType.INT48,
5365 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005366 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005367 elif error_name == ErrorIf.WrongInputType:
5368 # Pick some potentially correct output dtype if input type is incorrect
5369 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005370 else:
James Ward8b390432022-08-12 20:48:56 +01005371 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005372
Kevin Cheng550ccc52021-03-03 11:21:43 -08005373 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005374
5375 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005376 def concatOp(ser, rng, axis, inputs, error_name=None):
5377 input1 = inputs[0]
5378 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005379
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005380 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005381 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005382 if not (
5383 # unable to concat tensors of different ranks
5384 error_name == ErrorIf.ConcatInputRankMismatch
5385 # unable to concat tensors along an invalid axis
5386 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005387 ):
5388 for tensor in remaining_inputs:
5389 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005390
Matthew Haddon01c359d2021-10-15 16:30:48 +01005391 if error_name == ErrorIf.ConcatShapeSumMismatch:
5392 output_shape[axis] += rng.integers(5, 10)
5393
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005394 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005395 all_dtypes = {
5396 DType.INT8,
5397 DType.INT16,
5398 DType.INT32,
5399 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005400 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005401 DType.FP16,
5402 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005403 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005404 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5405 outputDType = rng.choice(wrong_dtypes)
5406 else:
5407 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005408
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005409 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005410
5411 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005412 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005413
5414 output_shape = a.shape.copy()
5415
5416 for i in range(len(output_shape)):
5417 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5418
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005419 if error_name == ErrorIf.PadOutputShapeMismatch:
5420 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005421 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005422 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005423 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005424
Matthew Haddone807aae2021-10-11 18:12:58 +01005425 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005426 all_dtypes = [
5427 DType.INT8,
5428 DType.INT16,
5429 DType.INT32,
5430 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005431 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005432 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005433 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005434 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005435 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5436 outputDType = rng.choice(wrong_dtypes)
5437 else:
5438 outputDType = a.dtype
5439
5440 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005441
5442 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005443 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005444 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005445
5446 if error_name == ErrorIf.WrongOutputType:
5447 all_dtypes = [
5448 DType.INT8,
5449 DType.INT16,
5450 DType.INT32,
5451 DType.INT48,
5452 DType.FP32,
5453 DType.FP16,
5454 DType.BF16,
5455 ]
5456 wrong_dtypes = list(set(all_dtypes))
5457 outputDType = rng.choice(wrong_dtypes)
5458 else:
5459 outputDType = DType.SHAPE
5460
5461 return ser.addOutput(output_shape, outputDType)
5462
5463 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005464 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005465 output_shape = shape.copy()
5466
Matthew Haddone807aae2021-10-11 18:12:58 +01005467 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5468 for i in range(len(output_shape)):
5469 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5470
5471 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005472 all_dtypes = [
5473 DType.INT8,
5474 DType.INT16,
5475 DType.INT32,
5476 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005477 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005478 DType.FP16,
5479 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005480 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005481 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5482 outputDType = rng.choice(wrong_dtypes)
5483 else:
5484 outputDType = a.dtype
5485
5486 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005487
5488 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005489 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005490
Matthew Haddone807aae2021-10-11 18:12:58 +01005491 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005492 all_dtypes = [
5493 DType.INT8,
5494 DType.INT16,
5495 DType.INT32,
5496 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005497 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005498 DType.FP16,
5499 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005500 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005501 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005502 outputDType = rng.choice(wrong_dtypes)
5503 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005504 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005505
Luke Huttona4e48ca2023-02-22 11:53:48 +00005506 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005507 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005508 for index in range(len(output_shape)):
5509 if output_shape[index] <= 2:
5510 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5511 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005512 output_shape[index] = output_shape[index] + rng.choice(
5513 [-2, -1, 1, 2]
5514 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005515 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5516 output_shape = input.shape.copy()
5517 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005518 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005519
5520 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005521
5522 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005523 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005524
5525 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005526 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005527
5528 for i in range(len(output_shape)):
5529 output_shape[i] = a.shape[i] * multiples[i]
5530
Luke Huttona4e48ca2023-02-22 11:53:48 +00005531 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005532 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005533
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005534 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005535 all_dtypes = [
5536 DType.INT8,
5537 DType.INT16,
5538 DType.INT32,
5539 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005540 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005541 DType.FP16,
5542 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005543 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005544 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5545 outputDType = rng.choice(wrong_dtypes)
5546 else:
5547 outputDType = a.dtype
5548
5549 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005550
5551 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005552 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005553 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005554
Kevin Cheng550ccc52021-03-03 11:21:43 -08005555 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005556
Luke Huttona4e48ca2023-02-22 11:53:48 +00005557 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005558 for i in range(len(output_shape)):
5559 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005560
Luke Huttona4e48ca2023-02-22 11:53:48 +00005561 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5562 for i in range(len(output_shape)):
5563 output_shape[i] += rng.integers(1, 10)
5564 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005565 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005566
Matthew Haddone807aae2021-10-11 18:12:58 +01005567 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005568 all_dtypes = [
5569 DType.INT8,
5570 DType.INT16,
5571 DType.INT32,
5572 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005573 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005574 DType.FP16,
5575 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005576 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005577 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5578 outputDType = rng.choice(wrong_dtypes)
5579 else:
5580 outputDType = a.dtype
5581
5582 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005583
5584 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005585 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005586 if error_name != ErrorIf.WrongRank:
5587 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005588 assert len(indices.shape) == 2
5589 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005590
Kevin Cheng77d0f762020-11-24 10:26:32 -08005591 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5592
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005593 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005594 all_dtypes = [
5595 DType.INT8,
5596 DType.INT16,
5597 DType.INT32,
5598 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005599 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005600 DType.FP16,
5601 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005602 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005603 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5604 outputDType = rng.choice(wrong_dtypes)
5605 else:
5606 outputDType = values.dtype
5607
5608 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005609
5610 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005611 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005612 if error_name != ErrorIf.WrongRank:
5613 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005614 assert len(indices.shape) == 2
5615 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005616 assert values_in.shape[0] == indices.shape[0] # N
5617 assert input.shape[1] == indices.shape[1] # W
5618 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005619
5620 output_shape = values_in.shape
5621
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005622 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005623 all_dtypes = [
5624 DType.INT8,
5625 DType.INT16,
5626 DType.INT32,
5627 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005628 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005629 DType.FP16,
5630 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005631 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005632 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5633 outputDType = rng.choice(wrong_dtypes)
5634 else:
5635 outputDType = values_in.dtype
5636
5637 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005638
5639 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005640 def tableOp(ser, rng, input, error_name=None):
5641 # Same shape as the input, dtype dependent on input dtype
5642 if error_name != ErrorIf.WrongInputType:
5643 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005644 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005645 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005646 wrong_dtypes = [
5647 DType.INT8,
5648 DType.INT16,
5649 DType.INT32,
5650 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005651 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005652 DType.FP16,
5653 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005654 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005655 wrong_dtypes.remove(output_dtype)
5656 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005657 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005658
5659 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005660 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005661 serializer,
5662 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005663 input,
5664 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005665 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005666 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005667 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005668 input_dtype,
5669 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005670 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005671 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005672 # Calculate OH, OW
5673 scale_y_n = scale[0]
5674 scale_y_d = scale[1]
5675 scale_x_n = scale[2]
5676 scale_x_d = scale[3]
5677 if error_name == ErrorIf.ScaleSmallerEqualZero:
5678 scale_y_n = max(scale_y_n, 1)
5679 scale_y_d = max(scale_y_d, 1)
5680 scale_x_n = max(scale_x_n, 1)
5681 scale_x_d = max(scale_x_d, 1)
5682
5683 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5684 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5685
5686 if error_name is not None:
5687 # Make sure the output tensor is valid, which can occur when
5688 # scale, offset or border have been changed for ERROR_IFs
5689 oh = max(oh, 1)
5690 ow = max(ow, 1)
5691 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005692 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5693 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005694
5695 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5696 choices = [1, 2, 3]
5697 change = rng.choice(choices)
5698 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5699 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005700 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005701 oh -= scale_y_d
5702 assert oh > 0 # Should have been caught in agResize
5703 else:
5704 oh += scale_y_d
5705 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005706 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005707 ow -= scale_x_d
5708 assert ow > 0 # Should have been caught in agResize
5709 else:
5710 ow += scale_x_d
5711
Matthew Haddon848efb42021-09-09 12:30:53 +01005712 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005713 output_dims = [
5714 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005715 oh,
5716 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005717 input.shape[0],
5718 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005719 elif error_name == ErrorIf.BatchMismatch:
5720 output_dims = [
5721 input.shape[0] + rng.integers(1, 10),
5722 oh,
5723 ow,
5724 input.shape[3],
5725 ]
5726 elif error_name == ErrorIf.ChannelMismatch:
5727 output_dims = [
5728 input.shape[0],
5729 oh,
5730 ow,
5731 input.shape[3] + rng.integers(1, 10),
5732 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005733 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005734 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005735
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005736 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005737
5738 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005739 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005740 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005741
5742 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005743 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005744 if error_name == ErrorIf.ConvOutputShapeMismatch:
5745 choices = [1, 2, 3]
5746 change = rng.choice(choices)
5747 if change in [1, 3]:
5748 output_shape[1] = output_shape[1] + rng.choice(choices)
5749 if change in [2, 3]:
5750 output_shape[2] = output_shape[2] + rng.choice(choices)
5751
James Ward8b390432022-08-12 20:48:56 +01005752 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005753 # Pick some potentially correct output dtype if input type is incorrect
5754 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005755 else:
James Ward8b390432022-08-12 20:48:56 +01005756 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005757
5758 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005759 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005760 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005761 else:
5762 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005763 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005764 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005765
Kevin Cheng550ccc52021-03-03 11:21:43 -08005766 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005767
5768 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005769 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5770 outputs = []
5771
5772 assert ifm1.dtype == ifm2.dtype
5773 input_dtype = ifm1.dtype
5774
5775 if error_name != ErrorIf.FFTInputShapeMismatch:
5776 assert ifm1.shape == ifm2.shape
5777
5778 input_shape = ifm1.shape
5779 if error_name != ErrorIf.WrongRank:
5780 assert len(input_shape) == 3
5781
5782 output_shape = input_shape.copy()
5783 output_dtype = input_dtype
5784
5785 if error_name == ErrorIf.WrongOutputType:
5786 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005787 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005788 output_dtype = rng.choice(wrong_dtypes)
5789 elif error_name == ErrorIf.BatchMismatch:
5790 output_shape[0] += rng.integers(1, 10)
5791 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5792 modify_dim = rng.choice([1, 2])
5793 output_shape[modify_dim] += rng.integers(1, 10)
5794
5795 outputs.append(serializer.addOutput(output_shape, output_dtype))
5796 outputs.append(serializer.addOutput(output_shape, output_dtype))
5797 return outputs
5798
5799 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005800 def rfft2dOp(serializer, rng, value, error_name=None):
5801 outputs = []
5802
5803 input_shape = value.shape
5804 if error_name != ErrorIf.WrongRank:
5805 assert len(input_shape) == 3
5806
5807 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5808
5809 output_dtype = value.dtype
5810 if error_name == ErrorIf.WrongOutputType:
5811 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005812 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005813 output_dtype = rng.choice(wrong_dtypes)
5814 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005815 output_shape[0] += rng.integers(1, 10)
5816 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5817 modify_dim = rng.choice([1, 2])
5818 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005819
5820 outputs.append(serializer.addOutput(output_shape, output_dtype))
5821 outputs.append(serializer.addOutput(output_shape, output_dtype))
5822 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005823
5824 @staticmethod
5825 def addShapeOp(ser, rng, a, b, error_name=None):
5826 if error_name != ErrorIf.RankMismatch:
5827 assert len(a.shape) == len(b.shape)
5828 assert a.dtype == b.dtype
5829
5830 shape = []
5831 for i in range(len(a.shape)):
5832 shape.append(a.shape[i])
5833
5834 fuzz_idx = rng.integers(0, len(a.shape))
5835 if error_name == ErrorIf.DimensionMismatch:
5836 shape[fuzz_idx] += 1
5837
5838 if error_name == ErrorIf.WrongOutputType:
5839 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5840 outputDType = rng.choice(wrong_dtypes)
5841 else:
5842 outputDType = DType.SHAPE
5843 return ser.addOutput(shape, outputDType)