blob: 28cf39206b4d324fbed5d5734ad470cd9c6df917 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
79 for dtype in (DType.FP32, DType.FP16, DType.BF16):
80 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000198 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 return np.int64(self.rng.integers(low=low, high=high, size=shape))
200 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
201 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
202
203 if dtype == DType.FP16:
204 return np.float16(f_tensor)
205 else:
206 f32_tensor = np.float32(f_tensor)
207 if dtype == DType.BF16:
208 # Floor the last 16 bits of each f32 value
209 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
210 else:
211 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100213 # All other integer types
214 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
Kevin Cheng989cb052021-04-28 16:29:44 -0700216 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700217 placeholders = []
218
Kevin Cheng989cb052021-04-28 16:29:44 -0700219 assert len(shape_list) == len(dtype_list)
220
Jeremy Johnson1271c442023-09-05 11:39:26 +0100221 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700222 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 if not self.args.lazy_data_gen:
224 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700225 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700226
227 return placeholders
228
Kevin Cheng989cb052021-04-28 16:29:44 -0700229 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700230 consts = []
231
Kevin Cheng989cb052021-04-28 16:29:44 -0700232 assert len(shape_list) == len(dtype_list)
233
Jeremy Johnson1271c442023-09-05 11:39:26 +0100234 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700235 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100236 if not self.args.lazy_data_gen:
237 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700238 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700239
240 return consts
241
242 def makeShape(self, rank):
243 if self.targetted_shape:
244 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 return np.int32(
246 self.rng.integers(
247 low=self.args.tensor_shape_range[0],
248 high=self.args.tensor_shape_range[1],
249 size=rank,
250 )
251 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700252
253 def setTargetShape(self, shape):
254 self.targetted_shape = shape
255
256 def randInt(self, low=0, high=256):
257 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
258
259 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 low, high = self.getDTypeRange(dtype)
261
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100262 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100263 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100264 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100265 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100266 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100267 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
268 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700269 elif dtype == DType.BOOL:
270 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000271 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700272 # Special size
273 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 return np.int32(self.rng.integers(low, high, size=1))[0]
276
277 def shapeStr(self, shape):
278
279 sStr = []
280 # Convert to strings
281 for i in shape:
282 sStr.append(str(i))
283
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100286 def typeStr(self, dtype):
287 if isinstance(dtype, list) or isinstance(dtype, tuple):
288 assert len(dtype) >= 2
289 strs = [self.typeStr(t) for t in dtype]
290 # Limit types to the first 2 as the 3rd is the accumulator
291 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100293 if dtype in gtu.DTYPE_ATTRIBUTES:
294 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700295 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100296 raise Exception(
297 "Unknown dtype, cannot convert to string: {}".format(dtype)
298 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100300 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100301 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100302 if dtype in gtu.DTYPE_ATTRIBUTES:
303 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100305 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700306
Luke Hutton57287132023-02-06 14:54:18 +0000307 def constrictBatchSize(self, shape):
308 # Limit the batch size unless an explicit target shape set
309 if self.args.max_batch_size and not self.args.target_shapes:
310 shape[0] = min(shape[0], self.args.max_batch_size)
311 return shape
312
James Ward30124a82023-02-02 14:56:33 +0000313 def makeDimension(self):
314 return self.randInt(
315 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
316 )
317
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100318 def tensorComplianceMetaData(
319 self, op, inputType, argsDict, outputTensor, errorName
320 ):
Jeremy Johnson708da822023-11-15 16:25:45 +0000321 # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet
322 UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100323 if (
324 errorName
325 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000326 or (
327 not gtu.dtypeIsSupportedByCompliance(inputType)
328 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
329 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100330 ):
331 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100332 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100333
Jeremy Johnson1271c442023-09-05 11:39:26 +0100334 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100335 compliance_tens = {
336 "mode": None,
337 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
338 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
339 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100340 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
341 mode = gtu.ComplianceMode.DOT_PRODUCT
342 compliance_tens["dot_product_info"] = {
343 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100344 "ks": int(argsDict["ksb"])
345 if "ksb" in argsDict
346 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100347 }
348 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
349 mode = gtu.ComplianceMode.FP_SPECIAL
350 elif "compliance" in op and "ulp" in op["compliance"]:
351 mode = gtu.ComplianceMode.ULP
352 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
353 elif op["op"] == Op.REDUCE_PRODUCT:
354 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000355 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000356 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000357 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000358 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
359 compliance_tens["abs_error_info"] = {
360 "lower_bound": op["compliance"]["abs_error_lower_bound"]
361 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100362 else:
363 mode = gtu.ComplianceMode.EXACT
364 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
365
366 return compliance_tens
367
368 # Build Op functions
369 # Create the output tensor (calling OutputShaper as needed)
370 # Do final tweaks to attributes (if necessary for errorIf)
371 # Add Op into graph
372 # Return resulting tensor information or BuildInfo
373
374 class BuildInfo:
375 """Enhanced build information containing result tensor and associated compliance dict."""
376
377 def __init__(self, resultTensor, complianceDict):
378 self.resultTensor = resultTensor
379 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700380
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000381 def build_unary(
382 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
383 ):
384 assert len(inputs) == 1
385 a = inputs[0]
386 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100387
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000388 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100389
390 # Ensure new output type has correct qinfo
391 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000392 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000393 qinfo = [
394 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000395 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000396 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100397
398 # Invalidate Input/Output list for error if checks.
399 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000400 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100401 pCount, cCount = op["operands"]
402 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000403 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
404 self, error_name, input_list, output_list
405 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100406
Les Bell729b0352021-11-24 10:28:21 +0000407 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100408 self.ser,
409 validator_fcns,
410 error_name,
411 op=op,
412 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000413 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000414 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000415 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100416 input_list=input_list,
417 output_list=output_list,
418 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000419 ):
420 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100421
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000422 attr = None
423 if op["op"] == Op.NEGATE:
424 attr = ts.TosaSerializerAttribute()
425 attr.NegateAttribute(qinfo[0], qinfo[1])
426
427 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000428
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000429 compliance = self.tensorComplianceMetaData(
430 op, a.dtype, args_dict, result_tensor, error_name
431 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000432 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700433
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000434 def build_binary_broadcast(
435 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
436 ):
437 assert len(inputs) == 2
438 a, b = inputs
439 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000440 self.ser, self.rng, a, b, error_name
441 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100442
443 # Invalidate Input/Output list for error if checks.
444 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000445 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100446 pCount, cCount = op["operands"]
447 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000448 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
449 self, error_name, input_list, output_list
450 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100451
Les Bell729b0352021-11-24 10:28:21 +0000452 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100453 self.ser,
454 validator_fcns,
455 error_name,
456 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000457 input1=a,
458 input2=b,
459 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000460 output_dtype=result_tensor.dtype,
461 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100462 input_list=input_list,
463 output_list=output_list,
464 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000465 ):
466 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100467
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000468 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000469
Jeremy Johnson9a758382023-11-07 16:27:35 +0000470 compliance = self.tensorComplianceMetaData(
471 op, a.dtype, args_dict, result_tensor, error_name
472 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000473
474 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700475
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100476 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700477 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000478 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700479 return result_tens
480
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000481 def build_arithmetic_right_shift(
482 self, op, a, b, round, validator_fcns=None, error_name=None
483 ):
484 result_tens = OutputShaper.binaryBroadcastOp(
485 self.ser, self.rng, a, b, error_name
486 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100487
488 # Invalidate Input/Output list for error if checks.
489 input_list = [a.name, b.name]
490 output_list = [result_tens.name]
491 pCount, cCount = op["operands"]
492 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000493 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
494 self, error_name, input_list, output_list
495 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100496
Les Bell729b0352021-11-24 10:28:21 +0000497 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100498 self.ser,
499 validator_fcns,
500 error_name,
501 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000502 input1=a,
503 input2=b,
504 input_dtype=a.dtype,
505 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000506 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100507 input_list=input_list,
508 output_list=output_list,
509 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000510 ):
511 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800512
513 attr = ts.TosaSerializerAttribute()
514 attr.ArithmeticRightShiftAttribute(round)
515
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000516 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800517 return result_tens
518
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100519 def build_mul(
520 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
521 ):
522 assert len(inputs) == 2
523 a, b = inputs
524 shift = args_dict["shift"]
525
526 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000527 self.ser, self.rng, a, b, error_name
528 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700529
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100530 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100531 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100532 result_tensor.setDtype(DType.INT32)
533
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100534 if error_name == ErrorIf.WrongOutputType:
535 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
536 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100537 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100538
539 # Invalidate Input/Output list for error if checks.
540 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100541 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100542 pCount, cCount = op["operands"]
543 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000544 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
545 self, error_name, input_list, output_list
546 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100547
Les Bell729b0352021-11-24 10:28:21 +0000548 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100549 self.ser,
550 validator_fcns,
551 error_name,
552 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000553 input1=a,
554 input2=b,
555 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100556 output_dtype=result_tensor.dtype,
557 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100558 input_list=input_list,
559 output_list=output_list,
560 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000561 ):
562 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700563
Kevin Chengaee1fac2020-11-11 13:54:06 -0800564 attr = ts.TosaSerializerAttribute()
565 attr.MulAttribute(shift)
566
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000567 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100568
569 compliance = self.tensorComplianceMetaData(
570 op, a.dtype, args_dict, result_tensor, error_name
571 )
572
573 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700574
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100575 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
576 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700577
Kevin Chengfe392ce2021-10-18 21:51:55 +0000578 attr = ts.TosaSerializerAttribute()
579 attr.TableAttribute(table)
580
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100581 # Invalidate Input/Output list for error if checks.
582 input_list = [a.name]
583 output_list = [result_tens.name]
584 pCount, cCount = op["operands"]
585 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000586 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
587 self, error_name, input_list, output_list
588 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100589
Les Bell729b0352021-11-24 10:28:21 +0000590 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100591 self.ser,
592 validator_fcns,
593 error_name,
594 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000595 input_shape=a.shape,
596 input_dtype=a.dtype,
597 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000598 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100599 input_list=input_list,
600 output_list=output_list,
601 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000602 ):
603 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100604
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000605 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700606
607 return result_tens
608
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000609 def build_select(
610 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
611 ):
612 assert len(inputs) == 3
613 cond, a, b = inputs
614
615 result_tensor = OutputShaper.selectOp(
616 self.ser, self.rng, cond, a, b, error_name
617 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100618
619 # Invalidate Input/Output list for error if checks.
620 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000621 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622 pCount, cCount = op["operands"]
623 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000624 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
625 self, error_name, input_list, output_list
626 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100627
Les Bell729b0352021-11-24 10:28:21 +0000628 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100629 self.ser,
630 validator_fcns,
631 error_name,
632 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000633 input1=cond,
634 input2=a,
635 input3=b,
636 input_shape=a.shape,
637 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000638 output_dtype=result_tensor.dtype,
639 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100640 input_list=input_list,
641 output_list=output_list,
642 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000643 ):
644 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100645
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000646 self.ser.addOperator(
647 op["op"],
648 input_list,
649 output_list,
650 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000651 compliance = self.tensorComplianceMetaData(
652 op, a.dtype, args_dict, result_tensor, error_name
653 )
654
655 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700656
Jeremy Johnsona0150012023-11-15 15:52:06 +0000657 def build_comparison(
658 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
659 ):
660 assert len(inputs) == 2
661 a, b = inputs
662
663 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000664 self.ser, self.rng, a, b, error_name
665 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100666
667 # Invalidate Input/Output list for error if checks.
668 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000669 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100670 pCount, cCount = op["operands"]
671 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000672 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
673 self, error_name, input_list, output_list
674 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100675
Les Bell729b0352021-11-24 10:28:21 +0000676 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100677 self.ser,
678 validator_fcns,
679 error_name,
680 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000681 input1=a,
682 input2=b,
683 input_shape=a.shape,
684 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000685 output_shape=result_tensor.shape,
686 output_dtype=result_tensor.dtype,
687 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100688 input_list=input_list,
689 output_list=output_list,
690 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000691 ):
692 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100693
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000694 self.ser.addOperator(
695 op["op"],
696 input_list,
697 output_list,
698 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000699
700 compliance = self.tensorComplianceMetaData(
701 op, a.dtype, args_dict, result_tensor, error_name
702 )
703 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700704
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000705 def build_argmax(
706 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
707 ):
708 assert len(inputs) == 1
709 a = inputs[0]
710 axis = args_dict["axis"]
711 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100712
713 # Invalidate Input/Output list for error if checks.
714 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000715 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100716 pCount, cCount = op["operands"]
717 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000718 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
719 self, error_name, input_list, output_list
720 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100721
Les Bell729b0352021-11-24 10:28:21 +0000722 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100723 self.ser,
724 validator_fcns,
725 error_name,
726 op=op,
727 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000728 input_shape=a.shape,
729 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000730 output_shape=result_tensor.shape,
731 output_dtype=result_tensor.dtype,
732 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100733 input_list=input_list,
734 output_list=output_list,
735 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000736 ):
737 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700738
739 attr = ts.TosaSerializerAttribute()
740 attr.AxisAttribute(axis)
741
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000742 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000743
744 compliance = self.tensorComplianceMetaData(
745 op, inputs[0].dtype, args_dict, result_tensor, error_name
746 )
747 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700748
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000749 def build_pool2d(
750 self,
751 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100752 inputs,
753 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000754 validator_fcns=None,
755 error_name=None,
756 qinfo=None,
757 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100758 assert len(inputs) == 1
759 input = inputs[0]
760 # max_pool has no accum_dtype
761 accum_dtype = (
762 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
763 )
764 stride = args_dict["stride"]
765 pad = args_dict["pad"]
766 kernel = args_dict["kernel"]
767
Jeremy Johnson0601f802023-11-08 16:28:09 +0000768 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000769 self.ser, self.rng, input, kernel, stride, pad, error_name
770 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100771
772 # Ensure new output type has correct qinfo
773 if error_name == ErrorIf.WrongInputType:
774 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000775 qinfo = [
776 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000777 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000778 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100779
780 # Invalidate Input/Output list for error if checks.
781 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000782 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100783 pCount, cCount = op["operands"]
784 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000785 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
786 self, error_name, input_list, output_list
787 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100788
Les Bell729b0352021-11-24 10:28:21 +0000789 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100790 self.ser,
791 validator_fcns,
792 error_name,
793 op=op,
794 input_shape=input.shape,
795 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000796 output_shape=result_tensor.shape,
797 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100798 kernel=kernel,
799 stride=stride,
800 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000801 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000802 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100803 input_list=input_list,
804 output_list=output_list,
805 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000806 ):
807 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700808
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000809 if qinfo is None:
810 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700811
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000812 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100813 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000814
815 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700816
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100817 compliance = self.tensorComplianceMetaData(
818 op, inputs[0].dtype, args_dict, result_tensor, error_name
819 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100820
821 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100822
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000823 def build_conv2d(
824 self,
825 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100826 inputs,
827 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000828 validator_fcns=None,
829 error_name=None,
830 qinfo=None,
831 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100832 assert len(inputs) == 3
833 ifm, filter, bias = inputs
834 accum_dtype = args_dict["acc_type"]
835 strides = args_dict["stride"]
836 padding = args_dict["pad"]
837 dilations = args_dict["dilation"]
838
Kevin Cheng550ccc52021-03-03 11:21:43 -0800839 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100840 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100841 self.ser,
842 self.rng,
843 ifm,
844 filter,
845 accum_dtype,
846 strides,
847 padding,
848 dilations,
849 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000850 )
851
852 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000853 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
854 DType.INT8,
855 DType.UINT8,
856 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000857 qinfo = [
858 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100859 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000860 ]
Les Bell0e027d42021-11-09 14:42:14 +0000861
862 # Invalidate Input/Output list for error_if checks.
863 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100864 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000865 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000866 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
867 self, error_name, input_list, output_list
868 )
Les Bell0e027d42021-11-09 14:42:14 +0000869
Les Bell729b0352021-11-24 10:28:21 +0000870 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000871 self.ser,
872 validator_fcns,
873 error_name,
874 op=op,
875 input_dtype=ifm.dtype,
876 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100877 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000878 qinfo=qinfo,
879 input_list=input_list,
880 num_operands=num_operands,
881 output_list=output_list,
882 pad=padding,
883 stride=strides,
884 dilation=dilations,
885 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100886 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100887 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000888 ):
889 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700890
Tai Lyd3797f02023-11-15 23:06:19 +0000891 # TODO - Test local_bound, for now set local bound attribute to False
892 local_bound = False
893
Eric Kunzee5e26762020-10-13 16:11:07 -0700894 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000895 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700896
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000897 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100898
899 compliance = self.tensorComplianceMetaData(
900 op, ifm.dtype, args_dict, result_tensor, error_name
901 )
902
903 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700904
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000905 def build_conv3d(
906 self,
907 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100908 inputs,
909 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000910 validator_fcns=None,
911 error_name=None,
912 qinfo=None,
913 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100914 assert len(inputs) == 3
915 ifm, filter, bias = inputs
916 accum_dtype = args_dict["acc_type"]
917 strides = args_dict["stride"]
918 padding = args_dict["pad"]
919 dilations = args_dict["dilation"]
920
Kevin Cheng1533b852021-09-01 12:51:58 -0700921 assert len(padding) == 6
922 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100923 self.ser,
924 self.rng,
925 ifm,
926 filter,
927 accum_dtype,
928 strides,
929 padding,
930 dilations,
931 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000932 )
933
934 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000935 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
936 DType.INT8,
937 DType.UINT8,
938 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000939 qinfo = [
940 TosaQuantGen.getZeroPoint(self, ifm.dtype),
941 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
942 ]
Les Bell0e027d42021-11-09 14:42:14 +0000943
944 # Invalidate Input/Output list for error_if checks.
945 input_list = [ifm.name, filter.name, bias.name]
946 output_list = [result_tens.name]
947 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000948 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
949 self, error_name, input_list, output_list
950 )
Les Bell0e027d42021-11-09 14:42:14 +0000951
Les Bell729b0352021-11-24 10:28:21 +0000952 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000953 self.ser,
954 validator_fcns,
955 error_name,
956 op=op,
957 input_dtype=ifm.dtype,
958 weight_dtype=filter.dtype,
959 output_dtype=result_tens.dtype,
960 qinfo=qinfo,
961 input_list=input_list,
962 num_operands=num_operands,
963 output_list=output_list,
964 pad=padding,
965 stride=strides,
966 dilation=dilations,
967 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100968 weight_shape=filter.shape,
969 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000970 ):
971 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700972
Tai Lyd3797f02023-11-15 23:06:19 +0000973 # TODO - Test local_bound, for now set local bound attribute to False
974 local_bound = False
975
Kevin Cheng1533b852021-09-01 12:51:58 -0700976 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000977 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700978
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000979 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700980 return result_tens
981
Kevin Cheng550ccc52021-03-03 11:21:43 -0800982 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000983 self,
984 op,
985 ifm,
986 filter,
987 bias,
James Ward8b390432022-08-12 20:48:56 +0100988 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000989 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700990 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000991 output_shape,
992 validator_fcns=None,
993 error_name=None,
994 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800995 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700996 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000997 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100998 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000999 )
Les Bell0e027d42021-11-09 14:42:14 +00001000
1001 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001002 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1003 DType.INT8,
1004 DType.UINT8,
1005 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001006 qinfo = [
1007 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1008 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1009 ]
Les Bell0e027d42021-11-09 14:42:14 +00001010
1011 # Invalidate Input/Output list for error_if checks.
1012 input_list = [ifm.name, filter.name, bias.name]
1013 output_list = [result_tens.name]
1014 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001015 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1016 self, error_name, input_list, output_list
1017 )
Les Bell0e027d42021-11-09 14:42:14 +00001018
Les Bell729b0352021-11-24 10:28:21 +00001019 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001020 self.ser,
1021 validator_fcns,
1022 error_name,
1023 op=op,
1024 input_dtype=ifm.dtype,
1025 weight_dtype=filter.dtype,
1026 output_dtype=result_tens.dtype,
1027 qinfo=qinfo,
1028 input_list=input_list,
1029 num_operands=num_operands,
1030 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001031 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001032 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001033 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001034 weight_shape=filter.shape,
1035 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001036 ):
1037 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001038
Tai Lyd3797f02023-11-15 23:06:19 +00001039 # TODO - Test local_bound, for now set local bound attribute to False
1040 local_bound = False
1041
Eric Kunzee5e26762020-10-13 16:11:07 -07001042 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001043 attr.TransposeConvAttribute(
1044 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1045 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001046
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001047 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001048 return result_tens
1049
Kevin Cheng550ccc52021-03-03 11:21:43 -08001050 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001051 self,
1052 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001053 inputs,
1054 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001055 validator_fcns=None,
1056 error_name=None,
1057 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001058 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001059 assert len(inputs) == 3
1060 ifm, filter, bias = inputs
1061 accum_dtype = args_dict["acc_type"]
1062 strides = args_dict["stride"]
1063 padding = args_dict["pad"]
1064 dilations = args_dict["dilation"]
1065
Kevin Cheng550ccc52021-03-03 11:21:43 -08001066 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001067 self.ser,
1068 self.rng,
1069 ifm,
1070 filter,
1071 accum_dtype,
1072 strides,
1073 padding,
1074 dilations,
1075 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001076 )
1077
1078 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001079 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1080 DType.INT8,
1081 DType.UINT8,
1082 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001083 qinfo = [
1084 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1085 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1086 ]
Les Bell0e027d42021-11-09 14:42:14 +00001087
1088 # Invalidate Input/Output list for error_if checks.
1089 input_list = [ifm.name, filter.name, bias.name]
1090 output_list = [result_tens.name]
1091 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001092 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1093 self, error_name, input_list, output_list
1094 )
Les Bell0e027d42021-11-09 14:42:14 +00001095
Les Bell729b0352021-11-24 10:28:21 +00001096 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001097 self.ser,
1098 validator_fcns,
1099 error_name,
1100 op=op,
1101 input_dtype=ifm.dtype,
1102 weight_dtype=filter.dtype,
1103 output_dtype=result_tens.dtype,
1104 qinfo=qinfo,
1105 input_list=input_list,
1106 num_operands=num_operands,
1107 output_list=output_list,
1108 pad=padding,
1109 stride=strides,
1110 dilation=dilations,
1111 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001112 weight_shape=filter.shape,
1113 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001114 ):
1115 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001116
Tai Lyd3797f02023-11-15 23:06:19 +00001117 # TODO - Test local_bound, for now set local bound attribute to False
1118 local_bound = False
1119
Eric Kunzee5e26762020-10-13 16:11:07 -07001120 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001121 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001122
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001123 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001124 return result_tens
1125
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001126 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001127 self,
1128 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001129 inputs,
1130 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001131 validator_fcns=None,
1132 error_name=None,
1133 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001134 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001135 assert len(inputs) == 3
1136 ifm, filter, bias = inputs
1137 accum_dtype = args_dict["acc_type"]
1138
1139 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001140 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001141 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001142
1143 # Invalidate Input/Output list for error if checks.
1144 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001145 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001146 pCount, cCount = op["operands"]
1147 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001148 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1149 self, error_name, input_list, output_list
1150 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001151
Les Bell729b0352021-11-24 10:28:21 +00001152 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001153 self.ser,
1154 validator_fcns,
1155 error_name,
1156 op=op,
1157 input_shape=ifm.shape,
1158 input_dtype=ifm.dtype,
1159 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001160 output_shape=result_tensor.shape,
1161 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001162 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001163 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001164 input_list=input_list,
1165 output_list=output_list,
1166 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001167 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001168 ):
1169 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001170
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001171 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001172 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001173
1174 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001175
1176 compliance = self.tensorComplianceMetaData(
1177 op, ifm.dtype, args_dict, result_tensor, error_name
1178 )
1179
1180 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001181
James Ward8b390432022-08-12 20:48:56 +01001182 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001183 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001184 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001185 assert len(inputs) == 2
1186 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001187 accum_dtype = args_dict["acc_type"]
1188 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001189 self.ser, self.rng, a, b, accum_dtype, error_name
1190 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001191
1192 # Invalidate Input/Output list for error if checks.
1193 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001194 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001195 pCount, cCount = op["operands"]
1196 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001197 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1198 self, error_name, input_list, output_list
1199 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001200
Les Bell729b0352021-11-24 10:28:21 +00001201 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001202 self.ser,
1203 validator_fcns,
1204 error_name,
1205 op=op,
1206 input_shape=a.shape,
1207 input_dtype=a.dtype,
1208 input2_shape=b.shape,
1209 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001210 output_shape=result_tensor.shape,
1211 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001212 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001213 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001214 input_list=input_list,
1215 output_list=output_list,
1216 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001217 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001218 ):
1219 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001220
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001221 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001222 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001223
1224 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001225
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001226 compliance = self.tensorComplianceMetaData(
1227 op, a.dtype, args_dict, result_tensor, error_name
1228 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001229
1230 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001231
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001232 def build_reduce(
1233 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1234 ):
1235 assert len(inputs) == 1
1236 a = inputs[0]
1237 axis = args_dict["axis"]
1238 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001239
1240 # Invalidate Input/Output list for error if checks.
1241 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001242 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001243 pCount, cCount = op["operands"]
1244 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001245 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1246 self, error_name, input_list, output_list
1247 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001248
Les Bell729b0352021-11-24 10:28:21 +00001249 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001250 self.ser,
1251 validator_fcns,
1252 error_name,
1253 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001254 axis=axis,
1255 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001256 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001257 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001258 output_dtype=result_tensor.dtype,
1259 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001260 input_list=input_list,
1261 output_list=output_list,
1262 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001263 ):
1264 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001265
1266 attr = ts.TosaSerializerAttribute()
1267 attr.AxisAttribute(axis)
1268
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001269 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001270
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001271 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1272 # Number of products - needed for compliance
1273 args_dict["n"] = a.shape[axis]
1274
1275 compliance = self.tensorComplianceMetaData(
1276 op, a.dtype, args_dict, result_tensor, error_name
1277 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001278
1279 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001280
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001281 def build_clamp(
1282 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1283 ):
1284 assert len(inputs) == 1
1285 a = inputs[0]
1286
1287 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001288
Jeremy Johnson18e26662021-07-22 16:15:29 +01001289 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001290
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001291 if error_name == ErrorIf.MaxSmallerMin:
1292 # Make sure the numbers are different to invoke this error
1293 while v[0] == v[1]:
1294 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1295 max_val = min(v)
1296 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001297 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001298 max_val = max(v)
1299 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001300
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001301 # Invalidate Input/Output list for error if checks.
1302 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001303 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001304 pCount, cCount = op["operands"]
1305 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001306 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1307 self, error_name, input_list, output_list
1308 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001309
Les Bell729b0352021-11-24 10:28:21 +00001310 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001311 self.ser,
1312 validator_fcns,
1313 error_name,
1314 op=op,
1315 max_val=max_val,
1316 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001317 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001318 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001319 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001320 output_dtype=result_tensor.dtype,
1321 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001322 input_list=input_list,
1323 output_list=output_list,
1324 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001325 ):
1326 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001327
1328 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001329 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1330 if a.dtype == DType.FP16:
1331 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1332 min_val = min_val.astype(np.float32)
1333 max_val = max_val.astype(np.float32)
1334
1335 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001336 else:
James Ward34071252022-12-07 15:48:47 +00001337 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001338
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001339 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001340
1341 compliance = self.tensorComplianceMetaData(
1342 op, a.dtype, args_dict, result_tensor, error_name
1343 )
1344
1345 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001346
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001347 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1348 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001349 attr = ts.TosaSerializerAttribute()
1350
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001351 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001352
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001353 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001354 return result_tens
1355
1356 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001357 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1358 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001359
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001360 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001361 return result_tens
1362
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001363 def build_activation(
1364 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1365 ):
1366 assert len(inputs) == 1
1367 a = inputs[0]
1368
1369 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001370
1371 # Invalidate Input/Output list for error if checks.
1372 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001373 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374 pCount, cCount = op["operands"]
1375 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001376 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1377 self, error_name, input_list, output_list
1378 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001379
Les Bell729b0352021-11-24 10:28:21 +00001380 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001381 self.ser,
1382 validator_fcns,
1383 error_name,
1384 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001385 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001386 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001387 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001388 output_dtype=result_tensor.dtype,
1389 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001390 input_list=input_list,
1391 output_list=output_list,
1392 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001393 ):
1394 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001395
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001396 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001397
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001398 compliance = self.tensorComplianceMetaData(
1399 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001400 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001401
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001402 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001403
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001404 def build_concat(
1405 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1406 ):
Won Jeon74342e52024-01-09 00:34:40 +00001407 if op["op"] == Op.CONCAT_SHAPE:
1408 axis = 0
1409 else:
1410 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001411 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001412 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001413
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001414 result_tensor = OutputShaper.concatOp(
1415 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001416 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001417
Matthew Haddon818ab902021-07-27 09:12:49 +01001418 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001419 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001420 input_tensor_names.append(tensor.name)
1421
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001422 # Invalidate Input/Output list for error if checks.
1423 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001424 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001425 pCount, cCount = op["operands"]
1426 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001427 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1428 self, error_name, input_list, output_list
1429 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001430
Les Bell729b0352021-11-24 10:28:21 +00001431 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432 self.ser,
1433 validator_fcns,
1434 error_name,
1435 op=op,
1436 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001437 input_shape=inputs[0].shape,
1438 output_shape=result_tensor.shape,
1439 input_dtype=inputs[0].dtype,
1440 output_dtype=result_tensor.dtype,
1441 inputs=inputs,
1442 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001443 input_list=input_list,
1444 output_list=output_list,
1445 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001446 ):
1447 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001448
Won Jeon74342e52024-01-09 00:34:40 +00001449 if op["op"] == Op.CONCAT:
1450 attr = ts.TosaSerializerAttribute()
1451 attr.AxisAttribute(axis)
1452 else:
1453 assert op["op"] == Op.CONCAT_SHAPE
1454 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001455 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001456
1457 compliance = self.tensorComplianceMetaData(
1458 op, inputs[0].dtype, args_dict, result_tensor, error_name
1459 )
1460
1461 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001462
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001463 def build_pad(
1464 self,
1465 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001466 inputs,
1467 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001468 validator_fcns=None,
1469 error_name=None,
1470 qinfo=None,
1471 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001472 assert len(inputs) == 1
1473 a = inputs[0]
1474 padding = args_dict["pad"]
1475 pad_const_int = args_dict["pad_const_int"]
1476 pad_const_float = args_dict["pad_const_fp"]
1477
1478 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001479
Kevin Chengfe392ce2021-10-18 21:51:55 +00001480 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001481 attr.PadAttribute(
1482 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1483 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001484
Matthew Haddone807aae2021-10-11 18:12:58 +01001485 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001486 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001487 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001488 pCount, cCount = op["operands"]
1489 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001490 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1491 self, error_name, input_list, output_list
1492 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001493
Les Bell729b0352021-11-24 10:28:21 +00001494 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001495 self.ser,
1496 validator_fcns,
1497 error_name,
1498 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001499 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001500 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001501 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001502 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001503 pad=padding,
1504 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001505 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001506 input_list=input_list,
1507 output_list=output_list,
1508 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001509 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001510 ):
1511 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001512
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001513 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001514
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001515 compliance = self.tensorComplianceMetaData(
1516 op, a.dtype, args_dict, result_tensor, error_name
1517 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001518
1519 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001520
Won Jeona21b2e82023-08-10 10:33:01 +00001521 def build_dim(
1522 self,
1523 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001524 inputs,
1525 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001526 validator_fcns=None,
1527 error_name=None,
1528 qinfo=None,
1529 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001530 assert len(inputs) == 1
1531 a = inputs[0]
1532 axis = args_dict["axis"]
1533 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001534
1535 # Invalidate Input/Output list for error if checks.
1536 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001537 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001538 pCount, cCount = op["operands"]
1539 num_operands = pCount + cCount
1540 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1541 self, error_name, input_list, output_list
1542 )
1543
1544 if not TosaErrorValidator.evValidateErrorIfs(
1545 self.ser,
1546 validator_fcns,
1547 error_name,
1548 op=op,
1549 axis=axis,
1550 input_shape=a.shape,
1551 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001552 output_shape=result_tensor.shape,
1553 output_dtype=result_tensor.dtype,
1554 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001555 input_list=input_list,
1556 output_list=output_list,
1557 num_operands=num_operands,
1558 ):
1559 return None
1560
1561 attr = ts.TosaSerializerAttribute()
1562 attr.AxisAttribute(axis)
1563
1564 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001565 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001566
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001567 def build_reshape(
1568 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1569 ):
Tai Ly8690a082023-12-18 20:40:24 +00001570 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001571 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001572 shape = inputs[1]
1573 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001574 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001575 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001576 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001577
1578 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001579 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001580 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001581 pCount, cCount = op["operands"]
1582 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001583 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1584 self, error_name, input_list, output_list
1585 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001586
Les Bell729b0352021-11-24 10:28:21 +00001587 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001588 self.ser,
1589 validator_fcns,
1590 error_name,
1591 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001592 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001593 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001594 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001595 output_dtype=result_tensor.dtype,
1596 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001597 input_list=input_list,
1598 output_list=output_list,
1599 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001600 ):
1601 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001602
Tai Ly8690a082023-12-18 20:40:24 +00001603 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001604
1605 compliance = self.tensorComplianceMetaData(
1606 op, a.dtype, args_dict, result_tensor, error_name
1607 )
1608
1609 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001610
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001611 def build_reverse(
1612 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1613 ):
1614 assert len(inputs) == 1
1615 a = inputs[0]
1616 axis = args_dict["axis"]
1617 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001618
1619 # Invalidate Input/Output list for error if checks.
1620 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001621 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001622 pCount, cCount = op["operands"]
1623 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001624 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1625 self, error_name, input_list, output_list
1626 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001627
Les Bell729b0352021-11-24 10:28:21 +00001628 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001629 self.ser,
1630 validator_fcns,
1631 error_name,
1632 op=op,
1633 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001634 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001635 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001636 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001637 output_dtype=result_tensor.dtype,
1638 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001639 input_list=input_list,
1640 output_list=output_list,
1641 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001642 ):
1643 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001644
1645 attr = ts.TosaSerializerAttribute()
1646 attr.AxisAttribute(axis)
1647
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001648 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001649 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001650
Matthew Haddone807aae2021-10-11 18:12:58 +01001651 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1652 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001653
Kevin Chengfe392ce2021-10-18 21:51:55 +00001654 attr = ts.TosaSerializerAttribute()
1655 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001656
Matthew Haddone807aae2021-10-11 18:12:58 +01001657 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001658 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001659 output_list = [result_tens.name]
1660 pCount, cCount = op["operands"]
1661 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001662 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1663 self, error_name, input_list, output_list
1664 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001665
Les Bell729b0352021-11-24 10:28:21 +00001666 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001667 self.ser,
1668 validator_fcns,
1669 error_name,
1670 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001671 input_shape=a.shape,
1672 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001673 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001674 input_dtype=a.dtype,
1675 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001676 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001677 input_list=input_list,
1678 output_list=output_list,
1679 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001680 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001681 ):
1682 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001683
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001684 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001685 return result_tens
1686
Matthew Haddone807aae2021-10-11 18:12:58 +01001687 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001688 result_tens = OutputShaper.sliceOp(
1689 self.ser, self.rng, a, start, size, error_name
1690 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001691
1692 # Invalidate Input/Output list for error if checks.
1693 input_list = [a.name]
1694 output_list = [result_tens.name]
1695 pCount, cCount = op["operands"]
1696 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001697 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1698 self, error_name, input_list, output_list
1699 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001700
Les Bell729b0352021-11-24 10:28:21 +00001701 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001702 self.ser,
1703 validator_fcns,
1704 error_name,
1705 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001706 input_shape=a.shape,
1707 output_shape=result_tens.shape,
1708 input_dtype=a.dtype,
1709 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001710 start=start,
1711 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001712 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001713 input_list=input_list,
1714 output_list=output_list,
1715 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001716 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001717 ):
1718 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001719
1720 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001721 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001722
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001723 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001724 return result_tens
1725
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001726 def build_tile(
1727 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1728 ):
Tai Ly8690a082023-12-18 20:40:24 +00001729 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001730 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001731 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001732 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001733 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001734 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001735 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001736
1737 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001738 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001739 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001740 pCount, cCount = op["operands"]
1741 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001742 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1743 self, error_name, input_list, output_list
1744 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001745
Les Bell729b0352021-11-24 10:28:21 +00001746 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001747 self.ser,
1748 validator_fcns,
1749 error_name,
1750 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001751 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001752 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001753 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001754 output_dtype=result_tensor.dtype,
1755 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001756 input_list=input_list,
1757 output_list=output_list,
1758 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001759 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001760 ):
1761 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001762
Tai Ly8690a082023-12-18 20:40:24 +00001763 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001764
1765 compliance = self.tensorComplianceMetaData(
1766 op, a.dtype, args_dict, result_tensor, error_name
1767 )
1768
1769 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001770
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001771 def build_gather(
1772 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1773 ):
1774 assert len(inputs) == 2
1775 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001776
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001777 result_tensor = OutputShaper.gatherOp(
1778 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001779 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001780
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001781 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001782 input_list = [values.name, indices.name]
1783 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001784 pCount, cCount = op["operands"]
1785 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001786 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1787 self, error_name, input_list, output_list
1788 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001789
Les Bell729b0352021-11-24 10:28:21 +00001790 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001791 self.ser,
1792 validator_fcns,
1793 error_name,
1794 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001795 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001796 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001797 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001798 output_dtype=result_tensor.dtype,
1799 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001800 input_list=input_list,
1801 output_list=output_list,
1802 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001803 ):
1804 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001805
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001806 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001807
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001808 compliance = self.tensorComplianceMetaData(
1809 op, values.dtype, args_dict, result_tensor, error_name
1810 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001811
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001812 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001813
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001814 def build_scatter(
1815 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1816 ):
1817 assert len(inputs) == 3
1818 values_in, indices, input = inputs
1819 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001820 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001821 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001822
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001823 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001824 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001825 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001826 pCount, cCount = op["operands"]
1827 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001828 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1829 self, error_name, input_list, output_list
1830 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001831
Les Bell729b0352021-11-24 10:28:21 +00001832 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001833 self.ser,
1834 validator_fcns,
1835 error_name,
1836 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001837 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001838 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001839 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001840 output_dtype=result_tensor.dtype,
1841 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001842 input_list=input_list,
1843 output_list=output_list,
1844 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001845 ):
1846 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001847
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001848 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001849
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001850 compliance = self.tensorComplianceMetaData(
1851 op, values_in.dtype, args_dict, result_tensor, error_name
1852 )
1853
1854 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001855
Kevin Cheng550ccc52021-03-03 11:21:43 -08001856 def build_resize(
1857 self,
1858 op,
1859 input,
1860 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001861 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001862 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001863 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001864 input_dtype,
1865 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001866 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001867 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001868 ):
1869 result_tens = OutputShaper.resizeOp(
1870 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001871 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001872 input,
1873 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001874 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001875 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001876 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001877 input_dtype,
1878 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001879 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001880 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001881
Matthew Haddon848efb42021-09-09 12:30:53 +01001882 # Invalidate Input/Output list for error if checks.
1883 input_list = [input.name]
1884 output_list = [result_tens.name]
1885 pCount, cCount = op["operands"]
1886 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001887 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1888 self, error_name, input_list, output_list
1889 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001890
Les Bell729b0352021-11-24 10:28:21 +00001891 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001892 self.ser,
1893 validator_fcns,
1894 error_name,
1895 op=op,
1896 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001897 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001898 input_dtype=input_dtype,
1899 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001900 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001901 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001902 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001903 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001904 input_list=input_list,
1905 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001906 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001907 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001908 ):
1909 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001910
Eric Kunzee5e26762020-10-13 16:11:07 -07001911 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001912
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001913 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001914
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001915 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001916 return result_tens
1917
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001918 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1919 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1920 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001921 self.ser.addOperator(
1922 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1923 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001924 return result_tens
1925
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001926 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001927 self.ser.addOutputTensor(val)
1928 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001929
1930 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00001931 def build_cast(
1932 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1933 ):
1934 assert len(inputs) == 1
1935 val = inputs[0]
1936 out_dtype = args_dict["out_type"]
1937
1938 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001939 self.ser, self.rng, val, out_dtype, error_name
1940 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001941
1942 # Invalidate Input/Output list for error if checks.
1943 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00001944 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001945 pCount, cCount = op["operands"]
1946 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001947 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1948 self, error_name, input_list, output_list
1949 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001950
Les Bell729b0352021-11-24 10:28:21 +00001951 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001952 self.ser,
1953 validator_fcns,
1954 error_name,
1955 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001956 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00001957 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001958 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00001959 output_dtype=result_tensor.dtype,
1960 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001961 input_list=input_list,
1962 output_list=output_list,
1963 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001964 ):
1965 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001966
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001967 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00001968
1969 compliance = self.tensorComplianceMetaData(
1970 op, val.dtype, args_dict, result_tensor, error_name
1971 )
1972
1973 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001974
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001975 def build_rescale(
1976 self,
1977 op,
1978 val,
1979 out_dtype,
1980 scale32,
1981 double_round,
1982 per_channel,
1983 validator_fcns,
1984 error_name,
1985 ):
1986 result_tens = OutputShaper.typeConversionOp(
1987 self.ser, self.rng, val, out_dtype, error_name
1988 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001989
1990 if per_channel:
1991 nc = val.shape[-1]
1992 else:
1993 nc = 1
1994
1995 in_type_width = self.typeWidth(val.dtype)
1996 out_type_width = self.typeWidth(out_dtype)
1997
Tai Ly8690a082023-12-18 20:40:24 +00001998 input_unsigned = False
1999 output_unsigned = False
2000
Kevin Cheng3a478572021-01-22 17:21:02 -08002001 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002002 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002003 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002004 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002005 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002006 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002007 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002008 elif error_name in [
2009 ErrorIf.InputZeroPointNotZero,
2010 ErrorIf.U16InputZeroPointNotValid,
2011 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002012 input_zp = self.randInt(-128, 128)
2013 if input_zp == 0:
2014 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002015 in_type_width += 1
2016 elif val.dtype == DType.UINT16:
2017 # Must come after ErrorIf.U16InputZeroPointNotValid check
2018 input_zp = self.rng.choice([0, 32768])
2019 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002020 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002021 else:
2022 input_zp = 0
2023
Kevin Cheng3a478572021-01-22 17:21:02 -08002024 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002025 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002026 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002027 elif out_dtype == DType.UINT8:
2028 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002029 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002030 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002031 elif error_name in [
2032 ErrorIf.OutputZeroPointNotZero,
2033 ErrorIf.U16OutputZeroPointNotValid,
2034 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002035 output_zp = self.randInt(-128, 128)
2036 if output_zp == 0:
2037 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002038 out_type_width += 1
2039 elif out_dtype == DType.UINT16:
2040 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2041 output_zp = self.rng.choice([0, 32768])
2042 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002043 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002044 else:
2045 output_zp = 0
2046
2047 # Calculate scale based on:
2048 # scale = a *(2^output_width)/(2^input_width))
2049
2050 a = np.float32(self.rng.random(size=[nc]))
2051 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2052
2053 if scale32:
2054 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002055 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002056 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2057 else:
2058 # Cap the scaling at 2^15 - 1 for scale16
2059 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2060
Kevin Cheng550ccc52021-03-03 11:21:43 -08002061 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002062
2063 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2064 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002065 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2066 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002067
2068 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002069 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2070 scale_arr[i], scale32
2071 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002072 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2073 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002074
Kevin Cheng550ccc52021-03-03 11:21:43 -08002075 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002076 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002077 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002078 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002079 assert val.placeholderFilename
2080 values = np.load(
2081 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2082 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002083 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2084 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2085 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002086 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2087 # Check we can safely convert to the expected dtype
2088 assert (
2089 val_adj.all() >= np.iinfo(values.dtype).min
2090 and val_adj.all() <= np.iinfo(values.dtype).max
2091 )
2092
2093 # Force casting to output datatype
2094 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2095
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002096 if not np.all(np.array_equal(values, val_adj)):
2097 # Values changed so overwrite file with new values
2098 np.save(
2099 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2100 val_adj,
2101 False,
2102 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002103
Matthew Haddonc2025212021-10-08 21:21:05 +01002104 # Invalidate Input/Output list for error if checks.
2105 input_list = [val.name]
2106 output_list = [result_tens.name]
2107 pCount, cCount = op["operands"]
2108 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002109 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2110 self, error_name, input_list, output_list
2111 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002112
2113 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002114 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002115 self.ser,
2116 validator_fcns,
2117 error_name,
2118 op=op,
2119 input_dtype=val.dtype,
2120 output_dtype=out_dtype,
2121 input_shape=val.shape,
2122 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002123 scale32=scale32,
2124 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002125 input_list=input_list,
2126 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002127 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002128 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002129 ):
2130 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002131
Eric Kunzee5e26762020-10-13 16:11:07 -07002132 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002133 attr.RescaleAttribute(
2134 input_zp,
2135 output_zp,
2136 multiplier_arr,
2137 shift_arr,
2138 scale32,
2139 double_round,
2140 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002141 input_unsigned,
2142 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002143 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002144
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002145 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002146 return result_tens
2147
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002148 def _get_condition_tensor(self, op, cond, error_name):
2149 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002150 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002151 else:
2152 cond_type = DType.BOOL
2153 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2154 choice = self.rng.choice([1, 2])
2155 if choice == 1:
2156 cond_shape = [2]
2157 else:
2158 cond_shape = [1, 2]
2159 else:
2160 # Must be of size 1 (rank 0)
2161 cond_shape = []
2162 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2163 return cond_tens
2164
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002165 def build_cond_if_const(
2166 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2167 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002168 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002169 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002170 # and fill them with const nodes for the body.
2171
2172 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002173 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002174
2175 # Make then/else tensors
2176 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002177
2178 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002179 if error_name in [
2180 ErrorIf.CondIfOutputListThenGraphMismatch,
2181 ErrorIf.CondIfOutputListElseGraphMismatch,
2182 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002183 incorrect_shape = deepcopy(then_tens.shape)
2184 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002185 incorrect_shape[i] += (
2186 self.rng.choice([-3, -2, 2, 3])
2187 if incorrect_shape[i] > 3
2188 else self.rng.choice([1, 2, 4])
2189 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002190 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2191
Jeremy Johnson18e26662021-07-22 16:15:29 +01002192 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2193 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002194
2195 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002196 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002197
2198 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002199 then_block = "THEN_BLOCK"
2200 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002201 attr = ts.TosaSerializerAttribute()
2202 attr.CondIfAttribute(then_block, else_block)
2203
2204 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002205 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002206
Jerry Ge9e94af82022-10-27 09:57:00 -07002207 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002208 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002209 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2210 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2211 else:
2212 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002213 self.ser.addOutputTensor(then_tens)
2214
Jerry Ge9e94af82022-10-27 09:57:00 -07002215 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002216 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2217 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2218 else:
2219 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002220 self.ser.addOutputTensor(else_tens)
2221
Les Bell729b0352021-11-24 10:28:21 +00002222 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002223 self.ser,
2224 validator_fcns,
2225 error_name,
2226 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002227 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002228 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002229 ):
2230 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002231
Eric Kunzee5e26762020-10-13 16:11:07 -07002232 return result_tens
2233
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002234 def build_cond_if_binary(
2235 self, op, a, b, cond, validator_fcns=None, error_name=None
2236 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002237 # For cond_if with a binary op in the then/else blocks, take a and b and
2238 # alternately add or subtract them based on the condition
2239
2240 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002241 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002242
Kevin Cheng550ccc52021-03-03 11:21:43 -08002243 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002244
2245 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002246 then_block = "THEN_BLOCK"
2247 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002248 attr = ts.TosaSerializerAttribute()
2249 attr.CondIfAttribute(then_block, else_block)
2250
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002251 if error_name in [
2252 ErrorIf.CondIfInputListThenGraphMismatch,
2253 ErrorIf.CondIfInputListElseGraphMismatch,
2254 ErrorIf.CondIfOutputListElseGraphMismatch,
2255 ErrorIf.CondIfOutputListThenGraphMismatch,
2256 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002257 incorrect_shape = a.shape.copy()
2258 for i in range(len(incorrect_shape)):
2259 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2260 incorrect_block_input = deepcopy(a)
2261 incorrect_block_input.shape = incorrect_shape
2262
Eric Kunzee5e26762020-10-13 16:11:07 -07002263 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002264 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002265 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002266 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002267
James Ward24dbc422022-10-19 12:20:31 +01002268 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002269 then_op, else_op = Op.ADD, Op.SUB
2270 elif a.dtype in (DType.INT8, DType.INT16):
2271 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2272 else:
2273 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002274
Les Bell6040b4d2021-10-11 12:50:31 +01002275 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002276 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002277 if (
2278 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2279 and block == then_block
2280 ) or (
2281 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2282 and block == else_block
2283 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002284 self.ser.addInputTensor(incorrect_block_input)
2285 self.ser.addInputTensor(b)
2286 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002287 elif (
2288 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2289 and block == then_block
2290 ) or (
2291 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2292 and block == else_block
2293 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002294 self.ser.addInputTensor(a)
2295 self.ser.addInputTensor(b)
2296 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2297 else:
2298 self.ser.addInputTensor(a)
2299 self.ser.addInputTensor(b)
2300 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002301 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002302
Les Bell729b0352021-11-24 10:28:21 +00002303 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002304 self.ser,
2305 validator_fcns,
2306 error_name,
2307 op=op,
2308 a=a,
2309 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002310 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002311 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002312 ):
2313 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002314
Eric Kunzee5e26762020-10-13 16:11:07 -07002315 return result_tens
2316
Matthew Haddon630c17c2021-10-14 15:05:41 +01002317 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002318 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002319
Kevin Cheng550ccc52021-03-03 11:21:43 -08002320 cond_block = "COND_BLOCK"
2321 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002322
2323 attr = ts.TosaSerializerAttribute()
2324 attr.WhileLoopAttribute(cond_block, body_block)
2325
2326 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002327 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002328 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002329 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002330
2331 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002332 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2333 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002334 if error_name == ErrorIf.InputListOutputListMismatch:
2335 incorrect_acc = deepcopy(acc)
2336 for i in range(len(incorrect_acc.shape)):
2337 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2338 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2339 else:
2340 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002341
2342 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002343 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002344 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002345 [iter.name, a.name, acc.name],
2346 [iter_out.name, a_out.name, acc_out.name],
2347 attr,
2348 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002349 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002350
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002351 if error_name in [
2352 ErrorIf.InputListCondGraphMismatch,
2353 ErrorIf.InputListBodyGraphInputMismatch,
2354 ErrorIf.InputListBodyGraphOutputMismatch,
2355 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002356 incorrect_iter = deepcopy(iter)
2357 for i in range(len(incorrect_iter.shape)):
2358 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2359 if len(incorrect_iter.shape) == 0:
2360 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2361
2362 incorrect_acc = deepcopy(acc)
2363 for i in range(len(incorrect_acc.shape)):
2364 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2365
Eric Kunzee5e26762020-10-13 16:11:07 -07002366 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002367 self.ser.addBasicBlock(cond_block)
2368
Matthew Haddon630c17c2021-10-14 15:05:41 +01002369 if error_name == ErrorIf.InputListCondGraphMismatch:
2370 self.ser.addInputTensor(incorrect_iter)
2371 self.ser.addInputTensor(a)
2372 self.ser.addInputTensor(incorrect_acc)
2373 else:
2374 self.ser.addInputTensor(iter)
2375 self.ser.addInputTensor(a)
2376 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002377 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002378
2379 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002380 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002381 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002382 cond_type = DType.BOOL
2383 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2384 choice = self.rng.choice([1, 2])
2385 if choice == 1:
2386 cond_shape = [3]
2387 else:
2388 cond_shape = [1, 2]
2389 else:
2390 cond_shape = []
2391 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002392
Kevin Cheng550ccc52021-03-03 11:21:43 -08002393 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002394
2395 # BODY block (input: a, acc, iter, output: a, acc, iter)
2396 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002397 self.ser.addBasicBlock(body_block)
2398
Matthew Haddon630c17c2021-10-14 15:05:41 +01002399 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2400 self.ser.addInputTensor(incorrect_iter)
2401 self.ser.addInputTensor(a)
2402 self.ser.addInputTensor(incorrect_acc)
2403 else:
2404 self.ser.addInputTensor(iter)
2405 self.ser.addInputTensor(a)
2406 self.ser.addInputTensor(acc)
2407
Kevin Cheng550ccc52021-03-03 11:21:43 -08002408 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002409
2410 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002411 iter_body_out = self.ser.addIntermediate(
2412 incorrect_iter.shape, incorrect_iter.dtype
2413 )
2414 acc_body_out = self.ser.addIntermediate(
2415 incorrect_acc.shape, incorrect_acc.dtype
2416 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002417 else:
2418 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2419 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2420
Eric Kunzee5e26762020-10-13 16:11:07 -07002421 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2422 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2423 self.ser.addOutputTensor(iter_body_out)
2424 self.ser.addOutputTensor(a)
2425 self.ser.addOutputTensor(acc_body_out)
2426
Les Bell729b0352021-11-24 10:28:21 +00002427 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002428 self.ser,
2429 validator_fcns,
2430 error_name,
2431 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002432 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002433 ):
2434 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002435
Eric Kunzee5e26762020-10-13 16:11:07 -07002436 return acc_out
2437
Luke Hutton57287132023-02-06 14:54:18 +00002438 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002439 self,
2440 op,
2441 val1,
2442 val2,
2443 inverse,
2444 validator_fcns=None,
2445 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002446 ):
2447 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2448
2449 input_names = [val1.name, val2.name]
2450 pCount, cCount = op["operands"]
2451 num_operands = pCount + cCount
2452
2453 output_names = [res.name for res in results]
2454 output_shapes = [res.shape for res in results]
2455 output_dtypes = [res.dtype for res in results]
2456
2457 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2458 self, error_name, input_names, output_names
2459 )
2460
2461 if not TosaErrorValidator.evValidateErrorIfs(
2462 self.ser,
2463 validator_fcns,
2464 error_name,
2465 op=op,
2466 inverse=inverse,
2467 input1=val1,
2468 input2=val2,
2469 input_shape=val1.shape,
2470 input_dtype=val1.dtype,
2471 output_shape=output_shapes,
2472 output_dtype=output_dtypes,
2473 result_tensors=results,
2474 input_list=input_names,
2475 output_list=output_names,
2476 num_operands=num_operands,
2477 ):
2478 return None
2479
Tai Lyd3797f02023-11-15 23:06:19 +00002480 # TODO - Test local_bound, for now set local bound attribute to False
2481 local_bound = False
2482
Luke Hutton57287132023-02-06 14:54:18 +00002483 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002484 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002485
2486 self.ser.addOperator(op["op"], input_names, output_names, attr)
2487 return results
2488
Tai Lyd3797f02023-11-15 23:06:19 +00002489 def build_rfft2d(
2490 self,
2491 op,
2492 val,
2493 validator_fcns=None,
2494 error_name=None,
2495 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002496 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2497
2498 input_names = [val.name]
2499 pCount, cCount = op["operands"]
2500 num_operands = pCount + cCount
2501
2502 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002503 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002504 output_dtypes = [res.dtype for res in results]
2505
2506 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2507 self, error_name, input_names, output_names
2508 )
2509
2510 if not TosaErrorValidator.evValidateErrorIfs(
2511 self.ser,
2512 validator_fcns,
2513 error_name,
2514 op=op,
2515 input_shape=val.shape,
2516 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002517 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002518 output_dtype=output_dtypes,
2519 result_tensors=results,
2520 input_list=input_names,
2521 output_list=output_names,
2522 num_operands=num_operands,
2523 ):
2524 return None
2525
Tai Lyd3797f02023-11-15 23:06:19 +00002526 # TODO - Test local_bound, for now set local bound attribute to False
2527 local_bound = False
2528
2529 attr = ts.TosaSerializerAttribute()
2530 attr.RFFTAttribute(local_bound)
2531
2532 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002533 return results
2534
Won Jeon74342e52024-01-09 00:34:40 +00002535 def build_shape_op(
2536 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2537 ):
2538 assert len(inputs) == 2
2539 a, b = inputs
2540
2541 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2542
2543 # Invalidate Input/Output list for error if checks.
2544 input_list = [a.name, b.name]
2545 output_list = [result_tensor.name]
2546 pCount, cCount = op["operands"]
2547 num_operands = pCount + cCount
2548 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2549 self, error_name, input_list, output_list
2550 )
2551
2552 if not TosaErrorValidator.evValidateErrorIfs(
2553 self.ser,
2554 validator_fcns,
2555 error_name,
2556 op=op,
2557 input1=a,
2558 input2=b,
2559 input_shape=a.shape,
2560 input_dtype=a.dtype,
2561 output_shape=result_tensor.shape,
2562 output_dtype=result_tensor.dtype,
2563 result_tensors=[result_tensor],
2564 input_list=input_list,
2565 output_list=output_list,
2566 num_operands=num_operands,
2567 ):
2568 return None
2569
2570 self.ser.addOperator(
2571 op["op"],
2572 input_list,
2573 output_list,
2574 )
2575 compliance = self.tensorComplianceMetaData(
2576 op, a.dtype, args_dict, result_tensor, error_name
2577 )
2578
2579 return TosaTestGen.BuildInfo(result_tensor, compliance)
2580
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002581 def create_filter_lists(
2582 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2583 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002584 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2585 default_test_rank_range = range(1, 5)
2586 if not shapeFilter:
2587 shapeFilter = [None]
2588
2589 # Calculate the filters based on what is requested and what the operator allows
2590 rmin, rmax = op["rank"]
2591 if rankFilter is not None:
2592 cleanRankFilter = []
2593 # Ensure rankFilter values are allowed by operator
2594 for rank in rankFilter:
2595 if rank >= rmin and rank <= rmax:
2596 cleanRankFilter.append(rank)
2597 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002598 # Ensure default behaviour is bounded by default range or by operator,
2599 # whichever is the smaller range of ranks.
2600 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002601 cleanRankFilter = (
2602 opRankRange
2603 if len(opRankRange) <= len(default_test_rank_range)
2604 else default_test_rank_range
2605 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002606 else:
2607 cleanRankFilter = range(rmin, rmax + 1)
2608
2609 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002610
Matthew Haddon1c00b712021-10-01 15:51:03 +01002611 if dtypeFilter is not None:
2612 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002613 # Create list of operator dtypes filtered by requested dtypes
2614 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002615 if dtype in dtypeFilter or (
2616 isinstance(dtype, list) and dtype[0] in dtypeFilter
2617 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002618 cleanDtypeFilter.append(dtype)
2619 else:
2620 cleanDtypeFilter = dtypes
2621
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002622 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002623 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002624 "shapeFilter": shapeFilter,
2625 "rankFilter": cleanRankFilter,
2626 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002627 }
2628 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002629 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002630 if validator is not None:
2631 validator_info = validator(check=False, op=op)
2632 else:
2633 return None
2634
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002635 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002636
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002637 # Set parameters as required
2638 if error_arguments["rank"] is not None:
2639 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002640 else:
2641 rankFilter = cleanRankFilter
2642
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002643 if error_arguments["dtype"] is not None:
2644 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002645 else:
2646 dtypeFilter = cleanDtypeFilter
2647
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002648 if error_arguments["shape"] is not None:
2649 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002650 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002651 shapeFilter = shapeFilter[
2652 :2
2653 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002654
2655 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002656 "shapeFilter": shapeFilter,
2657 "rankFilter": rankFilter,
2658 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002659 }
2660 return filterDict
2661
Kevin Cheng550ccc52021-03-03 11:21:43 -08002662 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002663 self,
2664 opName,
2665 shapeFilter=[None],
2666 rankFilter=None,
2667 dtypeFilter=None,
2668 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002669 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002670
2671 try:
2672 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002673 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002674 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002675
2676 # Initialize a new random number generator
2677 self.rng = np.random.default_rng(self.random_seed)
2678
Jeremy Johnson1271c442023-09-05 11:39:26 +01002679 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002680
Eric Kunzee5e26762020-10-13 16:11:07 -07002681 # Test list consists of a tuple of:
2682 # (opName, testNameStr, dtype, shapeList, argumentsList)
2683 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002684 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002685 error_if_validators = op["error_if_validators"]
2686 else:
2687 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002688
Matthew Haddon1c00b712021-10-01 15:51:03 +01002689 for validator in error_if_validators:
2690 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002691 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002692 else:
2693 error_name = None
2694
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002695 filterDict = self.create_filter_lists(
2696 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2697 )
2698 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002699 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002700 cleanRankFilter = filterDict["rankFilter"]
2701 cleanDtypeFilter = filterDict["dtypeFilter"]
2702 cleanShapeFilter = filterDict["shapeFilter"]
2703 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002704
2705 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002706 for t in cleanDtypeFilter:
2707 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002708 # Filter out by rank
2709 if shape is not None and len(shape) != r:
2710 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002711 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002712 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002713
Matthew Haddon74567092021-07-16 15:38:20 +01002714 shapeStr = self.shapeStr(shapeList[0])
2715 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002716
Matthew Haddon74567092021-07-16 15:38:20 +01002717 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2718 argList = []
2719 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002720 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002721 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002722 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002723
Matthew Haddon74567092021-07-16 15:38:20 +01002724 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002725 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002726 if argStr:
2727 testStr = "{}_{}_{}_{}".format(
2728 opName, shapeStr, typeStr, argStr
2729 )
2730 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002731 testStr = "{}_{}_{}".format(
2732 opName, shapeStr, typeStr
2733 )
2734 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002735 if argStr:
2736 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2737 opName, error_name, shapeStr, typeStr, argStr
2738 )
2739 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002740 testStr = "{}_ERRORIF_{}_{}_{}".format(
2741 opName, error_name, shapeStr, typeStr
2742 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002743
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002744 testList.append(
2745 (opName, testStr, t, error_name, shapeList, args)
2746 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002747
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002748 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002749 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2750 if "invalid_test_validators" in op:
2751 invalid_test_validators = op["invalid_test_validators"]
2752 clean_testList = []
2753 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002754 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002755 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002756 if validator_fcn(
2757 opName=test[0],
2758 input_dtype=test[2],
2759 shapeList=test[4],
2760 args=test[5],
2761 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002762 remove_test = True
2763 if not remove_test:
2764 clean_testList.append(test)
2765 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002766
2767 return testList
2768
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002769 def serializeTest(
2770 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2771 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002772 try:
2773 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002774 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002775 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002776
Jeremy Johnson0c716862023-04-13 17:18:19 +01002777 if self.args.verbose:
2778 print(f"Creating {testStr}")
2779
Eric Kunzee5e26762020-10-13 16:11:07 -07002780 # Create a serializer
2781 self.createSerializer(opName, testStr)
2782
Jeremy Johnson1271c442023-09-05 11:39:26 +01002783 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002784 if "error_if_validators" in op:
2785 error_if_validators = op["error_if_validators"]
2786 else:
2787 error_if_validators = None
2788
Kevin Cheng550ccc52021-03-03 11:21:43 -08002789 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002790 num_operands = pCount + cCount
2791
2792 if isinstance(dtype_or_dtypeList, list):
2793 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00002794 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002795 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002796 else:
2797 dtypeList = [dtype_or_dtypeList] * (num_operands)
2798
Won Jeon74342e52024-01-09 00:34:40 +00002799 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002800 assert (
2801 len(shapeList) == num_operands
2802 ), "shapeList length {} must match number of operands {}".format(
2803 len(shapeList), num_operands
2804 )
2805 assert (
2806 len(dtypeList) == num_operands
2807 ), "dtypeList length {} must match number of operands {}".format(
2808 len(dtypeList), num_operands
2809 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002810
2811 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002812 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002813 except KeyError:
2814 qgen = None
2815
2816 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002817
Matthew Haddon1c00b712021-10-01 15:51:03 +01002818 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002819 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002820 else:
2821 qinfo = None
2822
Jeremy Johnson1271c442023-09-05 11:39:26 +01002823 # Extra meta data for the desc.json
2824 tensMeta = {}
2825
2826 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002827 if isinstance(testArgs, dict):
2828 # New interface with args info in dictionary
2829 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002830 assert "dg_type" in argsDict
2831 tvgInfo = tvgen_fcn(
2832 self, opName, dtypeList, shapeList, argsDict, error_name
2833 )
2834 if tvgInfo.dataGenDict:
2835 tensMeta["data_gen"] = tvgInfo.dataGenDict
2836 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002837
2838 result = build_fcn(
2839 self,
2840 op,
2841 tens,
2842 argsDict,
2843 validator_fcns=error_if_validators,
2844 error_name=error_name,
2845 qinfo=qinfo,
2846 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002847 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002848 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002849 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002850
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002851 try:
2852 if error_if_validators is None:
2853 if qinfo is not None:
2854 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2855 else:
2856 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002857 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002858 if qinfo is not None:
2859 result = build_fcn(
2860 self,
2861 op,
2862 *tens,
2863 *testArgs,
2864 validator_fcns=error_if_validators,
2865 error_name=error_name,
2866 qinfo=qinfo,
2867 )
2868 else:
2869 result = build_fcn(
2870 self,
2871 op,
2872 *tens,
2873 *testArgs,
2874 validator_fcns=error_if_validators,
2875 error_name=error_name,
2876 )
2877 except TypeError as e:
2878 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2879 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002880
Jeremy Johnson1271c442023-09-05 11:39:26 +01002881 if result:
Les Bell729b0352021-11-24 10:28:21 +00002882 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002883 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2884 # Add the compliance meta data
2885 # NOTE: This currently expects only one result output
2886 tensMeta["compliance"] = {
2887 "version": "0.1",
2888 "tensors": {result.resultTensor.name: result.complianceDict},
2889 }
2890 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002891 else:
2892 # The test is not valid
2893 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002894
Eric Kunzee5e26762020-10-13 16:11:07 -07002895 def createDynamicOpLists(self):
2896
Jeremy Johnson00423432022-09-12 17:27:37 +01002897 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2898 # Already created these lists (can occur when class is initialized more than once)
2899 return
2900
Eric Kunzee5e26762020-10-13 16:11:07 -07002901 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002902 if not self.args.level8k:
2903 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2904 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2905 else:
2906 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2907 KERNELS_2D = [[1, bigK], [bigK, 2]]
2908 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002909
Kevin Cheng1533b852021-09-01 12:51:58 -07002910 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002911 testName = "conv2d_{}x{}".format(k[0], k[1])
2912 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2913 self.TOSA_OP_LIST[testName]["filter"] = k
2914 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002915
Kevin Cheng550ccc52021-03-03 11:21:43 -08002916 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2917 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2918 "depthwise_conv2d_TEMPLATE"
2919 ].copy()
2920 self.TOSA_OP_LIST[testName]["filter"] = k
2921 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002922
Kevin Cheng550ccc52021-03-03 11:21:43 -08002923 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2924 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2925 "transpose_conv2d_TEMPLATE"
2926 ].copy()
2927 self.TOSA_OP_LIST[testName]["filter"] = k
2928 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002929
Kevin Cheng1533b852021-09-01 12:51:58 -07002930 for k in KERNELS_3D:
2931 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2932 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2933 self.TOSA_OP_LIST[testName]["filter"] = k
2934 self.TOSA_OP_LIST[testName]["template"] = False
2935
Eric Kunzee5e26762020-10-13 16:11:07 -07002936 # Delete any templates after having created any dynamic ops
2937 # This is a two-pass operation because it's bad practice to delete
2938 # keys from dictionaries while iterating
2939 keyList = []
2940 for k in self.TOSA_OP_LIST:
2941 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002942 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002943 keyList.append(k)
2944 continue
2945 except KeyError:
2946 pass
2947
2948 for k in keyList:
2949 del self.TOSA_OP_LIST[k]
2950
2951 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002952 """Fill in default fields for ops if they aren't already specified.
2953 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002954 for op in self.TOSA_OP_LIST:
2955
2956 # Required fields
2957 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002958 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002959 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002960 raise Exception(
2961 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2962 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002963
2964 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002965 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002966 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002967 raise Exception(
2968 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2969 op
2970 )
2971 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002972
2973 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002974 _ = self.TOSA_OP_LIST[op]["types"]
2975 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002976 raise Exception(
2977 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2978 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002979
2980 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002981 _ = self.TOSA_OP_LIST[op]["op"]
2982 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002983 raise Exception(
2984 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2985 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002986
2987 # Put in default rank range, if missing
2988 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002989 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002990 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002991 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002992
2993 # Tensor operator list
2994 # 'op': op name
2995 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002996 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2997 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002998 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2999 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003000 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003001
Kevin Cheng550ccc52021-03-03 11:21:43 -08003002 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003003 TYPE_INT_FP = [
3004 DType.INT8,
3005 DType.INT16,
3006 DType.INT32,
3007 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003008 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003009 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003010 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003011
Kevin Cheng550ccc52021-03-03 11:21:43 -08003012 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003013 TYPE_FI32 = [
3014 DType.FP32,
3015 DType.FP16,
3016 DType.BF16,
3017 DType.INT32,
3018 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003019 TYPE_FIB = [
3020 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003021 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003022 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003023 DType.INT8,
3024 DType.INT16,
3025 DType.INT32,
3026 DType.BOOL,
3027 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003028 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003029
James Ward24dbc422022-10-19 12:20:31 +01003030 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07003031
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003032 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003033 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003034 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003035 [DType.INT8, DType.INT8, DType.INT32],
3036 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003037 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003038 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003039 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003040 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003041 ]
3042
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003043 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003044
3045 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003046 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003047 "argmax": {
3048 "op": Op.ARGMAX,
3049 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003050 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003051 "build_fcn": (
3052 build_argmax,
3053 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003054 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003055 TosaArgGen.agAxis,
3056 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003057 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003058 "error_if_validators": (
3059 TosaErrorValidator.evAxisSmallerZero,
3060 TosaErrorValidator.evAxisLargerRank,
3061 TosaErrorValidator.evArgmaxOutputRankMismatch,
3062 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3063 TosaErrorValidator.evWrongRank,
3064 TosaErrorValidator.evWrongInputType,
3065 TosaErrorValidator.evWrongOutputType,
3066 TosaErrorValidator.evWrongInputList,
3067 TosaErrorValidator.evWrongOutputList,
3068 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003069 "data_gen": {
3070 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3071 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003072 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003073 "avg_pool2d": {
3074 "op": Op.AVG_POOL2D,
3075 "operands": (1, 0),
3076 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003077 "build_fcn": (
3078 build_pool2d,
3079 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003080 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003081 TosaArgGen.agPooling,
3082 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003083 "qgen": TosaQuantGen.qgUnary,
3084 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003085 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003086 "error_if_validators": (
3087 TosaErrorValidator.evKernelSmallerOne,
3088 TosaErrorValidator.evStrideSmallerOne,
3089 TosaErrorValidator.evPadSmallerZero,
3090 TosaErrorValidator.evWrongRank,
3091 TosaErrorValidator.evWrongInputType,
3092 TosaErrorValidator.evWrongOutputType,
3093 TosaErrorValidator.evWrongInputList,
3094 TosaErrorValidator.evWrongOutputList,
3095 TosaErrorValidator.evInputZeroPointNotZero,
3096 TosaErrorValidator.evOutputZeroPointNotZero,
3097 TosaErrorValidator.evPadLargerEqualKernel,
3098 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003099 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003100 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003101 "data_gen": {
3102 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3103 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003104 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003105 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003106 "conv2d_TEMPLATE": {
3107 "op": Op.CONV2D,
3108 "operands": (1, 2),
3109 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003110 "build_fcn": (
3111 build_conv2d,
3112 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003113 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003114 TosaArgGen.agConv,
3115 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003116 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003117 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003118 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3119 "error_if_validators": (
3120 TosaErrorValidator.evWrongInputType,
3121 TosaErrorValidator.evWrongOutputType,
3122 TosaErrorValidator.evWrongInputList,
3123 TosaErrorValidator.evWrongOutputList,
3124 TosaErrorValidator.evInputZeroPointNotZero,
3125 TosaErrorValidator.evWeightZeroPointNotZero,
3126 TosaErrorValidator.evPadSmallerZero,
3127 TosaErrorValidator.evStrideSmallerOne,
3128 TosaErrorValidator.evDilationSmallerOne,
3129 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003130 TosaErrorValidator.evConvOutputShapeMismatch,
3131 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003132 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003133 "data_gen": {
3134 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3135 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003136 "template": True,
3137 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003138 # Templated operator. Filled in by createDynamicOpLists
3139 "conv3d_TEMPLATE": {
3140 "op": Op.CONV3D,
3141 "operands": (1, 2),
3142 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003143 "build_fcn": (
3144 build_conv3d,
3145 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003146 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003147 TosaArgGen.agConv,
3148 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003149 "qgen": TosaQuantGen.qgConv,
3150 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003151 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3152 "error_if_validators": (
3153 TosaErrorValidator.evWrongInputType,
3154 TosaErrorValidator.evWrongOutputType,
3155 TosaErrorValidator.evWrongInputList,
3156 TosaErrorValidator.evWrongOutputList,
3157 TosaErrorValidator.evInputZeroPointNotZero,
3158 TosaErrorValidator.evWeightZeroPointNotZero,
3159 TosaErrorValidator.evPadSmallerZero,
3160 TosaErrorValidator.evStrideSmallerOne,
3161 TosaErrorValidator.evDilationSmallerOne,
3162 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003163 TosaErrorValidator.evConvOutputShapeMismatch,
3164 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003165 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003166 "template": True,
3167 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003168 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003169 "depthwise_conv2d_TEMPLATE": {
3170 "op": Op.DEPTHWISE_CONV2D,
3171 "operands": (1, 2),
3172 "filter": [1, 1],
3173 "rank": (4, 4),
3174 "build_fcn": (
3175 build_depthwise_conv2d,
3176 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003177 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003178 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003179 ),
3180 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003181 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003182 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3183 "error_if_validators": (
3184 TosaErrorValidator.evWrongInputType,
3185 TosaErrorValidator.evWrongOutputType,
3186 TosaErrorValidator.evWrongInputList,
3187 TosaErrorValidator.evWrongOutputList,
3188 TosaErrorValidator.evInputZeroPointNotZero,
3189 TosaErrorValidator.evWeightZeroPointNotZero,
3190 TosaErrorValidator.evPadSmallerZero,
3191 TosaErrorValidator.evStrideSmallerOne,
3192 TosaErrorValidator.evDilationSmallerOne,
3193 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003194 TosaErrorValidator.evConvOutputShapeMismatch,
3195 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003196 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003197 "template": True,
3198 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003199 "fully_connected": {
3200 "op": Op.FULLY_CONNECTED,
3201 "operands": (1, 2),
3202 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003203 "build_fcn": (
3204 build_fully_connected,
3205 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003206 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003207 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003208 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003209 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003210 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003211 "error_if_validators": (
3212 TosaErrorValidator.evInputZeroPointNotZero,
3213 TosaErrorValidator.evWeightZeroPointNotZero,
3214 TosaErrorValidator.evWrongRank,
3215 TosaErrorValidator.evWrongInputType,
3216 TosaErrorValidator.evWrongOutputType,
3217 TosaErrorValidator.evWrongInputList,
3218 TosaErrorValidator.evWrongOutputList,
3219 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003220 "data_gen": {
3221 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3222 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003223 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003224 "matmul": {
3225 "op": Op.MATMUL,
3226 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003227 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003228 "build_fcn": (
3229 build_matmul,
3230 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003231 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003232 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003233 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003234 "qgen": TosaQuantGen.qgMatmul,
3235 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003236 "error_if_validators": (
3237 TosaErrorValidator.evInputZeroPointNotZero,
3238 TosaErrorValidator.evWrongRank,
3239 TosaErrorValidator.evWrongInputType,
3240 TosaErrorValidator.evWrongOutputType,
3241 TosaErrorValidator.evWrongInputList,
3242 TosaErrorValidator.evWrongOutputList,
3243 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003244 "data_gen": {
3245 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003246 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003247 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003248 "max_pool2d": {
3249 "op": Op.MAX_POOL2D,
3250 "operands": (1, 0),
3251 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003252 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003253 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003254 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003255 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003256 TosaArgGen.agPooling,
3257 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003258 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003259 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003260 "error_if_validators": (
3261 TosaErrorValidator.evKernelSmallerOne,
3262 TosaErrorValidator.evStrideSmallerOne,
3263 TosaErrorValidator.evPadSmallerZero,
3264 TosaErrorValidator.evWrongRank,
3265 TosaErrorValidator.evWrongInputType,
3266 TosaErrorValidator.evWrongOutputType,
3267 TosaErrorValidator.evWrongInputList,
3268 TosaErrorValidator.evWrongOutputList,
3269 TosaErrorValidator.evPadLargerEqualKernel,
3270 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003271 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003272 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003273 "data_gen": {
3274 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3275 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003276 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003277 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003278 "transpose_conv2d_TEMPLATE": {
3279 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003280 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003281 "rank": (4, 4),
3282 "build_fcn": (
3283 build_transpose_conv2d,
3284 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003285 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003286 TosaArgGen.agTransposeConv2D,
3287 ),
3288 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003289 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003290 "invalid_test_validators": (
3291 TosaInvalidValidator.ivHeightWidthInvalid,
3292 TosaInvalidValidator.ivNonPositiveOutputShape,
3293 ),
3294 "error_if_validators": (
3295 TosaErrorValidator.evWrongInputType,
3296 TosaErrorValidator.evWrongOutputType,
3297 TosaErrorValidator.evWrongInputList,
3298 TosaErrorValidator.evWrongOutputList,
3299 TosaErrorValidator.evInputZeroPointNotZero,
3300 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003301 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003302 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003303 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003304 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003305 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003306 "template": True,
3307 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003308 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003309 "clamp": {
3310 "op": Op.CLAMP,
3311 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003312 "build_fcn": (
3313 build_clamp,
3314 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003315 TosaTensorValuesGen.tvgLazyGenDefault,
3316 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003317 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003318 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003319 "error_if_validators": (
3320 TosaErrorValidator.evMaxSmallerMin,
3321 TosaErrorValidator.evWrongInputType,
3322 TosaErrorValidator.evWrongOutputType,
3323 TosaErrorValidator.evWrongInputList,
3324 TosaErrorValidator.evWrongOutputList,
3325 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003326 "data_gen": {
3327 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3328 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003329 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003330 "sigmoid": {
3331 "op": Op.SIGMOID,
3332 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003333 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003334 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003335 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003336 TosaTensorValuesGen.tvgLazyGenDefault,
3337 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003338 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003339 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003340 "error_if_validators": (
3341 TosaErrorValidator.evWrongInputType,
3342 TosaErrorValidator.evWrongOutputType,
3343 TosaErrorValidator.evWrongInputList,
3344 TosaErrorValidator.evWrongOutputList,
3345 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003346 "data_gen": {
3347 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3348 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003349 },
3350 "tanh": {
3351 "op": Op.TANH,
3352 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003353 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003354 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003355 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003356 TosaTensorValuesGen.tvgLazyGenDefault,
3357 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003358 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003359 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003360 "error_if_validators": (
3361 TosaErrorValidator.evWrongInputType,
3362 TosaErrorValidator.evWrongOutputType,
3363 TosaErrorValidator.evWrongInputList,
3364 TosaErrorValidator.evWrongOutputList,
3365 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003366 "data_gen": {
3367 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3368 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003369 "compliance": {
3370 "abs_error_lower_bound": 0.5,
3371 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003372 },
Won Jeon78155c62023-06-10 00:20:04 +00003373 "erf": {
3374 "op": Op.ERF,
3375 "operands": (1, 0),
3376 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003377 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003378 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003379 TosaTensorValuesGen.tvgLazyGenDefault,
3380 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003381 ),
3382 "types": TYPE_FP,
3383 "error_if_validators": (
3384 TosaErrorValidator.evWrongInputType,
3385 TosaErrorValidator.evWrongOutputType,
3386 TosaErrorValidator.evWrongInputList,
3387 TosaErrorValidator.evWrongOutputList,
3388 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003389 "data_gen": {
3390 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3391 },
3392 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003393 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003394 # Elementwise Binary Operators
3395 "add": {
3396 "op": Op.ADD,
3397 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003398 "build_fcn": (
3399 build_binary_broadcast,
3400 TosaTensorGen.tgBroadcastFuzz,
3401 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003402 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003403 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003404 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003405 "error_if_validators": (
3406 TosaErrorValidator.evRankMismatch,
3407 TosaErrorValidator.evWrongInputType,
3408 TosaErrorValidator.evWrongOutputType,
3409 TosaErrorValidator.evWrongInputList,
3410 TosaErrorValidator.evWrongOutputList,
3411 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003412 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003413 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003414 "data_gen": {
3415 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3416 },
3417 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003418 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003419 "arithmetic_right_shift": {
3420 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3421 "operands": (2, 0),
3422 "build_fcn": (
3423 build_arithmetic_right_shift,
3424 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003425 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003426 TosaArgGen.agArithmeticRightShift,
3427 ),
3428 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003429 "error_if_validators": (
3430 TosaErrorValidator.evRankMismatch,
3431 TosaErrorValidator.evWrongInputType,
3432 TosaErrorValidator.evWrongOutputType,
3433 TosaErrorValidator.evWrongInputList,
3434 TosaErrorValidator.evWrongOutputList,
3435 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003436 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003437 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003438 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003439 "bitwise_and": {
3440 "op": Op.BITWISE_AND,
3441 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003442 "build_fcn": (
3443 build_binary_broadcast,
3444 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003445 TosaTensorValuesGen.tvgLazyGenDefault,
3446 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003447 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003448 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003449 "error_if_validators": (
3450 TosaErrorValidator.evRankMismatch,
3451 TosaErrorValidator.evWrongInputType,
3452 TosaErrorValidator.evWrongOutputType,
3453 TosaErrorValidator.evWrongInputList,
3454 TosaErrorValidator.evWrongOutputList,
3455 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003456 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003457 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003458 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003459 "bitwise_or": {
3460 "op": Op.BITWISE_OR,
3461 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003462 "build_fcn": (
3463 build_binary_broadcast,
3464 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003465 TosaTensorValuesGen.tvgLazyGenDefault,
3466 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003467 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003468 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003469 "error_if_validators": (
3470 TosaErrorValidator.evRankMismatch,
3471 TosaErrorValidator.evWrongInputType,
3472 TosaErrorValidator.evWrongOutputType,
3473 TosaErrorValidator.evWrongInputList,
3474 TosaErrorValidator.evWrongOutputList,
3475 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003476 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003477 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003478 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003479 "bitwise_xor": {
3480 "op": Op.BITWISE_XOR,
3481 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003482 "build_fcn": (
3483 build_binary_broadcast,
3484 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003485 TosaTensorValuesGen.tvgLazyGenDefault,
3486 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003487 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003488 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003489 "error_if_validators": (
3490 TosaErrorValidator.evRankMismatch,
3491 TosaErrorValidator.evWrongInputType,
3492 TosaErrorValidator.evWrongOutputType,
3493 TosaErrorValidator.evWrongInputList,
3494 TosaErrorValidator.evWrongOutputList,
3495 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003496 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003497 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003498 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003499 "intdiv": {
3500 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003501 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003502 "build_fcn": (
3503 build_binary_broadcast,
3504 TosaTensorGen.tgBroadcastFuzz,
3505 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003506 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003507 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003508 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003509 "error_if_validators": (
3510 TosaErrorValidator.evRankMismatch,
3511 TosaErrorValidator.evWrongInputType,
3512 TosaErrorValidator.evWrongOutputType,
3513 TosaErrorValidator.evWrongInputList,
3514 TosaErrorValidator.evWrongOutputList,
3515 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003516 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003517 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003518 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003519 "logical_and": {
3520 "op": Op.LOGICAL_AND,
3521 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003522 "build_fcn": (
3523 build_binary_broadcast,
3524 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003525 TosaTensorValuesGen.tvgLazyGenDefault,
3526 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003527 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003528 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003529 "error_if_validators": (
3530 TosaErrorValidator.evRankMismatch,
3531 TosaErrorValidator.evWrongInputType,
3532 TosaErrorValidator.evWrongOutputType,
3533 TosaErrorValidator.evWrongInputList,
3534 TosaErrorValidator.evWrongOutputList,
3535 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003536 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003537 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003538 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003539 "logical_left_shift": {
3540 "op": Op.LOGICAL_LEFT_SHIFT,
3541 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003542 "build_fcn": (
3543 build_binary_broadcast,
3544 TosaTensorGen.tgBroadcastFuzz,
3545 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003546 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003547 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003548 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003549 "error_if_validators": (
3550 TosaErrorValidator.evRankMismatch,
3551 TosaErrorValidator.evWrongInputType,
3552 TosaErrorValidator.evWrongOutputType,
3553 TosaErrorValidator.evWrongInputList,
3554 TosaErrorValidator.evWrongOutputList,
3555 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003556 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003557 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003558 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003559 "logical_right_shift": {
3560 "op": Op.LOGICAL_RIGHT_SHIFT,
3561 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003562 "build_fcn": (
3563 build_binary_broadcast,
3564 TosaTensorGen.tgBroadcastFuzz,
3565 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003566 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003567 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003568 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003569 "error_if_validators": (
3570 TosaErrorValidator.evRankMismatch,
3571 TosaErrorValidator.evWrongInputType,
3572 TosaErrorValidator.evWrongOutputType,
3573 TosaErrorValidator.evWrongInputList,
3574 TosaErrorValidator.evWrongOutputList,
3575 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003576 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003577 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003578 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003579 "logical_or": {
3580 "op": Op.LOGICAL_OR,
3581 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003582 "build_fcn": (
3583 build_binary_broadcast,
3584 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003585 TosaTensorValuesGen.tvgLazyGenDefault,
3586 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003587 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003588 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003589 "error_if_validators": (
3590 TosaErrorValidator.evRankMismatch,
3591 TosaErrorValidator.evWrongInputType,
3592 TosaErrorValidator.evWrongOutputType,
3593 TosaErrorValidator.evWrongInputList,
3594 TosaErrorValidator.evWrongOutputList,
3595 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003596 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003597 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003598 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003599 "logical_xor": {
3600 "op": Op.LOGICAL_XOR,
3601 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003602 "build_fcn": (
3603 build_binary_broadcast,
3604 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003605 TosaTensorValuesGen.tvgLazyGenDefault,
3606 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003607 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003608 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003609 "error_if_validators": (
3610 TosaErrorValidator.evRankMismatch,
3611 TosaErrorValidator.evWrongInputType,
3612 TosaErrorValidator.evWrongOutputType,
3613 TosaErrorValidator.evWrongInputList,
3614 TosaErrorValidator.evWrongOutputList,
3615 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003616 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003617 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003618 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003619 "maximum": {
3620 "op": Op.MAXIMUM,
3621 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003622 "build_fcn": (
3623 build_binary_broadcast,
3624 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003625 TosaTensorValuesGen.tvgLazyGenDefault,
3626 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003627 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003628 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003629 "error_if_validators": (
3630 TosaErrorValidator.evRankMismatch,
3631 TosaErrorValidator.evWrongInputType,
3632 TosaErrorValidator.evWrongOutputType,
3633 TosaErrorValidator.evWrongInputList,
3634 TosaErrorValidator.evWrongOutputList,
3635 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003636 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003637 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003638 "data_gen": {
3639 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3640 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003641 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003642 "minimum": {
3643 "op": Op.MINIMUM,
3644 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003645 "build_fcn": (
3646 build_binary_broadcast,
3647 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003648 TosaTensorValuesGen.tvgLazyGenDefault,
3649 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003650 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003651 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003652 "error_if_validators": (
3653 TosaErrorValidator.evRankMismatch,
3654 TosaErrorValidator.evWrongInputType,
3655 TosaErrorValidator.evWrongOutputType,
3656 TosaErrorValidator.evWrongInputList,
3657 TosaErrorValidator.evWrongOutputList,
3658 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003659 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003660 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003661 "data_gen": {
3662 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3663 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003664 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003665 "mul": {
3666 "op": Op.MUL,
3667 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003668 "build_fcn": (
3669 build_mul,
3670 TosaTensorGen.tgBroadcastFuzz,
3671 TosaTensorValuesGen.tvgMul,
3672 TosaArgGen.agMul,
3673 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003674 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003675 "error_if_validators": (
3676 TosaErrorValidator.evWrongInputType,
3677 TosaErrorValidator.evWrongOutputType,
3678 TosaErrorValidator.evWrongInputList,
3679 TosaErrorValidator.evWrongOutputList,
3680 TosaErrorValidator.evRankMismatch,
3681 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003682 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003683 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003684 "data_gen": {
3685 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3686 },
3687 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003688 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003689 "pow": {
3690 "op": Op.POW,
3691 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003692 "build_fcn": (
3693 build_binary_broadcast,
3694 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003695 TosaTensorValuesGen.tvgPow,
3696 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003697 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003698 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003699 "error_if_validators": (
3700 TosaErrorValidator.evRankMismatch,
3701 TosaErrorValidator.evWrongInputType,
3702 TosaErrorValidator.evWrongOutputType,
3703 TosaErrorValidator.evWrongInputList,
3704 TosaErrorValidator.evWrongOutputList,
3705 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003706 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003707 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003708 "data_gen": {
3709 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3710 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003711 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003712 "sub": {
3713 "op": Op.SUB,
3714 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003715 "build_fcn": (
3716 build_binary_broadcast,
3717 TosaTensorGen.tgBroadcastFuzz,
3718 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003719 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003720 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003721 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003722 "error_if_validators": (
3723 TosaErrorValidator.evRankMismatch,
3724 TosaErrorValidator.evWrongInputType,
3725 TosaErrorValidator.evWrongOutputType,
3726 TosaErrorValidator.evWrongInputList,
3727 TosaErrorValidator.evWrongOutputList,
3728 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003729 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003730 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003731 "data_gen": {
3732 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3733 },
3734 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003735 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003736 "table": {
3737 "op": Op.TABLE,
3738 # Use the automatic generation functions to create the input array
3739 # but create the table tensor in the build function, as it may be
3740 # a different type from the input
3741 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003742 "build_fcn": (
3743 build_table,
3744 TosaTensorGen.tgBasic,
3745 TosaTensorValuesGen.tvgDefault,
3746 TosaArgGen.agTable,
3747 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003748 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003749 "error_if_validators": (
3750 TosaErrorValidator.evWrongInputType,
3751 TosaErrorValidator.evWrongOutputType,
3752 TosaErrorValidator.evWrongInputList,
3753 TosaErrorValidator.evWrongOutputList,
3754 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003755 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003756 # Elementwise Unary operators
3757 "abs": {
3758 "op": Op.ABS,
3759 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003760 "build_fcn": (
3761 build_unary,
3762 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003763 TosaTensorValuesGen.tvgLazyGenDefault,
3764 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003765 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003766 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003767 "error_if_validators": (
3768 TosaErrorValidator.evWrongInputType,
3769 TosaErrorValidator.evWrongOutputType,
3770 TosaErrorValidator.evWrongInputList,
3771 TosaErrorValidator.evWrongOutputList,
3772 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003773 "data_gen": {
3774 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3775 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003776 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003777 "bitwise_not": {
3778 "op": Op.BITWISE_NOT,
3779 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003780 "build_fcn": (
3781 build_unary,
3782 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003783 TosaTensorValuesGen.tvgLazyGenDefault,
3784 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003785 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003786 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003787 "error_if_validators": (
3788 TosaErrorValidator.evWrongInputType,
3789 TosaErrorValidator.evWrongOutputType,
3790 TosaErrorValidator.evWrongInputList,
3791 TosaErrorValidator.evWrongOutputList,
3792 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003793 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003794 "ceil": {
3795 "op": Op.CEIL,
3796 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003797 "build_fcn": (
3798 build_unary,
3799 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003800 TosaTensorValuesGen.tvgLazyGenDefault,
3801 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003802 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003803 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003804 "error_if_validators": (
3805 TosaErrorValidator.evWrongInputType,
3806 TosaErrorValidator.evWrongOutputType,
3807 TosaErrorValidator.evWrongInputList,
3808 TosaErrorValidator.evWrongOutputList,
3809 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003810 "data_gen": {
3811 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3812 },
3813 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003814 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003815 "clz": {
3816 "op": Op.CLZ,
3817 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003818 "build_fcn": (
3819 build_unary,
3820 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003821 TosaTensorValuesGen.tvgLazyGenDefault,
3822 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003823 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003824 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003825 "error_if_validators": (
3826 TosaErrorValidator.evWrongInputType,
3827 TosaErrorValidator.evWrongOutputType,
3828 TosaErrorValidator.evWrongInputList,
3829 TosaErrorValidator.evWrongOutputList,
3830 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003831 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003832 "exp": {
3833 "op": Op.EXP,
3834 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003835 "build_fcn": (
3836 build_unary,
3837 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003838 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003839 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003840 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003841 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003842 "error_if_validators": (
3843 TosaErrorValidator.evWrongInputType,
3844 TosaErrorValidator.evWrongOutputType,
3845 TosaErrorValidator.evWrongInputList,
3846 TosaErrorValidator.evWrongOutputList,
3847 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003848 "data_gen": {
3849 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3850 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003851 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003852 "floor": {
3853 "op": Op.FLOOR,
3854 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003855 "build_fcn": (
3856 build_unary,
3857 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003858 TosaTensorValuesGen.tvgLazyGenDefault,
3859 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003860 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003861 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003862 "error_if_validators": (
3863 TosaErrorValidator.evWrongInputType,
3864 TosaErrorValidator.evWrongOutputType,
3865 TosaErrorValidator.evWrongInputList,
3866 TosaErrorValidator.evWrongOutputList,
3867 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003868 "data_gen": {
3869 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3870 },
3871 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003872 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003873 "log": {
3874 "op": Op.LOG,
3875 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003876 "build_fcn": (
3877 build_unary,
3878 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003879 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003880 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003881 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003882 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003883 "error_if_validators": (
3884 TosaErrorValidator.evWrongInputType,
3885 TosaErrorValidator.evWrongOutputType,
3886 TosaErrorValidator.evWrongInputList,
3887 TosaErrorValidator.evWrongOutputList,
3888 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003889 "data_gen": {
3890 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3891 },
3892 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003893 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003894 "logical_not": {
3895 "op": Op.LOGICAL_NOT,
3896 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003897 "build_fcn": (
3898 build_unary,
3899 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003900 TosaTensorValuesGen.tvgLazyGenDefault,
3901 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003902 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003903 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003904 "error_if_validators": (
3905 TosaErrorValidator.evWrongInputType,
3906 TosaErrorValidator.evWrongOutputType,
3907 TosaErrorValidator.evWrongInputList,
3908 TosaErrorValidator.evWrongOutputList,
3909 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003910 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003911 "negate": {
3912 "op": Op.NEGATE,
3913 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003914 "build_fcn": (
3915 build_unary,
3916 TosaTensorGen.tgBasic,
3917 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003918 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003919 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003920 "qgen": TosaQuantGen.qgUnary,
3921 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003922 "error_if_validators": (
3923 TosaErrorValidator.evInputZeroPointNotZero,
3924 TosaErrorValidator.evOutputZeroPointNotZero,
3925 TosaErrorValidator.evWrongInputType,
3926 TosaErrorValidator.evWrongOutputType,
3927 TosaErrorValidator.evWrongInputList,
3928 TosaErrorValidator.evWrongOutputList,
3929 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003930 "data_gen": {
3931 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3932 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003933 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003934 "reciprocal": {
3935 "op": Op.RECIPROCAL,
3936 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003937 "build_fcn": (
3938 build_unary,
3939 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003940 TosaTensorValuesGen.tvgLazyGenDefault,
3941 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003942 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003943 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003944 "error_if_validators": (
3945 TosaErrorValidator.evWrongInputType,
3946 TosaErrorValidator.evWrongOutputType,
3947 TosaErrorValidator.evWrongInputList,
3948 TosaErrorValidator.evWrongOutputList,
3949 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003950 "data_gen": {
3951 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3952 },
3953 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003954 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003955 "rsqrt": {
3956 "op": Op.RSQRT,
3957 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003958 "build_fcn": (
3959 build_unary,
3960 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003961 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003962 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003963 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003964 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003965 "error_if_validators": (
3966 TosaErrorValidator.evWrongInputType,
3967 TosaErrorValidator.evWrongOutputType,
3968 TosaErrorValidator.evWrongInputList,
3969 TosaErrorValidator.evWrongOutputList,
3970 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003971 "data_gen": {
3972 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3973 },
3974 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003975 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003976 # Elementwise Ternary operators
3977 "select": {
3978 "op": Op.SELECT,
3979 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003980 "build_fcn": (
3981 build_select,
3982 TosaTensorGen.tgBroadcastFuzz,
3983 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00003984 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003985 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003986 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003987 "error_if_validators": (
3988 TosaErrorValidator.evRankMismatch,
3989 TosaErrorValidator.evWrongInputType,
3990 TosaErrorValidator.evWrongOutputType,
3991 TosaErrorValidator.evWrongInputList,
3992 TosaErrorValidator.evWrongOutputList,
3993 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003994 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003995 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00003996 "data_gen": {
3997 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3998 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003999 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004000 # Comparison operators
4001 "equal": {
4002 "op": Op.EQUAL,
4003 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004004 "build_fcn": (
4005 build_comparison,
4006 TosaTensorGen.tgBroadcastFuzz,
4007 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004008 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004009 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004010 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004011 "error_if_validators": (
4012 TosaErrorValidator.evRankMismatch,
4013 TosaErrorValidator.evWrongInputType,
4014 TosaErrorValidator.evWrongOutputType,
4015 TosaErrorValidator.evWrongInputList,
4016 TosaErrorValidator.evWrongOutputList,
4017 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004018 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004019 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004020 "data_gen": {
4021 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4022 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004023 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004024 "greater_equal": {
4025 "op": Op.GREATER_EQUAL,
4026 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004027 "build_fcn": (
4028 build_comparison,
4029 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004030 TosaTensorValuesGen.tvgLazyGenDefault,
4031 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004032 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004033 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004034 "error_if_validators": (
4035 TosaErrorValidator.evRankMismatch,
4036 TosaErrorValidator.evWrongInputType,
4037 TosaErrorValidator.evWrongOutputType,
4038 TosaErrorValidator.evWrongInputList,
4039 TosaErrorValidator.evWrongOutputList,
4040 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004041 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004042 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004043 "data_gen": {
4044 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4045 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004046 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004047 "greater": {
4048 "op": Op.GREATER,
4049 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004050 "build_fcn": (
4051 build_comparison,
4052 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004053 TosaTensorValuesGen.tvgLazyGenDefault,
4054 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004055 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004056 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004057 "error_if_validators": (
4058 TosaErrorValidator.evRankMismatch,
4059 TosaErrorValidator.evWrongInputType,
4060 TosaErrorValidator.evWrongOutputType,
4061 TosaErrorValidator.evWrongInputList,
4062 TosaErrorValidator.evWrongOutputList,
4063 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004064 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004065 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004066 "data_gen": {
4067 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4068 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004069 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004070 # Reduction operators
4071 "reduce_all": {
4072 "op": Op.REDUCE_ALL,
4073 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004074 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004075 "build_fcn": (
4076 build_reduce,
4077 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004078 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004079 TosaArgGen.agAxis,
4080 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004081 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004082 "error_if_validators": (
4083 TosaErrorValidator.evAxisLargerRank,
4084 TosaErrorValidator.evAxisSmallerZero,
4085 TosaErrorValidator.evShapeOfAxisNotOne,
4086 TosaErrorValidator.evWrongInputType,
4087 TosaErrorValidator.evWrongOutputType,
4088 TosaErrorValidator.evWrongRank,
4089 TosaErrorValidator.evWrongInputList,
4090 TosaErrorValidator.evWrongOutputList,
4091 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004092 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004093 "reduce_any": {
4094 "op": Op.REDUCE_ANY,
4095 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004096 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004097 "build_fcn": (
4098 build_reduce,
4099 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004100 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004101 TosaArgGen.agAxis,
4102 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004103 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004104 "error_if_validators": (
4105 TosaErrorValidator.evAxisLargerRank,
4106 TosaErrorValidator.evAxisSmallerZero,
4107 TosaErrorValidator.evShapeOfAxisNotOne,
4108 TosaErrorValidator.evWrongInputType,
4109 TosaErrorValidator.evWrongOutputType,
4110 TosaErrorValidator.evWrongRank,
4111 TosaErrorValidator.evWrongInputList,
4112 TosaErrorValidator.evWrongOutputList,
4113 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004114 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004115 "reduce_max": {
4116 "op": Op.REDUCE_MAX,
4117 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004118 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004119 "build_fcn": (
4120 build_reduce,
4121 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004122 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004123 TosaArgGen.agAxis,
4124 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004125 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004126 "error_if_validators": (
4127 TosaErrorValidator.evAxisLargerRank,
4128 TosaErrorValidator.evAxisSmallerZero,
4129 TosaErrorValidator.evShapeOfAxisNotOne,
4130 TosaErrorValidator.evWrongInputType,
4131 TosaErrorValidator.evWrongOutputType,
4132 TosaErrorValidator.evWrongRank,
4133 TosaErrorValidator.evWrongInputList,
4134 TosaErrorValidator.evWrongOutputList,
4135 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004136 "data_gen": {
4137 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4138 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004139 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004140 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004141 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004142 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004143 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004144 "build_fcn": (
4145 build_reduce,
4146 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004147 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004148 TosaArgGen.agAxis,
4149 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004150 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004151 "error_if_validators": (
4152 TosaErrorValidator.evAxisLargerRank,
4153 TosaErrorValidator.evAxisSmallerZero,
4154 TosaErrorValidator.evShapeOfAxisNotOne,
4155 TosaErrorValidator.evWrongInputType,
4156 TosaErrorValidator.evWrongOutputType,
4157 TosaErrorValidator.evWrongRank,
4158 TosaErrorValidator.evWrongInputList,
4159 TosaErrorValidator.evWrongOutputList,
4160 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004161 "data_gen": {
4162 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4163 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004164 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004165 "reduce_product": {
4166 "op": Op.REDUCE_PRODUCT,
4167 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004168 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004169 "build_fcn": (
4170 build_reduce,
4171 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004172 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004173 TosaArgGen.agAxis,
4174 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004175 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004176 "error_if_validators": (
4177 TosaErrorValidator.evAxisLargerRank,
4178 TosaErrorValidator.evAxisSmallerZero,
4179 TosaErrorValidator.evShapeOfAxisNotOne,
4180 TosaErrorValidator.evWrongInputType,
4181 TosaErrorValidator.evWrongOutputType,
4182 TosaErrorValidator.evWrongRank,
4183 TosaErrorValidator.evWrongInputList,
4184 TosaErrorValidator.evWrongOutputList,
4185 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004186 "data_gen": {
4187 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4188 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004189 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004190 "reduce_sum": {
4191 "op": Op.REDUCE_SUM,
4192 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004193 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004194 "build_fcn": (
4195 build_reduce,
4196 TosaTensorGen.tgBasic,
4197 TosaTensorValuesGen.tvgReduceSum,
4198 TosaArgGen.agAxis,
4199 ),
James Ward24dbc422022-10-19 12:20:31 +01004200 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004201 "error_if_validators": (
4202 TosaErrorValidator.evAxisLargerRank,
4203 TosaErrorValidator.evAxisSmallerZero,
4204 TosaErrorValidator.evShapeOfAxisNotOne,
4205 TosaErrorValidator.evWrongInputType,
4206 TosaErrorValidator.evWrongOutputType,
4207 TosaErrorValidator.evWrongRank,
4208 TosaErrorValidator.evWrongInputList,
4209 TosaErrorValidator.evWrongOutputList,
4210 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004211 "data_gen": {
4212 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4213 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004214 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004215 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004216 "concat": {
4217 "op": Op.CONCAT,
4218 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004219 "build_fcn": (
4220 build_concat,
4221 TosaTensorGen.tgConcat,
4222 TosaTensorValuesGen.tvgConcat,
4223 TosaArgGen.agAxis,
4224 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004225 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004226 "error_if_validators": (
4227 TosaErrorValidator.evAxisLargerRank,
4228 TosaErrorValidator.evAxisSmallerZero,
4229 TosaErrorValidator.evConcatInputRankMismatch,
4230 TosaErrorValidator.evConcatShapeSumMismatch,
4231 TosaErrorValidator.evConcatInputDimMismatch,
4232 TosaErrorValidator.evWrongInputType,
4233 TosaErrorValidator.evWrongOutputType,
4234 TosaErrorValidator.evWrongOutputList,
4235 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004236 "data_gen": {
4237 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4238 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004239 },
4240 "pad": {
4241 "op": Op.PAD,
4242 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004243 "build_fcn": (
4244 build_pad,
4245 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004246 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004247 TosaArgGen.agPad,
4248 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004249 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004250 "error_if_validators": (
4251 TosaErrorValidator.evWrongInputType,
4252 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004253 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004254 TosaErrorValidator.evWrongOutputType,
4255 TosaErrorValidator.evWrongInputList,
4256 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004257 TosaErrorValidator.evRankMismatch,
4258 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004259 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004260 "data_gen": {
4261 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4262 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004263 },
Won Jeona21b2e82023-08-10 10:33:01 +00004264 "dim": {
4265 "op": Op.DIM,
4266 "operands": (1, 0),
4267 "build_fcn": (
4268 build_dim,
4269 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004270 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004271 TosaArgGen.agAxis,
4272 ),
4273 "types": TYPE_FIB,
4274 "error_if_validators": (
4275 TosaErrorValidator.evAxisLargerRank,
4276 TosaErrorValidator.evAxisSmallerZero,
4277 TosaErrorValidator.evWrongInputType,
4278 TosaErrorValidator.evWrongInputList,
4279 TosaErrorValidator.evWrongOutputList,
4280 TosaErrorValidator.evWrongRank,
4281 ),
4282 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004283 "reshape": {
4284 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004285 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004286 "build_fcn": (
4287 build_reshape,
4288 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004289 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004290 TosaArgGen.agReshape,
4291 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004292 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004293 "error_if_validators": (
4294 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4295 TosaErrorValidator.evWrongInputType,
4296 TosaErrorValidator.evWrongOutputType,
4297 TosaErrorValidator.evWrongInputList,
4298 TosaErrorValidator.evWrongOutputList,
4299 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004300 "data_gen": {
4301 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4302 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004303 },
4304 "reverse": {
4305 "op": Op.REVERSE,
4306 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004307 "build_fcn": (
4308 build_reverse,
4309 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004310 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004311 TosaArgGen.agAxis,
4312 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004313 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004314 "error_if_validators": (
4315 TosaErrorValidator.evAxisSmallerZero,
4316 TosaErrorValidator.evAxisLargerRank,
4317 TosaErrorValidator.evWrongInputType,
4318 TosaErrorValidator.evWrongOutputType,
4319 TosaErrorValidator.evWrongInputList,
4320 TosaErrorValidator.evWrongOutputList,
4321 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004322 },
4323 "slice": {
4324 "op": Op.SLICE,
4325 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004326 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004327 "build_fcn": (
4328 build_slice,
4329 TosaTensorGen.tgBasic,
4330 TosaTensorValuesGen.tvgDefault,
4331 TosaArgGen.agSlice,
4332 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004333 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004334 "error_if_validators": (
4335 TosaErrorValidator.evStartSmallerZero,
4336 TosaErrorValidator.evSizeSmallerEqualZero,
4337 TosaErrorValidator.evStartSizeOutsideBounds,
4338 TosaErrorValidator.evSizeOutputShapeMismatch,
4339 TosaErrorValidator.evInputSizeStartLengthMismatch,
4340 TosaErrorValidator.evWrongRank,
4341 TosaErrorValidator.evWrongInputType,
4342 TosaErrorValidator.evWrongOutputType,
4343 TosaErrorValidator.evWrongInputList,
4344 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004345 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004346 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004347 },
4348 "tile": {
4349 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004350 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004351 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004352 "build_fcn": (
4353 build_tile,
4354 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004355 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004356 TosaArgGen.agTile,
4357 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004358 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004359 "error_if_validators": (
4360 TosaErrorValidator.evWrongInputType,
4361 TosaErrorValidator.evWrongOutputType,
4362 TosaErrorValidator.evWrongInputList,
4363 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004364 TosaErrorValidator.evRankMismatch,
4365 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004366 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004367 "data_gen": {
4368 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4369 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004370 },
4371 "transpose": {
4372 "op": Op.TRANSPOSE,
4373 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004374 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004375 "build_fcn": (
4376 build_transpose,
4377 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004378 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004379 TosaArgGen.agTranspose,
4380 ),
4381 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004382 "error_if_validators": (
4383 TosaErrorValidator.evIndexOutsideBounds,
4384 TosaErrorValidator.evIndexUsedTwice,
4385 TosaErrorValidator.evWrongInputType,
4386 TosaErrorValidator.evWrongOutputType,
4387 TosaErrorValidator.evWrongInputList,
4388 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004389 TosaErrorValidator.evWrongRank,
4390 TosaErrorValidator.evRankMismatch,
4391 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004392 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004393 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004394 # Data nodes
4395 "const": {
4396 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004397 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004398 "build_fcn": (
4399 build_const,
4400 TosaTensorGen.tgBasic,
4401 TosaTensorValuesGen.tvgDefault,
4402 None,
4403 ),
Luke Hutton65872422023-02-20 10:33:04 +00004404 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004405 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004406 "identity": {
4407 "op": Op.IDENTITY,
4408 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004409 "build_fcn": (
4410 build_unary,
4411 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004412 TosaTensorValuesGen.tvgLazyGenDefault,
4413 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004414 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004415 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004416 "data_gen": {
4417 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4418 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004419 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004420 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004421 "gather": {
4422 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004423 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004424 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004425 "build_fcn": (
4426 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004427 TosaTensorGen.tgGather,
4428 TosaTensorValuesGen.tvgGather,
4429 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004430 ),
James Ward24dbc422022-10-19 12:20:31 +01004431 "types": (
4432 DType.INT8,
4433 DType.INT16,
4434 DType.INT32,
4435 DType.FP16,
4436 DType.BF16,
4437 DType.FP32,
4438 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004439 "error_if_validators": (
4440 TosaErrorValidator.evWrongInputType,
4441 TosaErrorValidator.evWrongOutputType,
4442 TosaErrorValidator.evWrongInputList,
4443 TosaErrorValidator.evWrongOutputList,
4444 TosaErrorValidator.evWrongRank,
4445 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004446 "data_gen": {
4447 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4448 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004449 },
4450 "scatter": {
4451 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004452 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004453 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004454 "build_fcn": (
4455 build_scatter,
4456 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004457 TosaTensorValuesGen.tvgScatter,
4458 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004459 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004460 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004461 "error_if_validators": (
4462 TosaErrorValidator.evWrongInputType,
4463 TosaErrorValidator.evWrongOutputType,
4464 TosaErrorValidator.evWrongInputList,
4465 TosaErrorValidator.evWrongOutputList,
4466 TosaErrorValidator.evWrongRank,
4467 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004468 "data_gen": {
4469 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4470 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004471 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004472 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004473 "resize": {
4474 "op": Op.RESIZE,
4475 "operands": (1, 0),
4476 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004477 "build_fcn": (
4478 build_resize,
4479 TosaTensorGen.tgNHWC,
4480 TosaTensorValuesGen.tvgDefault,
4481 TosaArgGen.agResize,
4482 ),
James Ward24dbc422022-10-19 12:20:31 +01004483 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004484 "invalid_test_validators": (
4485 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004486 ),
4487 "error_if_validators": (
4488 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004489 TosaErrorValidator.evScaleSmallerEqualZero,
4490 TosaErrorValidator.evScaleNLargerMax,
4491 TosaErrorValidator.evScaleDLargerMax,
4492 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004493 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004494 TosaErrorValidator.evBorderSmallerMin,
4495 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004496 TosaErrorValidator.evWrongInputType,
4497 TosaErrorValidator.evWrongOutputType,
4498 TosaErrorValidator.evWrongRank,
4499 TosaErrorValidator.evWrongInputList,
4500 TosaErrorValidator.evWrongOutputList,
4501 TosaErrorValidator.evBatchMismatch,
4502 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004503 TosaErrorValidator.evResizeOutputShapeMismatch,
4504 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004505 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004506 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004507 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004508 "cast": {
4509 "op": Op.CAST,
4510 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004511 "build_fcn": (
4512 build_cast,
4513 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004514 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004515 TosaArgGen.agCast,
4516 ),
James Ward8b390432022-08-12 20:48:56 +01004517 "types": (
4518 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004519 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004520 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004521 DType.INT8,
4522 DType.INT16,
4523 DType.INT32,
4524 DType.BOOL,
4525 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004526 "error_if_validators": (
4527 TosaErrorValidator.evWrongInputType,
4528 TosaErrorValidator.evWrongOutputType,
4529 TosaErrorValidator.evWrongInputList,
4530 TosaErrorValidator.evWrongOutputList,
4531 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004532 "data_gen": {
4533 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4534 },
4535 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004536 },
4537 "rescale": {
4538 "op": Op.RESCALE,
4539 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004540 "build_fcn": (
4541 build_rescale,
4542 TosaTensorGen.tgBasic,
4543 TosaTensorValuesGen.tvgDefault,
4544 TosaArgGen.agRescale,
4545 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004546 "types": [
4547 DType.UINT8,
4548 DType.INT8,
4549 DType.INT16,
4550 DType.INT32,
4551 DType.INT48,
4552 DType.UINT16,
4553 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004554 "error_if_validators": (
4555 TosaErrorValidator.evInputZeroPointNotZero,
4556 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004557 TosaErrorValidator.evU16InputZeroPointNotValid,
4558 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004559 TosaErrorValidator.evScaleTrue,
4560 TosaErrorValidator.evScaleNotTrue,
4561 TosaErrorValidator.evWrongInputType,
4562 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004563 TosaErrorValidator.evWrongInputList,
4564 TosaErrorValidator.evWrongOutputList,
4565 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004566 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004567 # Custom
4568 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004569 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004570 # Two varients of cond_if, one that generates one of two constant tensors (no
4571 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4572 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004573 "cond_if_const": {
4574 "op": Op.COND_IF,
4575 "operands": (0, 2),
4576 "build_fcn": (
4577 build_cond_if_const,
4578 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004579 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004580 TosaArgGen.agCondIf,
4581 ),
4582 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004583 "error_if_validators": (
4584 TosaErrorValidator.evOutputListThenGraphMismatch,
4585 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004586 TosaErrorValidator.evCondIfCondNotMatchingBool,
4587 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004588 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004589 },
4590 "cond_if_binary": {
4591 "op": Op.COND_IF,
4592 "operands": (2, 0),
4593 "build_fcn": (
4594 build_cond_if_binary,
4595 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004596 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004597 TosaArgGen.agCondIf,
4598 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004599 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004600 "error_if_validators": (
4601 TosaErrorValidator.evInputListThenGraphMismatch,
4602 TosaErrorValidator.evInputListElseGraphMismatch,
4603 TosaErrorValidator.evOutputListThenGraphMismatch,
4604 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004605 TosaErrorValidator.evCondIfCondNotMatchingBool,
4606 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004607 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004608 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004609 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004610 "while_loop": {
4611 "op": Op.WHILE_LOOP,
4612 "operands": (0, 1),
4613 "build_fcn": (
4614 build_while_loop,
4615 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004616 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004617 TosaArgGen.agWhileLoop,
4618 ),
4619 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004620 "error_if_validators": (
4621 TosaErrorValidator.evInputListOutputListMismatch,
4622 TosaErrorValidator.evInputListCondGraphMismatch,
4623 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4624 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4625 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004626 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004627 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004628 },
Luke Hutton57287132023-02-06 14:54:18 +00004629 "fft2d": {
4630 "op": Op.FFT2D,
4631 "operands": (2, 0),
4632 "rank": (3, 3),
4633 "build_fcn": (
4634 build_fft2d,
4635 TosaTensorGen.tgFFT2d,
4636 TosaTensorValuesGen.tvgDefault,
4637 TosaArgGen.agFFT2d,
4638 ),
4639 "types": [DType.FP32],
4640 "error_if_validators": (
4641 TosaErrorValidator.evWrongInputType,
4642 TosaErrorValidator.evWrongOutputType,
4643 TosaErrorValidator.evWrongInputList,
4644 TosaErrorValidator.evWrongOutputList,
4645 TosaErrorValidator.evWrongRank,
4646 TosaErrorValidator.evBatchMismatch,
4647 TosaErrorValidator.evKernelNotPowerOfTwo,
4648 TosaErrorValidator.evFFTInputShapeMismatch,
4649 TosaErrorValidator.evFFTOutputShapeMismatch,
4650 ),
4651 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004652 "rfft2d": {
4653 "op": Op.RFFT2D,
4654 "operands": (1, 0),
4655 "rank": (3, 3),
4656 "build_fcn": (
4657 build_rfft2d,
4658 TosaTensorGen.tgRFFT2d,
4659 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004660 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004661 ),
4662 "types": [DType.FP32],
4663 "error_if_validators": (
4664 TosaErrorValidator.evWrongInputType,
4665 TosaErrorValidator.evWrongOutputType,
4666 TosaErrorValidator.evWrongInputList,
4667 TosaErrorValidator.evWrongOutputList,
4668 TosaErrorValidator.evWrongRank,
4669 TosaErrorValidator.evBatchMismatch,
4670 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004671 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004672 ),
4673 },
Won Jeon74342e52024-01-09 00:34:40 +00004674 # Shape
4675 "add_shape": {
4676 "op": Op.ADD_SHAPE,
4677 "operands": (2, 0),
4678 "build_fcn": (
4679 build_shape_op,
4680 TosaTensorGen.tgShape,
4681 TosaTensorValuesGen.tvgAddSub,
4682 TosaArgGen.agNone,
4683 ),
4684 "types": [DType.SHAPE],
4685 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4686 },
4687 "sub_shape": {
4688 "op": Op.SUB_SHAPE,
4689 "operands": (2, 0),
4690 "build_fcn": (
4691 build_shape_op,
4692 TosaTensorGen.tgShape,
4693 TosaTensorValuesGen.tvgAddSub,
4694 TosaArgGen.agNone,
4695 ),
4696 "types": [DType.SHAPE],
4697 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4698 },
4699 "mul_shape": {
4700 "op": Op.MUL_SHAPE,
4701 "operands": (2, 0),
4702 "build_fcn": (
4703 build_shape_op,
4704 TosaTensorGen.tgShape,
4705 TosaTensorValuesGen.tvgMul,
4706 TosaArgGen.agNone,
4707 ),
4708 "types": [DType.SHAPE],
4709 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4710 },
4711 "div_shape": {
4712 "op": Op.DIV_SHAPE,
4713 "operands": (2, 0),
4714 "build_fcn": (
4715 build_shape_op,
4716 TosaTensorGen.tgShape,
4717 TosaTensorValuesGen.tvgIntDiv,
4718 TosaArgGen.agNone,
4719 ),
4720 "types": [DType.SHAPE],
4721 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4722 },
4723 "concat_shape": {
4724 "op": Op.CONCAT_SHAPE,
4725 "operands": (2, 0),
4726 "build_fcn": (
4727 build_concat,
4728 TosaTensorGen.tgConcat,
4729 TosaTensorValuesGen.tvgConcat,
4730 TosaArgGen.agNone,
4731 ),
4732 "types": [DType.SHAPE],
4733 "error_if_validators": (),
4734 },
4735 "const_shape": {
4736 "op": Op.CONST_SHAPE,
4737 "operands": (0, 1),
4738 "build_fcn": (
4739 build_const,
4740 TosaTensorGen.tgBasic,
4741 TosaTensorValuesGen.tvgDefault,
4742 None,
4743 ),
4744 "types": [DType.SHAPE],
4745 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004746 }
4747
Kevin Cheng550ccc52021-03-03 11:21:43 -08004748
Eric Kunzee5e26762020-10-13 16:11:07 -07004749class OutputShaper:
4750 # Methods in this class compute the expected output shape and datatype
4751 # for common classes of operations
4752 def __init__(self):
4753 pass
4754
4755 # These methods return arguments that can be used for
4756 # creating a new output tensor
4757 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004758 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4759 if error_name != ErrorIf.RankMismatch:
4760 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004761 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004762
4763 shape = []
4764 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004765 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004766 shape.append(b.shape[i])
4767 else:
4768 shape.append(a.shape[i])
4769
Jerry Ge135c9552023-05-23 20:59:32 +00004770 fuzz_idx = rng.integers(0, len(a.shape))
4771 if error_name == ErrorIf.DimensionMismatch:
4772 shape[fuzz_idx] += 1
4773
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004774 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004775 all_dtypes = [
4776 DType.INT8,
4777 DType.INT16,
4778 DType.INT32,
4779 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004780 DType.FP16,
4781 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004782 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004783 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004784 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4785 outputDType = rng.choice(wrong_dtypes)
4786 else:
4787 outputDType = a.dtype
4788
4789 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004790
4791 @staticmethod
4792 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004793 assert len(a.shape) == len(b.shape)
4794 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004795
4796 shape = []
4797 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004798 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004799 shape.append(a.shape[i])
4800
Kevin Cheng550ccc52021-03-03 11:21:43 -08004801 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004802
4803 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004804 def unaryOp(ser, rng, a, error_name=None):
4805 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004806 all_dtypes = [
4807 DType.INT8,
4808 DType.INT16,
4809 DType.INT32,
4810 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004811 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004812 DType.FP16,
4813 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004814 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004815 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4816 outputDType = rng.choice(wrong_dtypes)
4817 else:
4818 outputDType = a.dtype
4819
4820 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004821
4822 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004823 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004824 if error_name != ErrorIf.RankMismatch:
4825 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004826 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004827
4828 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004829 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004830 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004831 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4832 else:
4833 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004834
Jerry Ge135c9552023-05-23 20:59:32 +00004835 fuzz_idx = rng.integers(0, len(a.shape))
4836 if error_name == ErrorIf.DimensionMismatch:
4837 shape[fuzz_idx] += 1
4838
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004839 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004840 all_dtypes = [
4841 DType.INT8,
4842 DType.INT16,
4843 DType.INT32,
4844 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004845 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004846 DType.FP16,
4847 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004848 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004849 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4850 outputDType = rng.choice(wrong_dtypes)
4851 else:
4852 outputDType = a.dtype
4853
4854 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004855
4856 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004857 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004858 if error_name != ErrorIf.RankMismatch:
4859 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004860 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004861
4862 # Do broadcast
4863 shape = []
4864 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004865 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004866 shape.append(b.shape[i])
4867 else:
4868 shape.append(a.shape[i])
4869
Jerry Ge135c9552023-05-23 20:59:32 +00004870 fuzz_idx = rng.integers(0, len(a.shape))
4871 if error_name == ErrorIf.DimensionMismatch:
4872 shape[fuzz_idx] += 1
4873
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004874 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004875 wrong_dtypes = [
4876 DType.INT8,
4877 DType.INT16,
4878 DType.INT32,
4879 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004880 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004881 DType.FP16,
4882 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004883 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004884 outputDType = rng.choice(wrong_dtypes)
4885 else:
4886 outputDType = DType.BOOL
4887
4888 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004889
4890 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004891 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004892 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004893 if error_name not in [
4894 ErrorIf.AxisSmallerZero,
4895 ErrorIf.AxisLargerRank,
4896 ErrorIf.ShapeOfAxisNotOne,
4897 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004898 shape[axis] = 1
4899 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4900 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004901
Matthew Haddond6ce7252021-09-29 15:35:44 +01004902 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004903 all_dtypes = [
4904 DType.INT8,
4905 DType.INT16,
4906 DType.INT32,
4907 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004908 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004909 DType.FP16,
4910 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004911 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004912 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4913 outputDType = rng.choice(wrong_dtypes)
4914 else:
4915 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004916
Matthew Haddond6ce7252021-09-29 15:35:44 +01004917 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004918
4919 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004920 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004921 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004922
4923 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4924 del shape[axis]
4925
4926 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4927 remove = rng.choice([True, False])
4928 if remove and len(shape) > 1:
4929 del shape[0]
4930 else:
4931 shape.append(1)
4932 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4933 for i in range(len(shape)):
4934 shape[i] = shape[i] + rng.integers(1, 10)
4935
4936 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004937 all_dtypes = [
4938 DType.INT8,
4939 DType.INT16,
4940 DType.INT32,
4941 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004942 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004943 DType.FP16,
4944 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004945 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004946 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4947 outputDType = rng.choice(wrong_dtypes)
4948 else:
4949 outputDType = DType.INT32
4950
4951 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004952
4953 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004954 def conv2dOp(
4955 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4956 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004957
4958 # IFM: NHWC
4959 # Filter: OHWI
4960 # OFM: NHWC
4961
Kevin Cheng550ccc52021-03-03 11:21:43 -08004962 h = (
4963 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004964 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004965 + padding[0]
4966 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004967 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004968 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004969
Kevin Cheng550ccc52021-03-03 11:21:43 -08004970 w = (
4971 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004972 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004973 + padding[2]
4974 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004975 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004976 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004977
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004978 if error_name == ErrorIf.ConvOutputShapeMismatch:
4979 choices = [1, 2, 3]
4980 change = rng.choice(choices)
4981 # increment in multiples of stride to not hit non-integer error case
4982 if change in [1, 3]:
4983 h = h + (rng.choice(choices) * strides[0])
4984 if change in [2, 3]:
4985 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004986
Eric Kunzee5e26762020-10-13 16:11:07 -07004987 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4988
James Ward8b390432022-08-12 20:48:56 +01004989 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004990 # Pick some potentially correct output dtype if input type is incorrect
4991 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004992 else:
James Ward8b390432022-08-12 20:48:56 +01004993 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004994
4995 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004996 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004997 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004998 else:
4999 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005000 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005001 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005002
Kevin Cheng550ccc52021-03-03 11:21:43 -08005003 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005004
5005 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005006 def conv3dOp(
5007 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5008 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005009
5010 # IFM: NDHWC
5011 # Filter: ODHWI
5012 # OFM: NDHWC
5013
5014 d = (
5015 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005016 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005017 + padding[0]
5018 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005019 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005020 ) // strides[0] + 1
5021
5022 h = (
5023 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005024 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005025 + padding[2]
5026 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005027 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005028 ) // strides[1] + 1
5029
5030 w = (
5031 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005032 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005033 + padding[4]
5034 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005035 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005036 ) // strides[2] + 1
5037
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005038 if error_name == ErrorIf.ConvOutputShapeMismatch:
5039 choices = [1, 2, 3, 4]
5040 change = rng.choice(choices)
5041 # increment in multiples of stride to not hit non-integer error case
5042 if change in [1, 4]:
5043 d = d + (rng.choice(choices) * strides[0])
5044 if change in [2, 4]:
5045 h = h + (rng.choice(choices) * strides[1])
5046 if change in [3, 4]:
5047 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005048
Kevin Cheng1533b852021-09-01 12:51:58 -07005049 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5050
James Ward8b390432022-08-12 20:48:56 +01005051 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005052 # Pick some potentially correct output dtype if input type is incorrect
5053 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005054 else:
James Ward8b390432022-08-12 20:48:56 +01005055 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005056
5057 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005058 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005059 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005060 else:
5061 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005062 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005063 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005064
5065 return ser.addOutput(ofm_shape, out_dtype)
5066
5067 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005068 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005069 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005070 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005071 # IFM: NHWC
5072 # Filter: HWCM
5073 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005074
Kevin Cheng550ccc52021-03-03 11:21:43 -08005075 h = (
5076 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005077 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005078 + padding[0]
5079 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005080 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005081 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005082
Kevin Cheng550ccc52021-03-03 11:21:43 -08005083 w = (
5084 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005085 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005086 + padding[2]
5087 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005088 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005089 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005090
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005091 if error_name == ErrorIf.ConvOutputShapeMismatch:
5092 choices = [1, 2, 3]
5093 change = rng.choice(choices)
5094 # increment in multiples of stride to not hit non-integer error case
5095 if change in [1, 3]:
5096 h = h + (rng.choice(choices) * strides[0])
5097 if change in [2, 3]:
5098 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005099
Eric Kunzee5e26762020-10-13 16:11:07 -07005100 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5101
James Ward8b390432022-08-12 20:48:56 +01005102 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005103 # Pick some potentially correct output dtype if input type is incorrect
5104 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005105 else:
James Ward8b390432022-08-12 20:48:56 +01005106 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005107
5108 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005109 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005110 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005111 else:
5112 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005113 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005114 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005115
Kevin Cheng550ccc52021-03-03 11:21:43 -08005116 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005117
5118 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005119 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005120 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005121 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005122 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005123 h = 1
5124 w = 1
5125 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005126 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5127 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005128
5129 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005130 choices = [1, 2, 3]
5131 change = rng.choice(choices)
5132 # increment in multiples of stride to not hit non-integer error case
5133 if change in [1, 3]:
5134 h = h + (rng.choice(choices) * stride[0])
5135 if change in [2, 3]:
5136 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005137 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005138
5139 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005140 all_dtypes = [
5141 DType.INT8,
5142 DType.INT16,
5143 DType.INT32,
5144 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005145 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005146 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005147 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005148 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005149 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5150 outputDType = rng.choice(wrong_dtypes)
5151 else:
5152 outputDType = ifm.dtype
5153
5154 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005155
5156 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005157 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005158 # input: N, IC
5159 # filter: OC, IC
5160 # output: N, OC
5161
5162 output_shape = [input.shape[0], filter.shape[0]]
5163
James Ward8b390432022-08-12 20:48:56 +01005164 # Validated in arg_gen (also invalidated for ErrorIf)
5165 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005166
Kevin Cheng550ccc52021-03-03 11:21:43 -08005167 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005168
5169 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005170 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005171 # a: N, H, C
5172 # b: N, C, W
5173 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005174
Kevin Cheng2d60f002021-06-09 14:18:32 -07005175 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005176
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005177 if error_name == ErrorIf.WrongOutputType:
5178 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005179 incorrect_types = (
5180 DType.INT4,
5181 DType.INT8,
5182 DType.INT16,
5183 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005184 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005185 DType.FP16,
5186 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005187 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005188 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005189 incorrect_types = (
5190 DType.INT4,
5191 DType.INT8,
5192 DType.INT16,
5193 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005194 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005195 DType.FP16,
5196 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005197 )
James Ward24dbc422022-10-19 12:20:31 +01005198 elif (
5199 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5200 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005201 incorrect_types = (
5202 DType.INT4,
5203 DType.INT8,
5204 DType.INT16,
5205 DType.INT32,
5206 DType.INT48,
5207 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005208 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005209 elif error_name == ErrorIf.WrongInputType:
5210 # Pick some potentially correct output dtype if input type is incorrect
5211 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005212 else:
James Ward8b390432022-08-12 20:48:56 +01005213 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005214
Kevin Cheng550ccc52021-03-03 11:21:43 -08005215 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005216
5217 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005218 def concatOp(ser, rng, axis, inputs, error_name=None):
5219 input1 = inputs[0]
5220 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005221
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005222 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005223 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005224 if not (
5225 # unable to concat tensors of different ranks
5226 error_name == ErrorIf.ConcatInputRankMismatch
5227 # unable to concat tensors along an invalid axis
5228 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005229 ):
5230 for tensor in remaining_inputs:
5231 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005232
Matthew Haddon01c359d2021-10-15 16:30:48 +01005233 if error_name == ErrorIf.ConcatShapeSumMismatch:
5234 output_shape[axis] += rng.integers(5, 10)
5235
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005236 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005237 all_dtypes = {
5238 DType.INT8,
5239 DType.INT16,
5240 DType.INT32,
5241 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005242 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005243 DType.FP16,
5244 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005245 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005246 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5247 outputDType = rng.choice(wrong_dtypes)
5248 else:
5249 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005250
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005251 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005252
5253 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005254 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005255
5256 output_shape = a.shape.copy()
5257
5258 for i in range(len(output_shape)):
5259 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5260
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005261 if error_name == ErrorIf.PadOutputShapeMismatch:
5262 bad_dim = rng.choice(range(len(output_shape)))
5263 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005264 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005265 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005266
Matthew Haddone807aae2021-10-11 18:12:58 +01005267 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005268 all_dtypes = [
5269 DType.INT8,
5270 DType.INT16,
5271 DType.INT32,
5272 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005273 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005274 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005275 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005276 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005277 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5278 outputDType = rng.choice(wrong_dtypes)
5279 else:
5280 outputDType = a.dtype
5281
5282 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005283
5284 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005285 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005286 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005287
5288 if error_name == ErrorIf.WrongOutputType:
5289 all_dtypes = [
5290 DType.INT8,
5291 DType.INT16,
5292 DType.INT32,
5293 DType.INT48,
5294 DType.FP32,
5295 DType.FP16,
5296 DType.BF16,
5297 ]
5298 wrong_dtypes = list(set(all_dtypes))
5299 outputDType = rng.choice(wrong_dtypes)
5300 else:
5301 outputDType = DType.SHAPE
5302
5303 return ser.addOutput(output_shape, outputDType)
5304
5305 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005306 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005307 output_shape = shape.copy()
5308
Matthew Haddone807aae2021-10-11 18:12:58 +01005309 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5310 for i in range(len(output_shape)):
5311 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5312
5313 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005314 all_dtypes = [
5315 DType.INT8,
5316 DType.INT16,
5317 DType.INT32,
5318 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005319 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005320 DType.FP16,
5321 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005322 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005323 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5324 outputDType = rng.choice(wrong_dtypes)
5325 else:
5326 outputDType = a.dtype
5327
5328 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005329
5330 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005331 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005332
Matthew Haddone807aae2021-10-11 18:12:58 +01005333 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005334 all_dtypes = [
5335 DType.INT8,
5336 DType.INT16,
5337 DType.INT32,
5338 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005339 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005340 DType.FP16,
5341 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005342 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005343 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005344 outputDType = rng.choice(wrong_dtypes)
5345 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005346 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005347
Luke Huttona4e48ca2023-02-22 11:53:48 +00005348 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005349 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005350 for index in range(len(output_shape)):
5351 if output_shape[index] <= 2:
5352 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5353 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005354 output_shape[index] = output_shape[index] + rng.choice(
5355 [-2, -1, 1, 2]
5356 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005357 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5358 output_shape = input.shape.copy()
5359 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005360 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005361
5362 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005363
5364 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005365 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005366
5367 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005368 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005369
5370 for i in range(len(output_shape)):
5371 output_shape[i] = a.shape[i] * multiples[i]
5372
Luke Huttona4e48ca2023-02-22 11:53:48 +00005373 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005374 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005375
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005376 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005377 all_dtypes = [
5378 DType.INT8,
5379 DType.INT16,
5380 DType.INT32,
5381 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005382 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005383 DType.FP16,
5384 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005385 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005386 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5387 outputDType = rng.choice(wrong_dtypes)
5388 else:
5389 outputDType = a.dtype
5390
5391 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005392
5393 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005394 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005395 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005396
Kevin Cheng550ccc52021-03-03 11:21:43 -08005397 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005398
Luke Huttona4e48ca2023-02-22 11:53:48 +00005399 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005400 for i in range(len(output_shape)):
5401 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005402
Luke Huttona4e48ca2023-02-22 11:53:48 +00005403 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5404 for i in range(len(output_shape)):
5405 output_shape[i] += rng.integers(1, 10)
5406 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005407 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005408
Matthew Haddone807aae2021-10-11 18:12:58 +01005409 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005410 all_dtypes = [
5411 DType.INT8,
5412 DType.INT16,
5413 DType.INT32,
5414 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005415 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005416 DType.FP16,
5417 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005418 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005419 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5420 outputDType = rng.choice(wrong_dtypes)
5421 else:
5422 outputDType = a.dtype
5423
5424 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005425
5426 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005427 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005428 if error_name != ErrorIf.WrongRank:
5429 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005430 assert len(indices.shape) == 2
5431 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005432
Kevin Cheng77d0f762020-11-24 10:26:32 -08005433 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5434
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005435 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005436 all_dtypes = [
5437 DType.INT8,
5438 DType.INT16,
5439 DType.INT32,
5440 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005441 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005442 DType.FP16,
5443 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005444 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005445 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5446 outputDType = rng.choice(wrong_dtypes)
5447 else:
5448 outputDType = values.dtype
5449
5450 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005451
5452 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005453 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005454 if error_name != ErrorIf.WrongRank:
5455 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005456 assert len(indices.shape) == 2
5457 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005458 assert values_in.shape[0] == indices.shape[0] # N
5459 assert input.shape[1] == indices.shape[1] # W
5460 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005461
5462 output_shape = values_in.shape
5463
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005464 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005465 all_dtypes = [
5466 DType.INT8,
5467 DType.INT16,
5468 DType.INT32,
5469 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005470 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005471 DType.FP16,
5472 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005473 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005474 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5475 outputDType = rng.choice(wrong_dtypes)
5476 else:
5477 outputDType = values_in.dtype
5478
5479 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005480
5481 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005482 def tableOp(ser, rng, input, error_name=None):
5483 # Same shape as the input, dtype dependent on input dtype
5484 if error_name != ErrorIf.WrongInputType:
5485 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005486 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005487 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005488 wrong_dtypes = [
5489 DType.INT8,
5490 DType.INT16,
5491 DType.INT32,
5492 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005493 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005494 DType.FP16,
5495 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005496 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005497 wrong_dtypes.remove(output_dtype)
5498 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005499 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005500
5501 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005502 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005503 serializer,
5504 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005505 input,
5506 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005507 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005508 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005509 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005510 input_dtype,
5511 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005512 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005513 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005514 # Calculate OH, OW
5515 scale_y_n = scale[0]
5516 scale_y_d = scale[1]
5517 scale_x_n = scale[2]
5518 scale_x_d = scale[3]
5519 if error_name == ErrorIf.ScaleSmallerEqualZero:
5520 scale_y_n = max(scale_y_n, 1)
5521 scale_y_d = max(scale_y_d, 1)
5522 scale_x_n = max(scale_x_n, 1)
5523 scale_x_d = max(scale_x_d, 1)
5524
5525 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5526 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5527
5528 if error_name is not None:
5529 # Make sure the output tensor is valid, which can occur when
5530 # scale, offset or border have been changed for ERROR_IFs
5531 oh = max(oh, 1)
5532 ow = max(ow, 1)
5533 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005534 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5535 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005536
5537 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5538 choices = [1, 2, 3]
5539 change = rng.choice(choices)
5540 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5541 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005542 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005543 oh -= scale_y_d
5544 assert oh > 0 # Should have been caught in agResize
5545 else:
5546 oh += scale_y_d
5547 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005548 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005549 ow -= scale_x_d
5550 assert ow > 0 # Should have been caught in agResize
5551 else:
5552 ow += scale_x_d
5553
Matthew Haddon848efb42021-09-09 12:30:53 +01005554 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005555 output_dims = [
5556 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005557 oh,
5558 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005559 input.shape[0],
5560 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005561 elif error_name == ErrorIf.BatchMismatch:
5562 output_dims = [
5563 input.shape[0] + rng.integers(1, 10),
5564 oh,
5565 ow,
5566 input.shape[3],
5567 ]
5568 elif error_name == ErrorIf.ChannelMismatch:
5569 output_dims = [
5570 input.shape[0],
5571 oh,
5572 ow,
5573 input.shape[3] + rng.integers(1, 10),
5574 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005575 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005576 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005577
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005578 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005579
5580 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005581 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005582 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005583
5584 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005585 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005586 if error_name == ErrorIf.ConvOutputShapeMismatch:
5587 choices = [1, 2, 3]
5588 change = rng.choice(choices)
5589 if change in [1, 3]:
5590 output_shape[1] = output_shape[1] + rng.choice(choices)
5591 if change in [2, 3]:
5592 output_shape[2] = output_shape[2] + rng.choice(choices)
5593
James Ward8b390432022-08-12 20:48:56 +01005594 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005595 # Pick some potentially correct output dtype if input type is incorrect
5596 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005597 else:
James Ward8b390432022-08-12 20:48:56 +01005598 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005599
5600 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005601 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005602 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005603 else:
5604 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005605 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005606 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005607
Kevin Cheng550ccc52021-03-03 11:21:43 -08005608 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005609
5610 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005611 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5612 outputs = []
5613
5614 assert ifm1.dtype == ifm2.dtype
5615 input_dtype = ifm1.dtype
5616
5617 if error_name != ErrorIf.FFTInputShapeMismatch:
5618 assert ifm1.shape == ifm2.shape
5619
5620 input_shape = ifm1.shape
5621 if error_name != ErrorIf.WrongRank:
5622 assert len(input_shape) == 3
5623
5624 output_shape = input_shape.copy()
5625 output_dtype = input_dtype
5626
5627 if error_name == ErrorIf.WrongOutputType:
5628 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005629 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005630 output_dtype = rng.choice(wrong_dtypes)
5631 elif error_name == ErrorIf.BatchMismatch:
5632 output_shape[0] += rng.integers(1, 10)
5633 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5634 modify_dim = rng.choice([1, 2])
5635 output_shape[modify_dim] += rng.integers(1, 10)
5636
5637 outputs.append(serializer.addOutput(output_shape, output_dtype))
5638 outputs.append(serializer.addOutput(output_shape, output_dtype))
5639 return outputs
5640
5641 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005642 def rfft2dOp(serializer, rng, value, error_name=None):
5643 outputs = []
5644
5645 input_shape = value.shape
5646 if error_name != ErrorIf.WrongRank:
5647 assert len(input_shape) == 3
5648
5649 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5650
5651 output_dtype = value.dtype
5652 if error_name == ErrorIf.WrongOutputType:
5653 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005654 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005655 output_dtype = rng.choice(wrong_dtypes)
5656 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005657 output_shape[0] += rng.integers(1, 10)
5658 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5659 modify_dim = rng.choice([1, 2])
5660 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005661
5662 outputs.append(serializer.addOutput(output_shape, output_dtype))
5663 outputs.append(serializer.addOutput(output_shape, output_dtype))
5664 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005665
5666 @staticmethod
5667 def addShapeOp(ser, rng, a, b, error_name=None):
5668 if error_name != ErrorIf.RankMismatch:
5669 assert len(a.shape) == len(b.shape)
5670 assert a.dtype == b.dtype
5671
5672 shape = []
5673 for i in range(len(a.shape)):
5674 shape.append(a.shape[i])
5675
5676 fuzz_idx = rng.integers(0, len(a.shape))
5677 if error_name == ErrorIf.DimensionMismatch:
5678 shape[fuzz_idx] += 1
5679
5680 if error_name == ErrorIf.WrongOutputType:
5681 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5682 outputDType = rng.choice(wrong_dtypes)
5683 else:
5684 outputDType = DType.SHAPE
5685 return ser.addOutput(shape, outputDType)