blob: 49d9f1bf53e8e0e422770a037c8c7246e2db756c [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
79 for dtype in (DType.FP32, DType.FP16, DType.BF16):
80 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000198 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 return np.int64(self.rng.integers(low=low, high=high, size=shape))
200 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
201 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
202
203 if dtype == DType.FP16:
204 return np.float16(f_tensor)
205 else:
206 f32_tensor = np.float32(f_tensor)
207 if dtype == DType.BF16:
208 # Floor the last 16 bits of each f32 value
209 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
210 else:
211 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100213 # All other integer types
214 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
Kevin Cheng989cb052021-04-28 16:29:44 -0700216 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700217 placeholders = []
218
Kevin Cheng989cb052021-04-28 16:29:44 -0700219 assert len(shape_list) == len(dtype_list)
220
Jeremy Johnson1271c442023-09-05 11:39:26 +0100221 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700222 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 if not self.args.lazy_data_gen:
224 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700225 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700226
227 return placeholders
228
Kevin Cheng989cb052021-04-28 16:29:44 -0700229 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700230 consts = []
231
Kevin Cheng989cb052021-04-28 16:29:44 -0700232 assert len(shape_list) == len(dtype_list)
233
Jeremy Johnson1271c442023-09-05 11:39:26 +0100234 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700235 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100236 if not self.args.lazy_data_gen:
237 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700238 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700239
240 return consts
241
242 def makeShape(self, rank):
243 if self.targetted_shape:
244 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 return np.int32(
246 self.rng.integers(
247 low=self.args.tensor_shape_range[0],
248 high=self.args.tensor_shape_range[1],
249 size=rank,
250 )
251 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700252
253 def setTargetShape(self, shape):
254 self.targetted_shape = shape
255
256 def randInt(self, low=0, high=256):
257 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
258
259 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 low, high = self.getDTypeRange(dtype)
261
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100262 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100263 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100264 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100265 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100266 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100267 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
268 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700269 elif dtype == DType.BOOL:
270 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000271 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700272 # Special size
273 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 return np.int32(self.rng.integers(low, high, size=1))[0]
276
277 def shapeStr(self, shape):
278
279 sStr = []
280 # Convert to strings
281 for i in shape:
282 sStr.append(str(i))
283
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100286 def typeStr(self, dtype):
287 if isinstance(dtype, list) or isinstance(dtype, tuple):
288 assert len(dtype) >= 2
289 strs = [self.typeStr(t) for t in dtype]
290 # Limit types to the first 2 as the 3rd is the accumulator
291 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100293 if dtype in gtu.DTYPE_ATTRIBUTES:
294 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700295 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100296 raise Exception(
297 "Unknown dtype, cannot convert to string: {}".format(dtype)
298 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100300 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100301 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100302 if dtype in gtu.DTYPE_ATTRIBUTES:
303 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100305 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700306
Luke Hutton57287132023-02-06 14:54:18 +0000307 def constrictBatchSize(self, shape):
308 # Limit the batch size unless an explicit target shape set
309 if self.args.max_batch_size and not self.args.target_shapes:
310 shape[0] = min(shape[0], self.args.max_batch_size)
311 return shape
312
James Ward30124a82023-02-02 14:56:33 +0000313 def makeDimension(self):
314 return self.randInt(
315 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
316 )
317
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100318 def tensorComplianceMetaData(
319 self, op, inputType, argsDict, outputTensor, errorName
320 ):
Jeremy Johnson708da822023-11-15 16:25:45 +0000321 # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet
322 UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100323 if (
324 errorName
325 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000326 or (
327 not gtu.dtypeIsSupportedByCompliance(inputType)
328 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
329 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100330 ):
331 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100332 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100333
Jeremy Johnson1271c442023-09-05 11:39:26 +0100334 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100335 compliance_tens = {
336 "mode": None,
337 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
338 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
339 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100340 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
341 mode = gtu.ComplianceMode.DOT_PRODUCT
342 compliance_tens["dot_product_info"] = {
343 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100344 "ks": int(argsDict["ksb"])
345 if "ksb" in argsDict
346 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100347 }
348 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
349 mode = gtu.ComplianceMode.FP_SPECIAL
350 elif "compliance" in op and "ulp" in op["compliance"]:
351 mode = gtu.ComplianceMode.ULP
352 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
353 elif op["op"] == Op.REDUCE_PRODUCT:
354 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000355 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000356 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000357 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000358 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
359 compliance_tens["abs_error_info"] = {
360 "lower_bound": op["compliance"]["abs_error_lower_bound"]
361 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100362 else:
363 mode = gtu.ComplianceMode.EXACT
364 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
365
366 return compliance_tens
367
368 # Build Op functions
369 # Create the output tensor (calling OutputShaper as needed)
370 # Do final tweaks to attributes (if necessary for errorIf)
371 # Add Op into graph
372 # Return resulting tensor information or BuildInfo
373
374 class BuildInfo:
375 """Enhanced build information containing result tensor and associated compliance dict."""
376
377 def __init__(self, resultTensor, complianceDict):
378 self.resultTensor = resultTensor
379 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700380
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000381 def build_unary(
382 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
383 ):
384 assert len(inputs) == 1
385 a = inputs[0]
386 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100387
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000388 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100389
390 # Ensure new output type has correct qinfo
391 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000392 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000393 qinfo = [
394 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000395 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000396 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100397
398 # Invalidate Input/Output list for error if checks.
399 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000400 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100401 pCount, cCount = op["operands"]
402 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000403 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
404 self, error_name, input_list, output_list
405 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100406
Les Bell729b0352021-11-24 10:28:21 +0000407 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100408 self.ser,
409 validator_fcns,
410 error_name,
411 op=op,
412 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000413 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000414 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000415 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100416 input_list=input_list,
417 output_list=output_list,
418 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000419 ):
420 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100421
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000422 attr = None
423 if op["op"] == Op.NEGATE:
424 attr = ts.TosaSerializerAttribute()
425 attr.NegateAttribute(qinfo[0], qinfo[1])
426
427 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000428
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000429 compliance = self.tensorComplianceMetaData(
430 op, a.dtype, args_dict, result_tensor, error_name
431 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000432 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700433
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000434 def build_binary_broadcast(
435 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
436 ):
437 assert len(inputs) == 2
438 a, b = inputs
439 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000440 self.ser, self.rng, a, b, error_name
441 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100442
443 # Invalidate Input/Output list for error if checks.
444 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000445 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100446 pCount, cCount = op["operands"]
447 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000448 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
449 self, error_name, input_list, output_list
450 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100451
Les Bell729b0352021-11-24 10:28:21 +0000452 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100453 self.ser,
454 validator_fcns,
455 error_name,
456 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000457 input1=a,
458 input2=b,
459 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000460 output_dtype=result_tensor.dtype,
461 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100462 input_list=input_list,
463 output_list=output_list,
464 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000465 ):
466 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100467
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000468 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000469
Jeremy Johnson9a758382023-11-07 16:27:35 +0000470 compliance = self.tensorComplianceMetaData(
471 op, a.dtype, args_dict, result_tensor, error_name
472 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000473
474 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700475
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100476 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700477 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000478 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700479 return result_tens
480
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000481 def build_arithmetic_right_shift(
482 self, op, a, b, round, validator_fcns=None, error_name=None
483 ):
484 result_tens = OutputShaper.binaryBroadcastOp(
485 self.ser, self.rng, a, b, error_name
486 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100487
488 # Invalidate Input/Output list for error if checks.
489 input_list = [a.name, b.name]
490 output_list = [result_tens.name]
491 pCount, cCount = op["operands"]
492 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000493 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
494 self, error_name, input_list, output_list
495 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100496
Les Bell729b0352021-11-24 10:28:21 +0000497 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100498 self.ser,
499 validator_fcns,
500 error_name,
501 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000502 input1=a,
503 input2=b,
504 input_dtype=a.dtype,
505 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000506 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100507 input_list=input_list,
508 output_list=output_list,
509 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000510 ):
511 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800512
513 attr = ts.TosaSerializerAttribute()
514 attr.ArithmeticRightShiftAttribute(round)
515
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000516 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800517 return result_tens
518
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100519 def build_mul(
520 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
521 ):
522 assert len(inputs) == 2
523 a, b = inputs
524 shift = args_dict["shift"]
525
526 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000527 self.ser, self.rng, a, b, error_name
528 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700529
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100530 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100531 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100532 result_tensor.setDtype(DType.INT32)
533
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100534 if error_name == ErrorIf.WrongOutputType:
535 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
536 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100537 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100538
539 # Invalidate Input/Output list for error if checks.
540 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100541 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100542 pCount, cCount = op["operands"]
543 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000544 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
545 self, error_name, input_list, output_list
546 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100547
Les Bell729b0352021-11-24 10:28:21 +0000548 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100549 self.ser,
550 validator_fcns,
551 error_name,
552 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000553 input1=a,
554 input2=b,
555 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100556 output_dtype=result_tensor.dtype,
557 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100558 input_list=input_list,
559 output_list=output_list,
560 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000561 ):
562 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700563
Kevin Chengaee1fac2020-11-11 13:54:06 -0800564 attr = ts.TosaSerializerAttribute()
565 attr.MulAttribute(shift)
566
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000567 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100568
569 compliance = self.tensorComplianceMetaData(
570 op, a.dtype, args_dict, result_tensor, error_name
571 )
572
573 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700574
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100575 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
576 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700577
Kevin Chengfe392ce2021-10-18 21:51:55 +0000578 attr = ts.TosaSerializerAttribute()
579 attr.TableAttribute(table)
580
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100581 # Invalidate Input/Output list for error if checks.
582 input_list = [a.name]
583 output_list = [result_tens.name]
584 pCount, cCount = op["operands"]
585 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000586 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
587 self, error_name, input_list, output_list
588 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100589
Les Bell729b0352021-11-24 10:28:21 +0000590 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100591 self.ser,
592 validator_fcns,
593 error_name,
594 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000595 input_shape=a.shape,
596 input_dtype=a.dtype,
597 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000598 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100599 input_list=input_list,
600 output_list=output_list,
601 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000602 ):
603 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100604
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000605 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700606
607 return result_tens
608
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000609 def build_select(
610 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
611 ):
612 assert len(inputs) == 3
613 cond, a, b = inputs
614
615 result_tensor = OutputShaper.selectOp(
616 self.ser, self.rng, cond, a, b, error_name
617 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100618
619 # Invalidate Input/Output list for error if checks.
620 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000621 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622 pCount, cCount = op["operands"]
623 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000624 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
625 self, error_name, input_list, output_list
626 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100627
Les Bell729b0352021-11-24 10:28:21 +0000628 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100629 self.ser,
630 validator_fcns,
631 error_name,
632 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000633 input1=cond,
634 input2=a,
635 input3=b,
636 input_shape=a.shape,
637 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000638 output_dtype=result_tensor.dtype,
639 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100640 input_list=input_list,
641 output_list=output_list,
642 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000643 ):
644 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100645
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000646 self.ser.addOperator(
647 op["op"],
648 input_list,
649 output_list,
650 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000651 compliance = self.tensorComplianceMetaData(
652 op, a.dtype, args_dict, result_tensor, error_name
653 )
654
655 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700656
Jeremy Johnsona0150012023-11-15 15:52:06 +0000657 def build_comparison(
658 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
659 ):
660 assert len(inputs) == 2
661 a, b = inputs
662
663 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000664 self.ser, self.rng, a, b, error_name
665 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100666
667 # Invalidate Input/Output list for error if checks.
668 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000669 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100670 pCount, cCount = op["operands"]
671 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000672 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
673 self, error_name, input_list, output_list
674 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100675
Les Bell729b0352021-11-24 10:28:21 +0000676 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100677 self.ser,
678 validator_fcns,
679 error_name,
680 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000681 input1=a,
682 input2=b,
683 input_shape=a.shape,
684 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000685 output_shape=result_tensor.shape,
686 output_dtype=result_tensor.dtype,
687 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100688 input_list=input_list,
689 output_list=output_list,
690 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000691 ):
692 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100693
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000694 self.ser.addOperator(
695 op["op"],
696 input_list,
697 output_list,
698 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000699
700 compliance = self.tensorComplianceMetaData(
701 op, a.dtype, args_dict, result_tensor, error_name
702 )
703 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700704
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000705 def build_argmax(
706 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
707 ):
708 assert len(inputs) == 1
709 a = inputs[0]
710 axis = args_dict["axis"]
711 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100712
713 # Invalidate Input/Output list for error if checks.
714 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000715 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100716 pCount, cCount = op["operands"]
717 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000718 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
719 self, error_name, input_list, output_list
720 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100721
Les Bell729b0352021-11-24 10:28:21 +0000722 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100723 self.ser,
724 validator_fcns,
725 error_name,
726 op=op,
727 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000728 input_shape=a.shape,
729 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000730 output_shape=result_tensor.shape,
731 output_dtype=result_tensor.dtype,
732 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100733 input_list=input_list,
734 output_list=output_list,
735 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000736 ):
737 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700738
739 attr = ts.TosaSerializerAttribute()
740 attr.AxisAttribute(axis)
741
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000742 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000743
744 compliance = self.tensorComplianceMetaData(
745 op, inputs[0].dtype, args_dict, result_tensor, error_name
746 )
747 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700748
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000749 def build_pool2d(
750 self,
751 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100752 inputs,
753 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000754 validator_fcns=None,
755 error_name=None,
756 qinfo=None,
757 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100758 assert len(inputs) == 1
759 input = inputs[0]
760 # max_pool has no accum_dtype
761 accum_dtype = (
762 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
763 )
764 stride = args_dict["stride"]
765 pad = args_dict["pad"]
766 kernel = args_dict["kernel"]
767
Jeremy Johnson0601f802023-11-08 16:28:09 +0000768 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000769 self.ser, self.rng, input, kernel, stride, pad, error_name
770 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100771
772 # Ensure new output type has correct qinfo
773 if error_name == ErrorIf.WrongInputType:
774 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000775 qinfo = [
776 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000777 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000778 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100779
780 # Invalidate Input/Output list for error if checks.
781 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000782 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100783 pCount, cCount = op["operands"]
784 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000785 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
786 self, error_name, input_list, output_list
787 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100788
Les Bell729b0352021-11-24 10:28:21 +0000789 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100790 self.ser,
791 validator_fcns,
792 error_name,
793 op=op,
794 input_shape=input.shape,
795 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000796 output_shape=result_tensor.shape,
797 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100798 kernel=kernel,
799 stride=stride,
800 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000801 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000802 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100803 input_list=input_list,
804 output_list=output_list,
805 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000806 ):
807 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700808
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000809 if qinfo is None:
810 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700811
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000812 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100813 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000814
815 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700816
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100817 compliance = self.tensorComplianceMetaData(
818 op, inputs[0].dtype, args_dict, result_tensor, error_name
819 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100820
821 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100822
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000823 def build_conv2d(
824 self,
825 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100826 inputs,
827 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000828 validator_fcns=None,
829 error_name=None,
830 qinfo=None,
831 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100832 assert len(inputs) == 3
833 ifm, filter, bias = inputs
834 accum_dtype = args_dict["acc_type"]
835 strides = args_dict["stride"]
836 padding = args_dict["pad"]
837 dilations = args_dict["dilation"]
838
Kevin Cheng550ccc52021-03-03 11:21:43 -0800839 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100840 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100841 self.ser,
842 self.rng,
843 ifm,
844 filter,
845 accum_dtype,
846 strides,
847 padding,
848 dilations,
849 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000850 )
851
852 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000853 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
854 DType.INT8,
855 DType.UINT8,
856 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000857 qinfo = [
858 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100859 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000860 ]
Les Bell0e027d42021-11-09 14:42:14 +0000861
862 # Invalidate Input/Output list for error_if checks.
863 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100864 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000865 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000866 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
867 self, error_name, input_list, output_list
868 )
Les Bell0e027d42021-11-09 14:42:14 +0000869
Les Bell729b0352021-11-24 10:28:21 +0000870 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000871 self.ser,
872 validator_fcns,
873 error_name,
874 op=op,
875 input_dtype=ifm.dtype,
876 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100877 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000878 qinfo=qinfo,
879 input_list=input_list,
880 num_operands=num_operands,
881 output_list=output_list,
882 pad=padding,
883 stride=strides,
884 dilation=dilations,
885 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100886 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100887 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000888 ):
889 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700890
Tai Lyd3797f02023-11-15 23:06:19 +0000891 # TODO - Test local_bound, for now set local bound attribute to False
892 local_bound = False
893
Eric Kunzee5e26762020-10-13 16:11:07 -0700894 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000895 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700896
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000897 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100898
899 compliance = self.tensorComplianceMetaData(
900 op, ifm.dtype, args_dict, result_tensor, error_name
901 )
902
903 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700904
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000905 def build_conv3d(
906 self,
907 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100908 inputs,
909 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000910 validator_fcns=None,
911 error_name=None,
912 qinfo=None,
913 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100914 assert len(inputs) == 3
915 ifm, filter, bias = inputs
916 accum_dtype = args_dict["acc_type"]
917 strides = args_dict["stride"]
918 padding = args_dict["pad"]
919 dilations = args_dict["dilation"]
920
Kevin Cheng1533b852021-09-01 12:51:58 -0700921 assert len(padding) == 6
922 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100923 self.ser,
924 self.rng,
925 ifm,
926 filter,
927 accum_dtype,
928 strides,
929 padding,
930 dilations,
931 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000932 )
933
934 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000935 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
936 DType.INT8,
937 DType.UINT8,
938 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000939 qinfo = [
940 TosaQuantGen.getZeroPoint(self, ifm.dtype),
941 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
942 ]
Les Bell0e027d42021-11-09 14:42:14 +0000943
944 # Invalidate Input/Output list for error_if checks.
945 input_list = [ifm.name, filter.name, bias.name]
946 output_list = [result_tens.name]
947 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000948 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
949 self, error_name, input_list, output_list
950 )
Les Bell0e027d42021-11-09 14:42:14 +0000951
Les Bell729b0352021-11-24 10:28:21 +0000952 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000953 self.ser,
954 validator_fcns,
955 error_name,
956 op=op,
957 input_dtype=ifm.dtype,
958 weight_dtype=filter.dtype,
959 output_dtype=result_tens.dtype,
960 qinfo=qinfo,
961 input_list=input_list,
962 num_operands=num_operands,
963 output_list=output_list,
964 pad=padding,
965 stride=strides,
966 dilation=dilations,
967 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100968 weight_shape=filter.shape,
969 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000970 ):
971 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700972
Tai Lyd3797f02023-11-15 23:06:19 +0000973 # TODO - Test local_bound, for now set local bound attribute to False
974 local_bound = False
975
Kevin Cheng1533b852021-09-01 12:51:58 -0700976 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000977 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700978
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000979 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700980 return result_tens
981
Kevin Cheng550ccc52021-03-03 11:21:43 -0800982 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000983 self,
984 op,
985 ifm,
986 filter,
987 bias,
James Ward8b390432022-08-12 20:48:56 +0100988 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000989 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700990 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000991 output_shape,
992 validator_fcns=None,
993 error_name=None,
994 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800995 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700996 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000997 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100998 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000999 )
Les Bell0e027d42021-11-09 14:42:14 +00001000
1001 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001002 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1003 DType.INT8,
1004 DType.UINT8,
1005 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001006 qinfo = [
1007 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1008 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1009 ]
Les Bell0e027d42021-11-09 14:42:14 +00001010
1011 # Invalidate Input/Output list for error_if checks.
1012 input_list = [ifm.name, filter.name, bias.name]
1013 output_list = [result_tens.name]
1014 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001015 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1016 self, error_name, input_list, output_list
1017 )
Les Bell0e027d42021-11-09 14:42:14 +00001018
Les Bell729b0352021-11-24 10:28:21 +00001019 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001020 self.ser,
1021 validator_fcns,
1022 error_name,
1023 op=op,
1024 input_dtype=ifm.dtype,
1025 weight_dtype=filter.dtype,
1026 output_dtype=result_tens.dtype,
1027 qinfo=qinfo,
1028 input_list=input_list,
1029 num_operands=num_operands,
1030 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001031 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001032 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001033 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001034 weight_shape=filter.shape,
1035 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001036 ):
1037 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001038
Tai Lyd3797f02023-11-15 23:06:19 +00001039 # TODO - Test local_bound, for now set local bound attribute to False
1040 local_bound = False
1041
Eric Kunzee5e26762020-10-13 16:11:07 -07001042 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001043 attr.TransposeConvAttribute(
1044 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1045 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001046
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001047 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001048 return result_tens
1049
Kevin Cheng550ccc52021-03-03 11:21:43 -08001050 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001051 self,
1052 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001053 inputs,
1054 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001055 validator_fcns=None,
1056 error_name=None,
1057 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001058 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001059 assert len(inputs) == 3
1060 ifm, filter, bias = inputs
1061 accum_dtype = args_dict["acc_type"]
1062 strides = args_dict["stride"]
1063 padding = args_dict["pad"]
1064 dilations = args_dict["dilation"]
1065
Kevin Cheng550ccc52021-03-03 11:21:43 -08001066 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001067 self.ser,
1068 self.rng,
1069 ifm,
1070 filter,
1071 accum_dtype,
1072 strides,
1073 padding,
1074 dilations,
1075 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001076 )
1077
1078 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001079 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1080 DType.INT8,
1081 DType.UINT8,
1082 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001083 qinfo = [
1084 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1085 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1086 ]
Les Bell0e027d42021-11-09 14:42:14 +00001087
1088 # Invalidate Input/Output list for error_if checks.
1089 input_list = [ifm.name, filter.name, bias.name]
1090 output_list = [result_tens.name]
1091 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001092 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1093 self, error_name, input_list, output_list
1094 )
Les Bell0e027d42021-11-09 14:42:14 +00001095
Les Bell729b0352021-11-24 10:28:21 +00001096 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001097 self.ser,
1098 validator_fcns,
1099 error_name,
1100 op=op,
1101 input_dtype=ifm.dtype,
1102 weight_dtype=filter.dtype,
1103 output_dtype=result_tens.dtype,
1104 qinfo=qinfo,
1105 input_list=input_list,
1106 num_operands=num_operands,
1107 output_list=output_list,
1108 pad=padding,
1109 stride=strides,
1110 dilation=dilations,
1111 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001112 weight_shape=filter.shape,
1113 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001114 ):
1115 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001116
Tai Lyd3797f02023-11-15 23:06:19 +00001117 # TODO - Test local_bound, for now set local bound attribute to False
1118 local_bound = False
1119
Eric Kunzee5e26762020-10-13 16:11:07 -07001120 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001121 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001122
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001123 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001124 return result_tens
1125
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001126 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001127 self,
1128 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001129 inputs,
1130 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001131 validator_fcns=None,
1132 error_name=None,
1133 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001134 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001135 assert len(inputs) == 3
1136 ifm, filter, bias = inputs
1137 accum_dtype = args_dict["acc_type"]
1138
1139 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001140 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001141 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001142
1143 # Invalidate Input/Output list for error if checks.
1144 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001145 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001146 pCount, cCount = op["operands"]
1147 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001148 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1149 self, error_name, input_list, output_list
1150 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001151
Les Bell729b0352021-11-24 10:28:21 +00001152 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001153 self.ser,
1154 validator_fcns,
1155 error_name,
1156 op=op,
1157 input_shape=ifm.shape,
1158 input_dtype=ifm.dtype,
1159 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001160 output_shape=result_tensor.shape,
1161 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001162 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001163 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001164 input_list=input_list,
1165 output_list=output_list,
1166 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001167 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001168 ):
1169 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001170
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001171 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001172 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001173
1174 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001175
1176 compliance = self.tensorComplianceMetaData(
1177 op, ifm.dtype, args_dict, result_tensor, error_name
1178 )
1179
1180 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001181
James Ward8b390432022-08-12 20:48:56 +01001182 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001183 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001184 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001185 assert len(inputs) == 2
1186 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001187 accum_dtype = args_dict["acc_type"]
1188 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001189 self.ser, self.rng, a, b, accum_dtype, error_name
1190 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001191
1192 # Invalidate Input/Output list for error if checks.
1193 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001194 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001195 pCount, cCount = op["operands"]
1196 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001197 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1198 self, error_name, input_list, output_list
1199 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001200
Les Bell729b0352021-11-24 10:28:21 +00001201 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001202 self.ser,
1203 validator_fcns,
1204 error_name,
1205 op=op,
1206 input_shape=a.shape,
1207 input_dtype=a.dtype,
1208 input2_shape=b.shape,
1209 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001210 output_shape=result_tensor.shape,
1211 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001212 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001213 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001214 input_list=input_list,
1215 output_list=output_list,
1216 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001217 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001218 ):
1219 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001220
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001221 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001222 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001223
1224 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001225
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001226 compliance = self.tensorComplianceMetaData(
1227 op, a.dtype, args_dict, result_tensor, error_name
1228 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001229
1230 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001231
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001232 def build_reduce(
1233 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1234 ):
1235 assert len(inputs) == 1
1236 a = inputs[0]
1237 axis = args_dict["axis"]
1238 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001239
1240 # Invalidate Input/Output list for error if checks.
1241 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001242 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001243 pCount, cCount = op["operands"]
1244 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001245 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1246 self, error_name, input_list, output_list
1247 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001248
Les Bell729b0352021-11-24 10:28:21 +00001249 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001250 self.ser,
1251 validator_fcns,
1252 error_name,
1253 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001254 axis=axis,
1255 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001256 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001257 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001258 output_dtype=result_tensor.dtype,
1259 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001260 input_list=input_list,
1261 output_list=output_list,
1262 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001263 ):
1264 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001265
1266 attr = ts.TosaSerializerAttribute()
1267 attr.AxisAttribute(axis)
1268
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001269 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001270
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001271 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1272 # Number of products - needed for compliance
1273 args_dict["n"] = a.shape[axis]
1274
1275 compliance = self.tensorComplianceMetaData(
1276 op, a.dtype, args_dict, result_tensor, error_name
1277 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001278
1279 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001280
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001281 def build_clamp(
1282 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1283 ):
1284 assert len(inputs) == 1
1285 a = inputs[0]
1286
1287 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001288
Jeremy Johnson18e26662021-07-22 16:15:29 +01001289 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001290
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001291 if error_name == ErrorIf.MaxSmallerMin:
1292 # Make sure the numbers are different to invoke this error
1293 while v[0] == v[1]:
1294 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1295 max_val = min(v)
1296 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001297 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001298 max_val = max(v)
1299 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001300
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001301 # Invalidate Input/Output list for error if checks.
1302 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001303 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001304 pCount, cCount = op["operands"]
1305 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001306 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1307 self, error_name, input_list, output_list
1308 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001309
Les Bell729b0352021-11-24 10:28:21 +00001310 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001311 self.ser,
1312 validator_fcns,
1313 error_name,
1314 op=op,
1315 max_val=max_val,
1316 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001317 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001318 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001319 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001320 output_dtype=result_tensor.dtype,
1321 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001322 input_list=input_list,
1323 output_list=output_list,
1324 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001325 ):
1326 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001327
1328 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001329 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1330 if a.dtype == DType.FP16:
1331 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1332 min_val = min_val.astype(np.float32)
1333 max_val = max_val.astype(np.float32)
1334
1335 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001336 else:
James Ward34071252022-12-07 15:48:47 +00001337 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001338
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001339 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001340
1341 compliance = self.tensorComplianceMetaData(
1342 op, a.dtype, args_dict, result_tensor, error_name
1343 )
1344
1345 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001346
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001347 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1348 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001349 attr = ts.TosaSerializerAttribute()
1350
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001351 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001352
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001353 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001354 return result_tens
1355
1356 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001357 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1358 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001359
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001360 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001361 return result_tens
1362
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001363 def build_activation(
1364 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1365 ):
1366 assert len(inputs) == 1
1367 a = inputs[0]
1368
1369 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001370
1371 # Invalidate Input/Output list for error if checks.
1372 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001373 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374 pCount, cCount = op["operands"]
1375 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001376 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1377 self, error_name, input_list, output_list
1378 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001379
Les Bell729b0352021-11-24 10:28:21 +00001380 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001381 self.ser,
1382 validator_fcns,
1383 error_name,
1384 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001385 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001386 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001387 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001388 output_dtype=result_tensor.dtype,
1389 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001390 input_list=input_list,
1391 output_list=output_list,
1392 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001393 ):
1394 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001395
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001396 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001397
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001398 compliance = self.tensorComplianceMetaData(
1399 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001400 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001401
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001402 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001403
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001404 def build_concat(
1405 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1406 ):
Won Jeon74342e52024-01-09 00:34:40 +00001407 if op["op"] == Op.CONCAT_SHAPE:
1408 axis = 0
1409 else:
1410 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001411 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001412 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001413
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001414 result_tensor = OutputShaper.concatOp(
1415 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001416 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001417
Matthew Haddon818ab902021-07-27 09:12:49 +01001418 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001419 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001420 input_tensor_names.append(tensor.name)
1421
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001422 # Invalidate Input/Output list for error if checks.
1423 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001424 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001425 pCount, cCount = op["operands"]
1426 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001427 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1428 self, error_name, input_list, output_list
1429 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001430
Les Bell729b0352021-11-24 10:28:21 +00001431 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432 self.ser,
1433 validator_fcns,
1434 error_name,
1435 op=op,
1436 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001437 input_shape=inputs[0].shape,
1438 output_shape=result_tensor.shape,
1439 input_dtype=inputs[0].dtype,
1440 output_dtype=result_tensor.dtype,
1441 inputs=inputs,
1442 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001443 input_list=input_list,
1444 output_list=output_list,
1445 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001446 ):
1447 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001448
Won Jeon74342e52024-01-09 00:34:40 +00001449 if op["op"] == Op.CONCAT:
1450 attr = ts.TosaSerializerAttribute()
1451 attr.AxisAttribute(axis)
1452 else:
1453 assert op["op"] == Op.CONCAT_SHAPE
1454 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001455 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001456
1457 compliance = self.tensorComplianceMetaData(
1458 op, inputs[0].dtype, args_dict, result_tensor, error_name
1459 )
1460
1461 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001462
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001463 def build_pad(
1464 self,
1465 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001466 inputs,
1467 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001468 validator_fcns=None,
1469 error_name=None,
1470 qinfo=None,
1471 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001472 assert len(inputs) == 1
1473 a = inputs[0]
1474 padding = args_dict["pad"]
1475 pad_const_int = args_dict["pad_const_int"]
1476 pad_const_float = args_dict["pad_const_fp"]
1477
1478 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001479
Kevin Chengfe392ce2021-10-18 21:51:55 +00001480 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001481 attr.PadAttribute(
1482 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1483 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001484
Matthew Haddone807aae2021-10-11 18:12:58 +01001485 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001486 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001487 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001488 pCount, cCount = op["operands"]
1489 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001490 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1491 self, error_name, input_list, output_list
1492 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001493
Les Bell729b0352021-11-24 10:28:21 +00001494 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001495 self.ser,
1496 validator_fcns,
1497 error_name,
1498 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001499 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001500 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001501 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001502 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001503 pad=padding,
1504 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001505 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001506 input_list=input_list,
1507 output_list=output_list,
1508 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001509 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001510 ):
1511 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001512
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001513 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001514
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001515 compliance = self.tensorComplianceMetaData(
1516 op, a.dtype, args_dict, result_tensor, error_name
1517 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001518
1519 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001520
Won Jeona21b2e82023-08-10 10:33:01 +00001521 def build_dim(
1522 self,
1523 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001524 inputs,
1525 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001526 validator_fcns=None,
1527 error_name=None,
1528 qinfo=None,
1529 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001530 assert len(inputs) == 1
1531 a = inputs[0]
1532 axis = args_dict["axis"]
1533 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001534
1535 # Invalidate Input/Output list for error if checks.
1536 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001537 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001538 pCount, cCount = op["operands"]
1539 num_operands = pCount + cCount
1540 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1541 self, error_name, input_list, output_list
1542 )
1543
1544 if not TosaErrorValidator.evValidateErrorIfs(
1545 self.ser,
1546 validator_fcns,
1547 error_name,
1548 op=op,
1549 axis=axis,
1550 input_shape=a.shape,
1551 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001552 output_shape=result_tensor.shape,
1553 output_dtype=result_tensor.dtype,
1554 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001555 input_list=input_list,
1556 output_list=output_list,
1557 num_operands=num_operands,
1558 ):
1559 return None
1560
1561 attr = ts.TosaSerializerAttribute()
1562 attr.AxisAttribute(axis)
1563
1564 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001565 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001566
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001567 def build_reshape(
1568 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1569 ):
Tai Ly8690a082023-12-18 20:40:24 +00001570 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001571 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001572 shape = inputs[1]
1573 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001574 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001575 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001576 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001577
1578 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001579 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001580 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001581 pCount, cCount = op["operands"]
1582 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001583 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1584 self, error_name, input_list, output_list
1585 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001586
Les Bell729b0352021-11-24 10:28:21 +00001587 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001588 self.ser,
1589 validator_fcns,
1590 error_name,
1591 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001592 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001593 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001594 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001595 output_dtype=result_tensor.dtype,
1596 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001597 input_list=input_list,
1598 output_list=output_list,
1599 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001600 ):
1601 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001602
Tai Ly8690a082023-12-18 20:40:24 +00001603 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001604
1605 compliance = self.tensorComplianceMetaData(
1606 op, a.dtype, args_dict, result_tensor, error_name
1607 )
1608
1609 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001610
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001611 def build_reverse(
1612 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1613 ):
1614 assert len(inputs) == 1
1615 a = inputs[0]
1616 axis = args_dict["axis"]
1617 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001618
1619 # Invalidate Input/Output list for error if checks.
1620 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001621 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001622 pCount, cCount = op["operands"]
1623 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001624 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1625 self, error_name, input_list, output_list
1626 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001627
Les Bell729b0352021-11-24 10:28:21 +00001628 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001629 self.ser,
1630 validator_fcns,
1631 error_name,
1632 op=op,
1633 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001634 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001635 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001636 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001637 output_dtype=result_tensor.dtype,
1638 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001639 input_list=input_list,
1640 output_list=output_list,
1641 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001642 ):
1643 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001644
1645 attr = ts.TosaSerializerAttribute()
1646 attr.AxisAttribute(axis)
1647
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001648 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001649 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001650
Matthew Haddone807aae2021-10-11 18:12:58 +01001651 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1652 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001653
Kevin Chengfe392ce2021-10-18 21:51:55 +00001654 attr = ts.TosaSerializerAttribute()
1655 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001656
Matthew Haddone807aae2021-10-11 18:12:58 +01001657 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001658 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001659 output_list = [result_tens.name]
1660 pCount, cCount = op["operands"]
1661 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001662 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1663 self, error_name, input_list, output_list
1664 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001665
Les Bell729b0352021-11-24 10:28:21 +00001666 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001667 self.ser,
1668 validator_fcns,
1669 error_name,
1670 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001671 input_shape=a.shape,
1672 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001673 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001674 input_dtype=a.dtype,
1675 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001676 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001677 input_list=input_list,
1678 output_list=output_list,
1679 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001680 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001681 ):
1682 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001683
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001684 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001685 return result_tens
1686
evacha017f7d4252024-01-24 12:08:09 +00001687 def build_slice(
1688 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1689 ):
1690 assert len(inputs) == 1
1691 a = inputs[0]
1692 start = args_dict["start"]
1693 size = args_dict["size"]
1694
1695 result_tensor = OutputShaper.sliceOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001696 self.ser, self.rng, a, start, size, error_name
1697 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001698
1699 # Invalidate Input/Output list for error if checks.
1700 input_list = [a.name]
evacha017f7d4252024-01-24 12:08:09 +00001701 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001702 pCount, cCount = op["operands"]
1703 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001704 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1705 self, error_name, input_list, output_list
1706 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001707
Les Bell729b0352021-11-24 10:28:21 +00001708 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001709 self.ser,
1710 validator_fcns,
1711 error_name,
1712 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001713 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001714 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001715 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001716 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001717 start=start,
1718 size=size,
evacha017f7d4252024-01-24 12:08:09 +00001719 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001720 input_list=input_list,
1721 output_list=output_list,
1722 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001723 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001724 ):
1725 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001726
1727 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001728 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001729
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001730 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha017f7d4252024-01-24 12:08:09 +00001731
1732 compliance = self.tensorComplianceMetaData(
1733 op, a.dtype, args_dict, result_tensor, error_name
1734 )
1735
1736 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001737
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001738 def build_tile(
1739 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1740 ):
Tai Ly8690a082023-12-18 20:40:24 +00001741 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001742 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001743 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001744 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001745 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001746 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001747 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001748
1749 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001750 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001751 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001752 pCount, cCount = op["operands"]
1753 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001754 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1755 self, error_name, input_list, output_list
1756 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001757
Les Bell729b0352021-11-24 10:28:21 +00001758 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001759 self.ser,
1760 validator_fcns,
1761 error_name,
1762 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001763 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001764 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001765 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001766 output_dtype=result_tensor.dtype,
1767 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001768 input_list=input_list,
1769 output_list=output_list,
1770 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001771 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001772 ):
1773 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001774
Tai Ly8690a082023-12-18 20:40:24 +00001775 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001776
1777 compliance = self.tensorComplianceMetaData(
1778 op, a.dtype, args_dict, result_tensor, error_name
1779 )
1780
1781 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001782
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001783 def build_gather(
1784 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1785 ):
1786 assert len(inputs) == 2
1787 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001788
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001789 result_tensor = OutputShaper.gatherOp(
1790 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001791 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001792
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001793 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001794 input_list = [values.name, indices.name]
1795 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001796 pCount, cCount = op["operands"]
1797 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001798 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1799 self, error_name, input_list, output_list
1800 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001801
Les Bell729b0352021-11-24 10:28:21 +00001802 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001803 self.ser,
1804 validator_fcns,
1805 error_name,
1806 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001807 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001808 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001809 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001810 output_dtype=result_tensor.dtype,
1811 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001812 input_list=input_list,
1813 output_list=output_list,
1814 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001815 ):
1816 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001817
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001818 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001819
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001820 compliance = self.tensorComplianceMetaData(
1821 op, values.dtype, args_dict, result_tensor, error_name
1822 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001823
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001824 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001825
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001826 def build_scatter(
1827 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1828 ):
1829 assert len(inputs) == 3
1830 values_in, indices, input = inputs
1831 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001832 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001833 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001834
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001835 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001836 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001837 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001838 pCount, cCount = op["operands"]
1839 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001840 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1841 self, error_name, input_list, output_list
1842 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001843
Les Bell729b0352021-11-24 10:28:21 +00001844 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001845 self.ser,
1846 validator_fcns,
1847 error_name,
1848 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001849 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001850 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001851 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001852 output_dtype=result_tensor.dtype,
1853 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001854 input_list=input_list,
1855 output_list=output_list,
1856 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001857 ):
1858 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001859
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001860 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001861
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001862 compliance = self.tensorComplianceMetaData(
1863 op, values_in.dtype, args_dict, result_tensor, error_name
1864 )
1865
1866 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001867
Kevin Cheng550ccc52021-03-03 11:21:43 -08001868 def build_resize(
1869 self,
1870 op,
1871 input,
1872 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001873 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001874 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001875 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001876 input_dtype,
1877 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001878 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001879 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001880 ):
1881 result_tens = OutputShaper.resizeOp(
1882 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001883 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001884 input,
1885 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001886 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001887 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001888 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001889 input_dtype,
1890 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001891 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001892 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001893
Matthew Haddon848efb42021-09-09 12:30:53 +01001894 # Invalidate Input/Output list for error if checks.
1895 input_list = [input.name]
1896 output_list = [result_tens.name]
1897 pCount, cCount = op["operands"]
1898 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001899 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1900 self, error_name, input_list, output_list
1901 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001902
Les Bell729b0352021-11-24 10:28:21 +00001903 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001904 self.ser,
1905 validator_fcns,
1906 error_name,
1907 op=op,
1908 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001909 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001910 input_dtype=input_dtype,
1911 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001912 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001913 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001914 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001915 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001916 input_list=input_list,
1917 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001918 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001919 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001920 ):
1921 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001922
Eric Kunzee5e26762020-10-13 16:11:07 -07001923 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001924
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001925 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001926
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001927 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001928 return result_tens
1929
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001930 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1931 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1932 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001933 self.ser.addOperator(
1934 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1935 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001936 return result_tens
1937
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001938 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001939 self.ser.addOutputTensor(val)
1940 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001941
1942 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00001943 def build_cast(
1944 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1945 ):
1946 assert len(inputs) == 1
1947 val = inputs[0]
1948 out_dtype = args_dict["out_type"]
1949
1950 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001951 self.ser, self.rng, val, out_dtype, error_name
1952 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001953
1954 # Invalidate Input/Output list for error if checks.
1955 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00001956 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001957 pCount, cCount = op["operands"]
1958 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001959 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1960 self, error_name, input_list, output_list
1961 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001962
Les Bell729b0352021-11-24 10:28:21 +00001963 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001964 self.ser,
1965 validator_fcns,
1966 error_name,
1967 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001968 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00001969 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001970 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00001971 output_dtype=result_tensor.dtype,
1972 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001973 input_list=input_list,
1974 output_list=output_list,
1975 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001976 ):
1977 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001978
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001979 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00001980
1981 compliance = self.tensorComplianceMetaData(
1982 op, val.dtype, args_dict, result_tensor, error_name
1983 )
1984
1985 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001986
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001987 def build_rescale(
1988 self,
1989 op,
1990 val,
1991 out_dtype,
1992 scale32,
1993 double_round,
1994 per_channel,
1995 validator_fcns,
1996 error_name,
1997 ):
1998 result_tens = OutputShaper.typeConversionOp(
1999 self.ser, self.rng, val, out_dtype, error_name
2000 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002001
2002 if per_channel:
2003 nc = val.shape[-1]
2004 else:
2005 nc = 1
2006
2007 in_type_width = self.typeWidth(val.dtype)
2008 out_type_width = self.typeWidth(out_dtype)
2009
Tai Ly8690a082023-12-18 20:40:24 +00002010 input_unsigned = False
2011 output_unsigned = False
2012
Kevin Cheng3a478572021-01-22 17:21:02 -08002013 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002014 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002015 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002016 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002017 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002018 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002019 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002020 elif error_name in [
2021 ErrorIf.InputZeroPointNotZero,
2022 ErrorIf.U16InputZeroPointNotValid,
2023 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002024 input_zp = self.randInt(-128, 128)
2025 if input_zp == 0:
2026 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002027 in_type_width += 1
2028 elif val.dtype == DType.UINT16:
2029 # Must come after ErrorIf.U16InputZeroPointNotValid check
2030 input_zp = self.rng.choice([0, 32768])
2031 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002032 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002033 else:
2034 input_zp = 0
2035
Kevin Cheng3a478572021-01-22 17:21:02 -08002036 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002037 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002038 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002039 elif out_dtype == DType.UINT8:
2040 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002041 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002042 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002043 elif error_name in [
2044 ErrorIf.OutputZeroPointNotZero,
2045 ErrorIf.U16OutputZeroPointNotValid,
2046 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002047 output_zp = self.randInt(-128, 128)
2048 if output_zp == 0:
2049 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002050 out_type_width += 1
2051 elif out_dtype == DType.UINT16:
2052 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2053 output_zp = self.rng.choice([0, 32768])
2054 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002055 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002056 else:
2057 output_zp = 0
2058
2059 # Calculate scale based on:
2060 # scale = a *(2^output_width)/(2^input_width))
2061
2062 a = np.float32(self.rng.random(size=[nc]))
2063 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2064
2065 if scale32:
2066 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002067 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002068 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2069 else:
2070 # Cap the scaling at 2^15 - 1 for scale16
2071 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2072
Kevin Cheng550ccc52021-03-03 11:21:43 -08002073 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002074
2075 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2076 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002077 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2078 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002079
2080 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002081 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2082 scale_arr[i], scale32
2083 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002084 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2085 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002086
Kevin Cheng550ccc52021-03-03 11:21:43 -08002087 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002088 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002089 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002090 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002091 assert val.placeholderFilename
2092 values = np.load(
2093 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2094 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002095 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2096 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2097 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002098 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2099 # Check we can safely convert to the expected dtype
2100 assert (
2101 val_adj.all() >= np.iinfo(values.dtype).min
2102 and val_adj.all() <= np.iinfo(values.dtype).max
2103 )
2104
2105 # Force casting to output datatype
2106 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2107
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002108 if not np.all(np.array_equal(values, val_adj)):
2109 # Values changed so overwrite file with new values
2110 np.save(
2111 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2112 val_adj,
2113 False,
2114 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002115
Matthew Haddonc2025212021-10-08 21:21:05 +01002116 # Invalidate Input/Output list for error if checks.
2117 input_list = [val.name]
2118 output_list = [result_tens.name]
2119 pCount, cCount = op["operands"]
2120 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002121 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2122 self, error_name, input_list, output_list
2123 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002124
2125 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002126 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002127 self.ser,
2128 validator_fcns,
2129 error_name,
2130 op=op,
2131 input_dtype=val.dtype,
2132 output_dtype=out_dtype,
2133 input_shape=val.shape,
2134 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002135 scale32=scale32,
2136 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002137 input_list=input_list,
2138 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002139 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002140 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002141 ):
2142 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002143
Eric Kunzee5e26762020-10-13 16:11:07 -07002144 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002145 attr.RescaleAttribute(
2146 input_zp,
2147 output_zp,
2148 multiplier_arr,
2149 shift_arr,
2150 scale32,
2151 double_round,
2152 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002153 input_unsigned,
2154 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002155 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002156
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002157 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002158 return result_tens
2159
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002160 def _get_condition_tensor(self, op, cond, error_name):
2161 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002162 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002163 else:
2164 cond_type = DType.BOOL
2165 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2166 choice = self.rng.choice([1, 2])
2167 if choice == 1:
2168 cond_shape = [2]
2169 else:
2170 cond_shape = [1, 2]
2171 else:
2172 # Must be of size 1 (rank 0)
2173 cond_shape = []
2174 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2175 return cond_tens
2176
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002177 def build_cond_if_const(
2178 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2179 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002180 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002181 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002182 # and fill them with const nodes for the body.
2183
2184 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002185 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002186
2187 # Make then/else tensors
2188 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002189
2190 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002191 if error_name in [
2192 ErrorIf.CondIfOutputListThenGraphMismatch,
2193 ErrorIf.CondIfOutputListElseGraphMismatch,
2194 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002195 incorrect_shape = deepcopy(then_tens.shape)
2196 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002197 incorrect_shape[i] += (
2198 self.rng.choice([-3, -2, 2, 3])
2199 if incorrect_shape[i] > 3
2200 else self.rng.choice([1, 2, 4])
2201 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002202 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2203
Jeremy Johnson18e26662021-07-22 16:15:29 +01002204 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2205 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002206
2207 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002208 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002209
2210 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002211 then_block = "THEN_BLOCK"
2212 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002213 attr = ts.TosaSerializerAttribute()
2214 attr.CondIfAttribute(then_block, else_block)
2215
2216 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002217 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002218
Jerry Ge9e94af82022-10-27 09:57:00 -07002219 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002220 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002221 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2222 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2223 else:
2224 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002225 self.ser.addOutputTensor(then_tens)
2226
Jerry Ge9e94af82022-10-27 09:57:00 -07002227 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002228 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2229 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2230 else:
2231 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002232 self.ser.addOutputTensor(else_tens)
2233
Les Bell729b0352021-11-24 10:28:21 +00002234 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002235 self.ser,
2236 validator_fcns,
2237 error_name,
2238 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002239 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002240 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002241 ):
2242 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002243
Eric Kunzee5e26762020-10-13 16:11:07 -07002244 return result_tens
2245
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002246 def build_cond_if_binary(
2247 self, op, a, b, cond, validator_fcns=None, error_name=None
2248 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002249 # For cond_if with a binary op in the then/else blocks, take a and b and
2250 # alternately add or subtract them based on the condition
2251
2252 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002253 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002254
Kevin Cheng550ccc52021-03-03 11:21:43 -08002255 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002256
2257 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002258 then_block = "THEN_BLOCK"
2259 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002260 attr = ts.TosaSerializerAttribute()
2261 attr.CondIfAttribute(then_block, else_block)
2262
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002263 if error_name in [
2264 ErrorIf.CondIfInputListThenGraphMismatch,
2265 ErrorIf.CondIfInputListElseGraphMismatch,
2266 ErrorIf.CondIfOutputListElseGraphMismatch,
2267 ErrorIf.CondIfOutputListThenGraphMismatch,
2268 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002269 incorrect_shape = a.shape.copy()
2270 for i in range(len(incorrect_shape)):
2271 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2272 incorrect_block_input = deepcopy(a)
2273 incorrect_block_input.shape = incorrect_shape
2274
Eric Kunzee5e26762020-10-13 16:11:07 -07002275 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002276 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002277 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002278 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002279
James Ward24dbc422022-10-19 12:20:31 +01002280 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002281 then_op, else_op = Op.ADD, Op.SUB
2282 elif a.dtype in (DType.INT8, DType.INT16):
2283 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2284 else:
2285 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002286
Les Bell6040b4d2021-10-11 12:50:31 +01002287 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002288 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002289 if (
2290 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2291 and block == then_block
2292 ) or (
2293 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2294 and block == else_block
2295 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002296 self.ser.addInputTensor(incorrect_block_input)
2297 self.ser.addInputTensor(b)
2298 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002299 elif (
2300 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2301 and block == then_block
2302 ) or (
2303 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2304 and block == else_block
2305 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002306 self.ser.addInputTensor(a)
2307 self.ser.addInputTensor(b)
2308 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2309 else:
2310 self.ser.addInputTensor(a)
2311 self.ser.addInputTensor(b)
2312 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002313 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002314
Les Bell729b0352021-11-24 10:28:21 +00002315 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002316 self.ser,
2317 validator_fcns,
2318 error_name,
2319 op=op,
2320 a=a,
2321 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002322 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002323 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002324 ):
2325 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002326
Eric Kunzee5e26762020-10-13 16:11:07 -07002327 return result_tens
2328
Matthew Haddon630c17c2021-10-14 15:05:41 +01002329 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002330 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002331
Kevin Cheng550ccc52021-03-03 11:21:43 -08002332 cond_block = "COND_BLOCK"
2333 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002334
2335 attr = ts.TosaSerializerAttribute()
2336 attr.WhileLoopAttribute(cond_block, body_block)
2337
2338 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002339 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002340 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002341 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002342
2343 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002344 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2345 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002346 if error_name == ErrorIf.InputListOutputListMismatch:
2347 incorrect_acc = deepcopy(acc)
2348 for i in range(len(incorrect_acc.shape)):
2349 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2350 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2351 else:
2352 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002353
2354 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002355 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002356 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002357 [iter.name, a.name, acc.name],
2358 [iter_out.name, a_out.name, acc_out.name],
2359 attr,
2360 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002361 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002362
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002363 if error_name in [
2364 ErrorIf.InputListCondGraphMismatch,
2365 ErrorIf.InputListBodyGraphInputMismatch,
2366 ErrorIf.InputListBodyGraphOutputMismatch,
2367 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002368 incorrect_iter = deepcopy(iter)
2369 for i in range(len(incorrect_iter.shape)):
2370 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2371 if len(incorrect_iter.shape) == 0:
2372 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2373
2374 incorrect_acc = deepcopy(acc)
2375 for i in range(len(incorrect_acc.shape)):
2376 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2377
Eric Kunzee5e26762020-10-13 16:11:07 -07002378 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002379 self.ser.addBasicBlock(cond_block)
2380
Matthew Haddon630c17c2021-10-14 15:05:41 +01002381 if error_name == ErrorIf.InputListCondGraphMismatch:
2382 self.ser.addInputTensor(incorrect_iter)
2383 self.ser.addInputTensor(a)
2384 self.ser.addInputTensor(incorrect_acc)
2385 else:
2386 self.ser.addInputTensor(iter)
2387 self.ser.addInputTensor(a)
2388 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002389 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002390
2391 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002392 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002393 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002394 cond_type = DType.BOOL
2395 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2396 choice = self.rng.choice([1, 2])
2397 if choice == 1:
2398 cond_shape = [3]
2399 else:
2400 cond_shape = [1, 2]
2401 else:
2402 cond_shape = []
2403 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002404
Kevin Cheng550ccc52021-03-03 11:21:43 -08002405 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002406
2407 # BODY block (input: a, acc, iter, output: a, acc, iter)
2408 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002409 self.ser.addBasicBlock(body_block)
2410
Matthew Haddon630c17c2021-10-14 15:05:41 +01002411 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2412 self.ser.addInputTensor(incorrect_iter)
2413 self.ser.addInputTensor(a)
2414 self.ser.addInputTensor(incorrect_acc)
2415 else:
2416 self.ser.addInputTensor(iter)
2417 self.ser.addInputTensor(a)
2418 self.ser.addInputTensor(acc)
2419
Kevin Cheng550ccc52021-03-03 11:21:43 -08002420 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002421
2422 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002423 iter_body_out = self.ser.addIntermediate(
2424 incorrect_iter.shape, incorrect_iter.dtype
2425 )
2426 acc_body_out = self.ser.addIntermediate(
2427 incorrect_acc.shape, incorrect_acc.dtype
2428 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002429 else:
2430 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2431 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2432
Eric Kunzee5e26762020-10-13 16:11:07 -07002433 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2434 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2435 self.ser.addOutputTensor(iter_body_out)
2436 self.ser.addOutputTensor(a)
2437 self.ser.addOutputTensor(acc_body_out)
2438
Les Bell729b0352021-11-24 10:28:21 +00002439 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002440 self.ser,
2441 validator_fcns,
2442 error_name,
2443 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002444 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002445 ):
2446 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002447
Eric Kunzee5e26762020-10-13 16:11:07 -07002448 return acc_out
2449
Luke Hutton57287132023-02-06 14:54:18 +00002450 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002451 self,
2452 op,
2453 val1,
2454 val2,
2455 inverse,
2456 validator_fcns=None,
2457 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002458 ):
2459 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2460
2461 input_names = [val1.name, val2.name]
2462 pCount, cCount = op["operands"]
2463 num_operands = pCount + cCount
2464
2465 output_names = [res.name for res in results]
2466 output_shapes = [res.shape for res in results]
2467 output_dtypes = [res.dtype for res in results]
2468
2469 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2470 self, error_name, input_names, output_names
2471 )
2472
2473 if not TosaErrorValidator.evValidateErrorIfs(
2474 self.ser,
2475 validator_fcns,
2476 error_name,
2477 op=op,
2478 inverse=inverse,
2479 input1=val1,
2480 input2=val2,
2481 input_shape=val1.shape,
2482 input_dtype=val1.dtype,
2483 output_shape=output_shapes,
2484 output_dtype=output_dtypes,
2485 result_tensors=results,
2486 input_list=input_names,
2487 output_list=output_names,
2488 num_operands=num_operands,
2489 ):
2490 return None
2491
Tai Lyd3797f02023-11-15 23:06:19 +00002492 # TODO - Test local_bound, for now set local bound attribute to False
2493 local_bound = False
2494
Luke Hutton57287132023-02-06 14:54:18 +00002495 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002496 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002497
2498 self.ser.addOperator(op["op"], input_names, output_names, attr)
2499 return results
2500
Tai Lyd3797f02023-11-15 23:06:19 +00002501 def build_rfft2d(
2502 self,
2503 op,
2504 val,
2505 validator_fcns=None,
2506 error_name=None,
2507 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002508 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2509
2510 input_names = [val.name]
2511 pCount, cCount = op["operands"]
2512 num_operands = pCount + cCount
2513
2514 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002515 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002516 output_dtypes = [res.dtype for res in results]
2517
2518 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2519 self, error_name, input_names, output_names
2520 )
2521
2522 if not TosaErrorValidator.evValidateErrorIfs(
2523 self.ser,
2524 validator_fcns,
2525 error_name,
2526 op=op,
2527 input_shape=val.shape,
2528 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002529 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002530 output_dtype=output_dtypes,
2531 result_tensors=results,
2532 input_list=input_names,
2533 output_list=output_names,
2534 num_operands=num_operands,
2535 ):
2536 return None
2537
Tai Lyd3797f02023-11-15 23:06:19 +00002538 # TODO - Test local_bound, for now set local bound attribute to False
2539 local_bound = False
2540
2541 attr = ts.TosaSerializerAttribute()
2542 attr.RFFTAttribute(local_bound)
2543
2544 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002545 return results
2546
Won Jeon74342e52024-01-09 00:34:40 +00002547 def build_shape_op(
2548 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2549 ):
2550 assert len(inputs) == 2
2551 a, b = inputs
2552
2553 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2554
2555 # Invalidate Input/Output list for error if checks.
2556 input_list = [a.name, b.name]
2557 output_list = [result_tensor.name]
2558 pCount, cCount = op["operands"]
2559 num_operands = pCount + cCount
2560 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2561 self, error_name, input_list, output_list
2562 )
2563
2564 if not TosaErrorValidator.evValidateErrorIfs(
2565 self.ser,
2566 validator_fcns,
2567 error_name,
2568 op=op,
2569 input1=a,
2570 input2=b,
2571 input_shape=a.shape,
2572 input_dtype=a.dtype,
2573 output_shape=result_tensor.shape,
2574 output_dtype=result_tensor.dtype,
2575 result_tensors=[result_tensor],
2576 input_list=input_list,
2577 output_list=output_list,
2578 num_operands=num_operands,
2579 ):
2580 return None
2581
2582 self.ser.addOperator(
2583 op["op"],
2584 input_list,
2585 output_list,
2586 )
2587 compliance = self.tensorComplianceMetaData(
2588 op, a.dtype, args_dict, result_tensor, error_name
2589 )
2590
2591 return TosaTestGen.BuildInfo(result_tensor, compliance)
2592
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002593 def create_filter_lists(
2594 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2595 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002596 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2597 default_test_rank_range = range(1, 5)
2598 if not shapeFilter:
2599 shapeFilter = [None]
2600
2601 # Calculate the filters based on what is requested and what the operator allows
2602 rmin, rmax = op["rank"]
2603 if rankFilter is not None:
2604 cleanRankFilter = []
2605 # Ensure rankFilter values are allowed by operator
2606 for rank in rankFilter:
2607 if rank >= rmin and rank <= rmax:
2608 cleanRankFilter.append(rank)
2609 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002610 # Ensure default behaviour is bounded by default range or by operator,
2611 # whichever is the smaller range of ranks.
2612 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002613 cleanRankFilter = (
2614 opRankRange
2615 if len(opRankRange) <= len(default_test_rank_range)
2616 else default_test_rank_range
2617 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002618 else:
2619 cleanRankFilter = range(rmin, rmax + 1)
2620
2621 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002622
Matthew Haddon1c00b712021-10-01 15:51:03 +01002623 if dtypeFilter is not None:
2624 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002625 # Create list of operator dtypes filtered by requested dtypes
2626 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002627 if dtype in dtypeFilter or (
2628 isinstance(dtype, list) and dtype[0] in dtypeFilter
2629 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002630 cleanDtypeFilter.append(dtype)
2631 else:
2632 cleanDtypeFilter = dtypes
2633
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002634 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002635 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002636 "shapeFilter": shapeFilter,
2637 "rankFilter": cleanRankFilter,
2638 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002639 }
2640 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002641 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002642 if validator is not None:
2643 validator_info = validator(check=False, op=op)
2644 else:
2645 return None
2646
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002647 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002648
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002649 # Set parameters as required
2650 if error_arguments["rank"] is not None:
2651 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002652 else:
2653 rankFilter = cleanRankFilter
2654
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002655 if error_arguments["dtype"] is not None:
2656 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002657 else:
2658 dtypeFilter = cleanDtypeFilter
2659
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002660 if error_arguments["shape"] is not None:
2661 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002662 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002663 shapeFilter = shapeFilter[
2664 :2
2665 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002666
2667 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002668 "shapeFilter": shapeFilter,
2669 "rankFilter": rankFilter,
2670 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002671 }
2672 return filterDict
2673
Kevin Cheng550ccc52021-03-03 11:21:43 -08002674 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002675 self,
2676 opName,
2677 shapeFilter=[None],
2678 rankFilter=None,
2679 dtypeFilter=None,
2680 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002681 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002682
2683 try:
2684 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002685 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002686 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002687
2688 # Initialize a new random number generator
2689 self.rng = np.random.default_rng(self.random_seed)
2690
Jeremy Johnson1271c442023-09-05 11:39:26 +01002691 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002692
Eric Kunzee5e26762020-10-13 16:11:07 -07002693 # Test list consists of a tuple of:
2694 # (opName, testNameStr, dtype, shapeList, argumentsList)
2695 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002696 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002697 error_if_validators = op["error_if_validators"]
2698 else:
2699 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002700
Matthew Haddon1c00b712021-10-01 15:51:03 +01002701 for validator in error_if_validators:
2702 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002703 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002704 else:
2705 error_name = None
2706
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002707 filterDict = self.create_filter_lists(
2708 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2709 )
2710 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002711 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002712 cleanRankFilter = filterDict["rankFilter"]
2713 cleanDtypeFilter = filterDict["dtypeFilter"]
2714 cleanShapeFilter = filterDict["shapeFilter"]
2715 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002716
2717 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002718 for t in cleanDtypeFilter:
2719 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002720 # Filter out by rank
2721 if shape is not None and len(shape) != r:
2722 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002723 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002724 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002725
Matthew Haddon74567092021-07-16 15:38:20 +01002726 shapeStr = self.shapeStr(shapeList[0])
2727 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002728
Matthew Haddon74567092021-07-16 15:38:20 +01002729 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2730 argList = []
2731 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002732 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002733 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002734 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002735
Matthew Haddon74567092021-07-16 15:38:20 +01002736 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002737 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002738 if argStr:
2739 testStr = "{}_{}_{}_{}".format(
2740 opName, shapeStr, typeStr, argStr
2741 )
2742 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002743 testStr = "{}_{}_{}".format(
2744 opName, shapeStr, typeStr
2745 )
2746 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002747 if argStr:
2748 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2749 opName, error_name, shapeStr, typeStr, argStr
2750 )
2751 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002752 testStr = "{}_ERRORIF_{}_{}_{}".format(
2753 opName, error_name, shapeStr, typeStr
2754 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002755
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002756 testList.append(
2757 (opName, testStr, t, error_name, shapeList, args)
2758 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002759
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002760 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002761 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2762 if "invalid_test_validators" in op:
2763 invalid_test_validators = op["invalid_test_validators"]
2764 clean_testList = []
2765 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002766 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002767 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002768 if validator_fcn(
2769 opName=test[0],
2770 input_dtype=test[2],
2771 shapeList=test[4],
2772 args=test[5],
2773 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002774 remove_test = True
2775 if not remove_test:
2776 clean_testList.append(test)
2777 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002778
2779 return testList
2780
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002781 def serializeTest(
2782 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2783 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002784 try:
2785 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002786 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002787 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002788
Jeremy Johnson0c716862023-04-13 17:18:19 +01002789 if self.args.verbose:
2790 print(f"Creating {testStr}")
2791
Eric Kunzee5e26762020-10-13 16:11:07 -07002792 # Create a serializer
2793 self.createSerializer(opName, testStr)
2794
Jeremy Johnson1271c442023-09-05 11:39:26 +01002795 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002796 if "error_if_validators" in op:
2797 error_if_validators = op["error_if_validators"]
2798 else:
2799 error_if_validators = None
2800
Kevin Cheng550ccc52021-03-03 11:21:43 -08002801 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002802 num_operands = pCount + cCount
2803
2804 if isinstance(dtype_or_dtypeList, list):
2805 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00002806 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002807 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002808 else:
2809 dtypeList = [dtype_or_dtypeList] * (num_operands)
2810
Won Jeon74342e52024-01-09 00:34:40 +00002811 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002812 assert (
2813 len(shapeList) == num_operands
2814 ), "shapeList length {} must match number of operands {}".format(
2815 len(shapeList), num_operands
2816 )
2817 assert (
2818 len(dtypeList) == num_operands
2819 ), "dtypeList length {} must match number of operands {}".format(
2820 len(dtypeList), num_operands
2821 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002822
2823 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002824 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002825 except KeyError:
2826 qgen = None
2827
2828 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002829
Matthew Haddon1c00b712021-10-01 15:51:03 +01002830 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002831 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002832 else:
2833 qinfo = None
2834
Jeremy Johnson1271c442023-09-05 11:39:26 +01002835 # Extra meta data for the desc.json
2836 tensMeta = {}
2837
2838 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002839 if isinstance(testArgs, dict):
2840 # New interface with args info in dictionary
2841 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002842 assert "dg_type" in argsDict
2843 tvgInfo = tvgen_fcn(
2844 self, opName, dtypeList, shapeList, argsDict, error_name
2845 )
2846 if tvgInfo.dataGenDict:
2847 tensMeta["data_gen"] = tvgInfo.dataGenDict
2848 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002849
2850 result = build_fcn(
2851 self,
2852 op,
2853 tens,
2854 argsDict,
2855 validator_fcns=error_if_validators,
2856 error_name=error_name,
2857 qinfo=qinfo,
2858 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002859 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002860 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002861 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002862
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002863 try:
2864 if error_if_validators is None:
2865 if qinfo is not None:
2866 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2867 else:
2868 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002869 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002870 if qinfo is not None:
2871 result = build_fcn(
2872 self,
2873 op,
2874 *tens,
2875 *testArgs,
2876 validator_fcns=error_if_validators,
2877 error_name=error_name,
2878 qinfo=qinfo,
2879 )
2880 else:
2881 result = build_fcn(
2882 self,
2883 op,
2884 *tens,
2885 *testArgs,
2886 validator_fcns=error_if_validators,
2887 error_name=error_name,
2888 )
2889 except TypeError as e:
2890 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2891 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002892
Jeremy Johnson1271c442023-09-05 11:39:26 +01002893 if result:
Les Bell729b0352021-11-24 10:28:21 +00002894 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002895 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2896 # Add the compliance meta data
2897 # NOTE: This currently expects only one result output
2898 tensMeta["compliance"] = {
2899 "version": "0.1",
2900 "tensors": {result.resultTensor.name: result.complianceDict},
2901 }
2902 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002903 else:
2904 # The test is not valid
2905 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002906
Eric Kunzee5e26762020-10-13 16:11:07 -07002907 def createDynamicOpLists(self):
2908
Jeremy Johnson00423432022-09-12 17:27:37 +01002909 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2910 # Already created these lists (can occur when class is initialized more than once)
2911 return
2912
Eric Kunzee5e26762020-10-13 16:11:07 -07002913 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002914 if not self.args.level8k:
2915 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2916 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2917 else:
2918 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2919 KERNELS_2D = [[1, bigK], [bigK, 2]]
2920 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002921
Kevin Cheng1533b852021-09-01 12:51:58 -07002922 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002923 testName = "conv2d_{}x{}".format(k[0], k[1])
2924 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2925 self.TOSA_OP_LIST[testName]["filter"] = k
2926 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002927
Kevin Cheng550ccc52021-03-03 11:21:43 -08002928 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2929 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2930 "depthwise_conv2d_TEMPLATE"
2931 ].copy()
2932 self.TOSA_OP_LIST[testName]["filter"] = k
2933 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002934
Kevin Cheng550ccc52021-03-03 11:21:43 -08002935 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2936 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2937 "transpose_conv2d_TEMPLATE"
2938 ].copy()
2939 self.TOSA_OP_LIST[testName]["filter"] = k
2940 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002941
Kevin Cheng1533b852021-09-01 12:51:58 -07002942 for k in KERNELS_3D:
2943 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2944 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2945 self.TOSA_OP_LIST[testName]["filter"] = k
2946 self.TOSA_OP_LIST[testName]["template"] = False
2947
Eric Kunzee5e26762020-10-13 16:11:07 -07002948 # Delete any templates after having created any dynamic ops
2949 # This is a two-pass operation because it's bad practice to delete
2950 # keys from dictionaries while iterating
2951 keyList = []
2952 for k in self.TOSA_OP_LIST:
2953 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002954 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002955 keyList.append(k)
2956 continue
2957 except KeyError:
2958 pass
2959
2960 for k in keyList:
2961 del self.TOSA_OP_LIST[k]
2962
2963 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002964 """Fill in default fields for ops if they aren't already specified.
2965 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002966 for op in self.TOSA_OP_LIST:
2967
2968 # Required fields
2969 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002970 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002971 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002972 raise Exception(
2973 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2974 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002975
2976 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002977 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002978 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002979 raise Exception(
2980 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2981 op
2982 )
2983 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002984
2985 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002986 _ = self.TOSA_OP_LIST[op]["types"]
2987 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002988 raise Exception(
2989 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2990 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002991
2992 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002993 _ = self.TOSA_OP_LIST[op]["op"]
2994 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002995 raise Exception(
2996 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2997 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002998
2999 # Put in default rank range, if missing
3000 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003001 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003002 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003003 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003004
3005 # Tensor operator list
3006 # 'op': op name
3007 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003008 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3009 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003010 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3011 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003012 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003013
Kevin Cheng550ccc52021-03-03 11:21:43 -08003014 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003015 TYPE_INT_FP = [
3016 DType.INT8,
3017 DType.INT16,
3018 DType.INT32,
3019 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003020 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003021 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003022 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003023
Kevin Cheng550ccc52021-03-03 11:21:43 -08003024 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003025 TYPE_FI32 = [
3026 DType.FP32,
3027 DType.FP16,
3028 DType.BF16,
3029 DType.INT32,
3030 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003031 TYPE_FIB = [
3032 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003033 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003034 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003035 DType.INT8,
3036 DType.INT16,
3037 DType.INT32,
3038 DType.BOOL,
3039 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003040 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003041
James Ward24dbc422022-10-19 12:20:31 +01003042 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07003043
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003044 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003045 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003046 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003047 [DType.INT8, DType.INT8, DType.INT32],
3048 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003049 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003050 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003051 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003052 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003053 ]
3054
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003055 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003056
3057 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003058 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003059 "argmax": {
3060 "op": Op.ARGMAX,
3061 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003062 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003063 "build_fcn": (
3064 build_argmax,
3065 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003066 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003067 TosaArgGen.agAxis,
3068 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003069 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003070 "error_if_validators": (
3071 TosaErrorValidator.evAxisSmallerZero,
3072 TosaErrorValidator.evAxisLargerRank,
3073 TosaErrorValidator.evArgmaxOutputRankMismatch,
3074 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3075 TosaErrorValidator.evWrongRank,
3076 TosaErrorValidator.evWrongInputType,
3077 TosaErrorValidator.evWrongOutputType,
3078 TosaErrorValidator.evWrongInputList,
3079 TosaErrorValidator.evWrongOutputList,
3080 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003081 "data_gen": {
3082 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3083 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003084 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003085 "avg_pool2d": {
3086 "op": Op.AVG_POOL2D,
3087 "operands": (1, 0),
3088 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003089 "build_fcn": (
3090 build_pool2d,
3091 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003092 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003093 TosaArgGen.agPooling,
3094 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003095 "qgen": TosaQuantGen.qgUnary,
3096 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003097 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003098 "error_if_validators": (
3099 TosaErrorValidator.evKernelSmallerOne,
3100 TosaErrorValidator.evStrideSmallerOne,
3101 TosaErrorValidator.evPadSmallerZero,
3102 TosaErrorValidator.evWrongRank,
3103 TosaErrorValidator.evWrongInputType,
3104 TosaErrorValidator.evWrongOutputType,
3105 TosaErrorValidator.evWrongInputList,
3106 TosaErrorValidator.evWrongOutputList,
3107 TosaErrorValidator.evInputZeroPointNotZero,
3108 TosaErrorValidator.evOutputZeroPointNotZero,
3109 TosaErrorValidator.evPadLargerEqualKernel,
3110 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003111 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003112 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003113 "data_gen": {
3114 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3115 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003116 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003117 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003118 "conv2d_TEMPLATE": {
3119 "op": Op.CONV2D,
3120 "operands": (1, 2),
3121 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003122 "build_fcn": (
3123 build_conv2d,
3124 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003125 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003126 TosaArgGen.agConv,
3127 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003128 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003129 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003130 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3131 "error_if_validators": (
3132 TosaErrorValidator.evWrongInputType,
3133 TosaErrorValidator.evWrongOutputType,
3134 TosaErrorValidator.evWrongInputList,
3135 TosaErrorValidator.evWrongOutputList,
3136 TosaErrorValidator.evInputZeroPointNotZero,
3137 TosaErrorValidator.evWeightZeroPointNotZero,
3138 TosaErrorValidator.evPadSmallerZero,
3139 TosaErrorValidator.evStrideSmallerOne,
3140 TosaErrorValidator.evDilationSmallerOne,
3141 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003142 TosaErrorValidator.evConvOutputShapeMismatch,
3143 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003144 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003145 "data_gen": {
3146 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3147 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003148 "template": True,
3149 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003150 # Templated operator. Filled in by createDynamicOpLists
3151 "conv3d_TEMPLATE": {
3152 "op": Op.CONV3D,
3153 "operands": (1, 2),
3154 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003155 "build_fcn": (
3156 build_conv3d,
3157 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003158 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003159 TosaArgGen.agConv,
3160 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003161 "qgen": TosaQuantGen.qgConv,
3162 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003163 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3164 "error_if_validators": (
3165 TosaErrorValidator.evWrongInputType,
3166 TosaErrorValidator.evWrongOutputType,
3167 TosaErrorValidator.evWrongInputList,
3168 TosaErrorValidator.evWrongOutputList,
3169 TosaErrorValidator.evInputZeroPointNotZero,
3170 TosaErrorValidator.evWeightZeroPointNotZero,
3171 TosaErrorValidator.evPadSmallerZero,
3172 TosaErrorValidator.evStrideSmallerOne,
3173 TosaErrorValidator.evDilationSmallerOne,
3174 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003175 TosaErrorValidator.evConvOutputShapeMismatch,
3176 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003177 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003178 "template": True,
3179 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003180 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003181 "depthwise_conv2d_TEMPLATE": {
3182 "op": Op.DEPTHWISE_CONV2D,
3183 "operands": (1, 2),
3184 "filter": [1, 1],
3185 "rank": (4, 4),
3186 "build_fcn": (
3187 build_depthwise_conv2d,
3188 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003189 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003190 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003191 ),
3192 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003193 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003194 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3195 "error_if_validators": (
3196 TosaErrorValidator.evWrongInputType,
3197 TosaErrorValidator.evWrongOutputType,
3198 TosaErrorValidator.evWrongInputList,
3199 TosaErrorValidator.evWrongOutputList,
3200 TosaErrorValidator.evInputZeroPointNotZero,
3201 TosaErrorValidator.evWeightZeroPointNotZero,
3202 TosaErrorValidator.evPadSmallerZero,
3203 TosaErrorValidator.evStrideSmallerOne,
3204 TosaErrorValidator.evDilationSmallerOne,
3205 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003206 TosaErrorValidator.evConvOutputShapeMismatch,
3207 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003208 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003209 "template": True,
3210 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003211 "fully_connected": {
3212 "op": Op.FULLY_CONNECTED,
3213 "operands": (1, 2),
3214 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003215 "build_fcn": (
3216 build_fully_connected,
3217 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003218 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003219 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003220 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003221 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003222 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003223 "error_if_validators": (
3224 TosaErrorValidator.evInputZeroPointNotZero,
3225 TosaErrorValidator.evWeightZeroPointNotZero,
3226 TosaErrorValidator.evWrongRank,
3227 TosaErrorValidator.evWrongInputType,
3228 TosaErrorValidator.evWrongOutputType,
3229 TosaErrorValidator.evWrongInputList,
3230 TosaErrorValidator.evWrongOutputList,
3231 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003232 "data_gen": {
3233 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3234 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003235 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003236 "matmul": {
3237 "op": Op.MATMUL,
3238 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003239 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003240 "build_fcn": (
3241 build_matmul,
3242 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003243 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003244 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003245 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003246 "qgen": TosaQuantGen.qgMatmul,
3247 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003248 "error_if_validators": (
3249 TosaErrorValidator.evInputZeroPointNotZero,
3250 TosaErrorValidator.evWrongRank,
3251 TosaErrorValidator.evWrongInputType,
3252 TosaErrorValidator.evWrongOutputType,
3253 TosaErrorValidator.evWrongInputList,
3254 TosaErrorValidator.evWrongOutputList,
3255 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003256 "data_gen": {
3257 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003258 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003259 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003260 "max_pool2d": {
3261 "op": Op.MAX_POOL2D,
3262 "operands": (1, 0),
3263 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003264 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003265 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003266 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003267 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003268 TosaArgGen.agPooling,
3269 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003270 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003271 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003272 "error_if_validators": (
3273 TosaErrorValidator.evKernelSmallerOne,
3274 TosaErrorValidator.evStrideSmallerOne,
3275 TosaErrorValidator.evPadSmallerZero,
3276 TosaErrorValidator.evWrongRank,
3277 TosaErrorValidator.evWrongInputType,
3278 TosaErrorValidator.evWrongOutputType,
3279 TosaErrorValidator.evWrongInputList,
3280 TosaErrorValidator.evWrongOutputList,
3281 TosaErrorValidator.evPadLargerEqualKernel,
3282 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003283 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003284 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003285 "data_gen": {
3286 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3287 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003288 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003289 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003290 "transpose_conv2d_TEMPLATE": {
3291 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003292 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003293 "rank": (4, 4),
3294 "build_fcn": (
3295 build_transpose_conv2d,
3296 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003297 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003298 TosaArgGen.agTransposeConv2D,
3299 ),
3300 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003301 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003302 "invalid_test_validators": (
3303 TosaInvalidValidator.ivHeightWidthInvalid,
3304 TosaInvalidValidator.ivNonPositiveOutputShape,
3305 ),
3306 "error_if_validators": (
3307 TosaErrorValidator.evWrongInputType,
3308 TosaErrorValidator.evWrongOutputType,
3309 TosaErrorValidator.evWrongInputList,
3310 TosaErrorValidator.evWrongOutputList,
3311 TosaErrorValidator.evInputZeroPointNotZero,
3312 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003313 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003314 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003315 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003316 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003317 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003318 "template": True,
3319 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003320 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003321 "clamp": {
3322 "op": Op.CLAMP,
3323 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003324 "build_fcn": (
3325 build_clamp,
3326 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003327 TosaTensorValuesGen.tvgLazyGenDefault,
3328 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003329 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003330 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003331 "error_if_validators": (
3332 TosaErrorValidator.evMaxSmallerMin,
3333 TosaErrorValidator.evWrongInputType,
3334 TosaErrorValidator.evWrongOutputType,
3335 TosaErrorValidator.evWrongInputList,
3336 TosaErrorValidator.evWrongOutputList,
3337 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003338 "data_gen": {
3339 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3340 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003341 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003342 "sigmoid": {
3343 "op": Op.SIGMOID,
3344 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003345 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003346 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003347 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003348 TosaTensorValuesGen.tvgLazyGenDefault,
3349 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003350 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003351 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003352 "error_if_validators": (
3353 TosaErrorValidator.evWrongInputType,
3354 TosaErrorValidator.evWrongOutputType,
3355 TosaErrorValidator.evWrongInputList,
3356 TosaErrorValidator.evWrongOutputList,
3357 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003358 "data_gen": {
3359 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3360 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003361 },
3362 "tanh": {
3363 "op": Op.TANH,
3364 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003365 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003366 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003367 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003368 TosaTensorValuesGen.tvgLazyGenDefault,
3369 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003370 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003371 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003372 "error_if_validators": (
3373 TosaErrorValidator.evWrongInputType,
3374 TosaErrorValidator.evWrongOutputType,
3375 TosaErrorValidator.evWrongInputList,
3376 TosaErrorValidator.evWrongOutputList,
3377 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003378 "data_gen": {
3379 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3380 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003381 "compliance": {
3382 "abs_error_lower_bound": 0.5,
3383 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003384 },
Won Jeon78155c62023-06-10 00:20:04 +00003385 "erf": {
3386 "op": Op.ERF,
3387 "operands": (1, 0),
3388 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003389 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003390 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003391 TosaTensorValuesGen.tvgLazyGenDefault,
3392 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003393 ),
3394 "types": TYPE_FP,
3395 "error_if_validators": (
3396 TosaErrorValidator.evWrongInputType,
3397 TosaErrorValidator.evWrongOutputType,
3398 TosaErrorValidator.evWrongInputList,
3399 TosaErrorValidator.evWrongOutputList,
3400 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003401 "data_gen": {
3402 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3403 },
3404 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003405 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003406 # Elementwise Binary Operators
3407 "add": {
3408 "op": Op.ADD,
3409 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003410 "build_fcn": (
3411 build_binary_broadcast,
3412 TosaTensorGen.tgBroadcastFuzz,
3413 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003414 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003415 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003416 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003417 "error_if_validators": (
3418 TosaErrorValidator.evRankMismatch,
3419 TosaErrorValidator.evWrongInputType,
3420 TosaErrorValidator.evWrongOutputType,
3421 TosaErrorValidator.evWrongInputList,
3422 TosaErrorValidator.evWrongOutputList,
3423 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003424 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003425 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003426 "data_gen": {
3427 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3428 },
3429 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003430 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003431 "arithmetic_right_shift": {
3432 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3433 "operands": (2, 0),
3434 "build_fcn": (
3435 build_arithmetic_right_shift,
3436 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003437 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003438 TosaArgGen.agArithmeticRightShift,
3439 ),
3440 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003441 "error_if_validators": (
3442 TosaErrorValidator.evRankMismatch,
3443 TosaErrorValidator.evWrongInputType,
3444 TosaErrorValidator.evWrongOutputType,
3445 TosaErrorValidator.evWrongInputList,
3446 TosaErrorValidator.evWrongOutputList,
3447 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003448 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003449 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003450 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003451 "bitwise_and": {
3452 "op": Op.BITWISE_AND,
3453 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003454 "build_fcn": (
3455 build_binary_broadcast,
3456 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003457 TosaTensorValuesGen.tvgLazyGenDefault,
3458 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003459 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003460 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003461 "error_if_validators": (
3462 TosaErrorValidator.evRankMismatch,
3463 TosaErrorValidator.evWrongInputType,
3464 TosaErrorValidator.evWrongOutputType,
3465 TosaErrorValidator.evWrongInputList,
3466 TosaErrorValidator.evWrongOutputList,
3467 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003468 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003469 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003470 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003471 "bitwise_or": {
3472 "op": Op.BITWISE_OR,
3473 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003474 "build_fcn": (
3475 build_binary_broadcast,
3476 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003477 TosaTensorValuesGen.tvgLazyGenDefault,
3478 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003479 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003480 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003481 "error_if_validators": (
3482 TosaErrorValidator.evRankMismatch,
3483 TosaErrorValidator.evWrongInputType,
3484 TosaErrorValidator.evWrongOutputType,
3485 TosaErrorValidator.evWrongInputList,
3486 TosaErrorValidator.evWrongOutputList,
3487 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003488 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003489 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003490 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003491 "bitwise_xor": {
3492 "op": Op.BITWISE_XOR,
3493 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003494 "build_fcn": (
3495 build_binary_broadcast,
3496 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003497 TosaTensorValuesGen.tvgLazyGenDefault,
3498 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003499 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003500 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003501 "error_if_validators": (
3502 TosaErrorValidator.evRankMismatch,
3503 TosaErrorValidator.evWrongInputType,
3504 TosaErrorValidator.evWrongOutputType,
3505 TosaErrorValidator.evWrongInputList,
3506 TosaErrorValidator.evWrongOutputList,
3507 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003508 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003509 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003510 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003511 "intdiv": {
3512 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003513 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003514 "build_fcn": (
3515 build_binary_broadcast,
3516 TosaTensorGen.tgBroadcastFuzz,
3517 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003518 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003519 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003520 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003521 "error_if_validators": (
3522 TosaErrorValidator.evRankMismatch,
3523 TosaErrorValidator.evWrongInputType,
3524 TosaErrorValidator.evWrongOutputType,
3525 TosaErrorValidator.evWrongInputList,
3526 TosaErrorValidator.evWrongOutputList,
3527 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003528 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003529 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003530 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003531 "logical_and": {
3532 "op": Op.LOGICAL_AND,
3533 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003534 "build_fcn": (
3535 build_binary_broadcast,
3536 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003537 TosaTensorValuesGen.tvgLazyGenDefault,
3538 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003539 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003540 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003541 "error_if_validators": (
3542 TosaErrorValidator.evRankMismatch,
3543 TosaErrorValidator.evWrongInputType,
3544 TosaErrorValidator.evWrongOutputType,
3545 TosaErrorValidator.evWrongInputList,
3546 TosaErrorValidator.evWrongOutputList,
3547 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003548 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003549 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003550 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003551 "logical_left_shift": {
3552 "op": Op.LOGICAL_LEFT_SHIFT,
3553 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003554 "build_fcn": (
3555 build_binary_broadcast,
3556 TosaTensorGen.tgBroadcastFuzz,
3557 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003558 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003559 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003560 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003561 "error_if_validators": (
3562 TosaErrorValidator.evRankMismatch,
3563 TosaErrorValidator.evWrongInputType,
3564 TosaErrorValidator.evWrongOutputType,
3565 TosaErrorValidator.evWrongInputList,
3566 TosaErrorValidator.evWrongOutputList,
3567 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003568 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003569 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003570 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003571 "logical_right_shift": {
3572 "op": Op.LOGICAL_RIGHT_SHIFT,
3573 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003574 "build_fcn": (
3575 build_binary_broadcast,
3576 TosaTensorGen.tgBroadcastFuzz,
3577 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003578 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003579 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003580 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003581 "error_if_validators": (
3582 TosaErrorValidator.evRankMismatch,
3583 TosaErrorValidator.evWrongInputType,
3584 TosaErrorValidator.evWrongOutputType,
3585 TosaErrorValidator.evWrongInputList,
3586 TosaErrorValidator.evWrongOutputList,
3587 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003588 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003589 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003590 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003591 "logical_or": {
3592 "op": Op.LOGICAL_OR,
3593 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003594 "build_fcn": (
3595 build_binary_broadcast,
3596 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003597 TosaTensorValuesGen.tvgLazyGenDefault,
3598 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003599 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003600 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003601 "error_if_validators": (
3602 TosaErrorValidator.evRankMismatch,
3603 TosaErrorValidator.evWrongInputType,
3604 TosaErrorValidator.evWrongOutputType,
3605 TosaErrorValidator.evWrongInputList,
3606 TosaErrorValidator.evWrongOutputList,
3607 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003608 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003609 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003610 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003611 "logical_xor": {
3612 "op": Op.LOGICAL_XOR,
3613 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003614 "build_fcn": (
3615 build_binary_broadcast,
3616 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003617 TosaTensorValuesGen.tvgLazyGenDefault,
3618 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003619 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003620 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003621 "error_if_validators": (
3622 TosaErrorValidator.evRankMismatch,
3623 TosaErrorValidator.evWrongInputType,
3624 TosaErrorValidator.evWrongOutputType,
3625 TosaErrorValidator.evWrongInputList,
3626 TosaErrorValidator.evWrongOutputList,
3627 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003628 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003629 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003630 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003631 "maximum": {
3632 "op": Op.MAXIMUM,
3633 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003634 "build_fcn": (
3635 build_binary_broadcast,
3636 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003637 TosaTensorValuesGen.tvgLazyGenDefault,
3638 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003639 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003640 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003641 "error_if_validators": (
3642 TosaErrorValidator.evRankMismatch,
3643 TosaErrorValidator.evWrongInputType,
3644 TosaErrorValidator.evWrongOutputType,
3645 TosaErrorValidator.evWrongInputList,
3646 TosaErrorValidator.evWrongOutputList,
3647 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003648 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003649 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003650 "data_gen": {
3651 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3652 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003653 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003654 "minimum": {
3655 "op": Op.MINIMUM,
3656 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003657 "build_fcn": (
3658 build_binary_broadcast,
3659 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003660 TosaTensorValuesGen.tvgLazyGenDefault,
3661 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003662 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003663 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003664 "error_if_validators": (
3665 TosaErrorValidator.evRankMismatch,
3666 TosaErrorValidator.evWrongInputType,
3667 TosaErrorValidator.evWrongOutputType,
3668 TosaErrorValidator.evWrongInputList,
3669 TosaErrorValidator.evWrongOutputList,
3670 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003671 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003672 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003673 "data_gen": {
3674 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3675 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003676 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003677 "mul": {
3678 "op": Op.MUL,
3679 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003680 "build_fcn": (
3681 build_mul,
3682 TosaTensorGen.tgBroadcastFuzz,
3683 TosaTensorValuesGen.tvgMul,
3684 TosaArgGen.agMul,
3685 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003686 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003687 "error_if_validators": (
3688 TosaErrorValidator.evWrongInputType,
3689 TosaErrorValidator.evWrongOutputType,
3690 TosaErrorValidator.evWrongInputList,
3691 TosaErrorValidator.evWrongOutputList,
3692 TosaErrorValidator.evRankMismatch,
3693 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003694 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003695 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003696 "data_gen": {
3697 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3698 },
3699 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003700 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003701 "pow": {
3702 "op": Op.POW,
3703 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003704 "build_fcn": (
3705 build_binary_broadcast,
3706 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003707 TosaTensorValuesGen.tvgPow,
3708 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003709 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003710 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003711 "error_if_validators": (
3712 TosaErrorValidator.evRankMismatch,
3713 TosaErrorValidator.evWrongInputType,
3714 TosaErrorValidator.evWrongOutputType,
3715 TosaErrorValidator.evWrongInputList,
3716 TosaErrorValidator.evWrongOutputList,
3717 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003718 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003719 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003720 "data_gen": {
3721 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3722 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003723 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003724 "sub": {
3725 "op": Op.SUB,
3726 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003727 "build_fcn": (
3728 build_binary_broadcast,
3729 TosaTensorGen.tgBroadcastFuzz,
3730 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003731 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003732 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003733 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003734 "error_if_validators": (
3735 TosaErrorValidator.evRankMismatch,
3736 TosaErrorValidator.evWrongInputType,
3737 TosaErrorValidator.evWrongOutputType,
3738 TosaErrorValidator.evWrongInputList,
3739 TosaErrorValidator.evWrongOutputList,
3740 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003741 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003742 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003743 "data_gen": {
3744 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3745 },
3746 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003747 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003748 "table": {
3749 "op": Op.TABLE,
3750 # Use the automatic generation functions to create the input array
3751 # but create the table tensor in the build function, as it may be
3752 # a different type from the input
3753 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003754 "build_fcn": (
3755 build_table,
3756 TosaTensorGen.tgBasic,
3757 TosaTensorValuesGen.tvgDefault,
3758 TosaArgGen.agTable,
3759 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003760 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003761 "error_if_validators": (
3762 TosaErrorValidator.evWrongInputType,
3763 TosaErrorValidator.evWrongOutputType,
3764 TosaErrorValidator.evWrongInputList,
3765 TosaErrorValidator.evWrongOutputList,
3766 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003767 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003768 # Elementwise Unary operators
3769 "abs": {
3770 "op": Op.ABS,
3771 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003772 "build_fcn": (
3773 build_unary,
3774 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003775 TosaTensorValuesGen.tvgLazyGenDefault,
3776 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003777 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003778 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003779 "error_if_validators": (
3780 TosaErrorValidator.evWrongInputType,
3781 TosaErrorValidator.evWrongOutputType,
3782 TosaErrorValidator.evWrongInputList,
3783 TosaErrorValidator.evWrongOutputList,
3784 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003785 "data_gen": {
3786 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3787 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003788 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003789 "bitwise_not": {
3790 "op": Op.BITWISE_NOT,
3791 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003792 "build_fcn": (
3793 build_unary,
3794 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003795 TosaTensorValuesGen.tvgLazyGenDefault,
3796 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003797 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003798 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003799 "error_if_validators": (
3800 TosaErrorValidator.evWrongInputType,
3801 TosaErrorValidator.evWrongOutputType,
3802 TosaErrorValidator.evWrongInputList,
3803 TosaErrorValidator.evWrongOutputList,
3804 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003805 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003806 "ceil": {
3807 "op": Op.CEIL,
3808 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003809 "build_fcn": (
3810 build_unary,
3811 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003812 TosaTensorValuesGen.tvgLazyGenDefault,
3813 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003814 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003815 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003816 "error_if_validators": (
3817 TosaErrorValidator.evWrongInputType,
3818 TosaErrorValidator.evWrongOutputType,
3819 TosaErrorValidator.evWrongInputList,
3820 TosaErrorValidator.evWrongOutputList,
3821 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003822 "data_gen": {
3823 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3824 },
3825 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003826 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003827 "clz": {
3828 "op": Op.CLZ,
3829 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003830 "build_fcn": (
3831 build_unary,
3832 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003833 TosaTensorValuesGen.tvgLazyGenDefault,
3834 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003835 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003836 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003837 "error_if_validators": (
3838 TosaErrorValidator.evWrongInputType,
3839 TosaErrorValidator.evWrongOutputType,
3840 TosaErrorValidator.evWrongInputList,
3841 TosaErrorValidator.evWrongOutputList,
3842 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003843 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003844 "exp": {
3845 "op": Op.EXP,
3846 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003847 "build_fcn": (
3848 build_unary,
3849 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003850 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003851 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003852 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003853 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003854 "error_if_validators": (
3855 TosaErrorValidator.evWrongInputType,
3856 TosaErrorValidator.evWrongOutputType,
3857 TosaErrorValidator.evWrongInputList,
3858 TosaErrorValidator.evWrongOutputList,
3859 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003860 "data_gen": {
3861 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3862 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003863 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003864 "floor": {
3865 "op": Op.FLOOR,
3866 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003867 "build_fcn": (
3868 build_unary,
3869 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003870 TosaTensorValuesGen.tvgLazyGenDefault,
3871 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003872 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003873 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003874 "error_if_validators": (
3875 TosaErrorValidator.evWrongInputType,
3876 TosaErrorValidator.evWrongOutputType,
3877 TosaErrorValidator.evWrongInputList,
3878 TosaErrorValidator.evWrongOutputList,
3879 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003880 "data_gen": {
3881 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3882 },
3883 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003884 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003885 "log": {
3886 "op": Op.LOG,
3887 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003888 "build_fcn": (
3889 build_unary,
3890 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003891 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003892 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003893 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003894 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003895 "error_if_validators": (
3896 TosaErrorValidator.evWrongInputType,
3897 TosaErrorValidator.evWrongOutputType,
3898 TosaErrorValidator.evWrongInputList,
3899 TosaErrorValidator.evWrongOutputList,
3900 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003901 "data_gen": {
3902 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3903 },
3904 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003905 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003906 "logical_not": {
3907 "op": Op.LOGICAL_NOT,
3908 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003909 "build_fcn": (
3910 build_unary,
3911 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003912 TosaTensorValuesGen.tvgLazyGenDefault,
3913 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003914 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003915 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003916 "error_if_validators": (
3917 TosaErrorValidator.evWrongInputType,
3918 TosaErrorValidator.evWrongOutputType,
3919 TosaErrorValidator.evWrongInputList,
3920 TosaErrorValidator.evWrongOutputList,
3921 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003922 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003923 "negate": {
3924 "op": Op.NEGATE,
3925 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003926 "build_fcn": (
3927 build_unary,
3928 TosaTensorGen.tgBasic,
3929 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003930 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003931 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003932 "qgen": TosaQuantGen.qgUnary,
3933 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003934 "error_if_validators": (
3935 TosaErrorValidator.evInputZeroPointNotZero,
3936 TosaErrorValidator.evOutputZeroPointNotZero,
3937 TosaErrorValidator.evWrongInputType,
3938 TosaErrorValidator.evWrongOutputType,
3939 TosaErrorValidator.evWrongInputList,
3940 TosaErrorValidator.evWrongOutputList,
3941 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003942 "data_gen": {
3943 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3944 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003945 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003946 "reciprocal": {
3947 "op": Op.RECIPROCAL,
3948 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003949 "build_fcn": (
3950 build_unary,
3951 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003952 TosaTensorValuesGen.tvgLazyGenDefault,
3953 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003954 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003955 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003956 "error_if_validators": (
3957 TosaErrorValidator.evWrongInputType,
3958 TosaErrorValidator.evWrongOutputType,
3959 TosaErrorValidator.evWrongInputList,
3960 TosaErrorValidator.evWrongOutputList,
3961 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003962 "data_gen": {
3963 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3964 },
3965 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003966 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003967 "rsqrt": {
3968 "op": Op.RSQRT,
3969 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003970 "build_fcn": (
3971 build_unary,
3972 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003973 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003974 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003975 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003976 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003977 "error_if_validators": (
3978 TosaErrorValidator.evWrongInputType,
3979 TosaErrorValidator.evWrongOutputType,
3980 TosaErrorValidator.evWrongInputList,
3981 TosaErrorValidator.evWrongOutputList,
3982 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003983 "data_gen": {
3984 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3985 },
3986 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003987 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003988 # Elementwise Ternary operators
3989 "select": {
3990 "op": Op.SELECT,
3991 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003992 "build_fcn": (
3993 build_select,
3994 TosaTensorGen.tgBroadcastFuzz,
3995 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00003996 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003997 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003998 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003999 "error_if_validators": (
4000 TosaErrorValidator.evRankMismatch,
4001 TosaErrorValidator.evWrongInputType,
4002 TosaErrorValidator.evWrongOutputType,
4003 TosaErrorValidator.evWrongInputList,
4004 TosaErrorValidator.evWrongOutputList,
4005 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004006 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004007 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004008 "data_gen": {
4009 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4010 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004011 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004012 # Comparison operators
4013 "equal": {
4014 "op": Op.EQUAL,
4015 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004016 "build_fcn": (
4017 build_comparison,
4018 TosaTensorGen.tgBroadcastFuzz,
4019 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004020 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004021 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004022 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004023 "error_if_validators": (
4024 TosaErrorValidator.evRankMismatch,
4025 TosaErrorValidator.evWrongInputType,
4026 TosaErrorValidator.evWrongOutputType,
4027 TosaErrorValidator.evWrongInputList,
4028 TosaErrorValidator.evWrongOutputList,
4029 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004030 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004031 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004032 "data_gen": {
4033 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4034 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004035 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004036 "greater_equal": {
4037 "op": Op.GREATER_EQUAL,
4038 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004039 "build_fcn": (
4040 build_comparison,
4041 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004042 TosaTensorValuesGen.tvgLazyGenDefault,
4043 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004044 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004045 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004046 "error_if_validators": (
4047 TosaErrorValidator.evRankMismatch,
4048 TosaErrorValidator.evWrongInputType,
4049 TosaErrorValidator.evWrongOutputType,
4050 TosaErrorValidator.evWrongInputList,
4051 TosaErrorValidator.evWrongOutputList,
4052 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004053 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004054 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004055 "data_gen": {
4056 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4057 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004058 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004059 "greater": {
4060 "op": Op.GREATER,
4061 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004062 "build_fcn": (
4063 build_comparison,
4064 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004065 TosaTensorValuesGen.tvgLazyGenDefault,
4066 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004067 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004068 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004069 "error_if_validators": (
4070 TosaErrorValidator.evRankMismatch,
4071 TosaErrorValidator.evWrongInputType,
4072 TosaErrorValidator.evWrongOutputType,
4073 TosaErrorValidator.evWrongInputList,
4074 TosaErrorValidator.evWrongOutputList,
4075 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004076 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004077 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004078 "data_gen": {
4079 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4080 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004081 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004082 # Reduction operators
4083 "reduce_all": {
4084 "op": Op.REDUCE_ALL,
4085 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004086 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004087 "build_fcn": (
4088 build_reduce,
4089 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004090 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004091 TosaArgGen.agAxis,
4092 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004093 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004094 "error_if_validators": (
4095 TosaErrorValidator.evAxisLargerRank,
4096 TosaErrorValidator.evAxisSmallerZero,
4097 TosaErrorValidator.evShapeOfAxisNotOne,
4098 TosaErrorValidator.evWrongInputType,
4099 TosaErrorValidator.evWrongOutputType,
4100 TosaErrorValidator.evWrongRank,
4101 TosaErrorValidator.evWrongInputList,
4102 TosaErrorValidator.evWrongOutputList,
4103 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004104 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004105 "reduce_any": {
4106 "op": Op.REDUCE_ANY,
4107 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004108 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004109 "build_fcn": (
4110 build_reduce,
4111 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004112 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004113 TosaArgGen.agAxis,
4114 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004115 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004116 "error_if_validators": (
4117 TosaErrorValidator.evAxisLargerRank,
4118 TosaErrorValidator.evAxisSmallerZero,
4119 TosaErrorValidator.evShapeOfAxisNotOne,
4120 TosaErrorValidator.evWrongInputType,
4121 TosaErrorValidator.evWrongOutputType,
4122 TosaErrorValidator.evWrongRank,
4123 TosaErrorValidator.evWrongInputList,
4124 TosaErrorValidator.evWrongOutputList,
4125 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004126 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004127 "reduce_max": {
4128 "op": Op.REDUCE_MAX,
4129 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004130 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004131 "build_fcn": (
4132 build_reduce,
4133 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004134 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004135 TosaArgGen.agAxis,
4136 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004137 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004138 "error_if_validators": (
4139 TosaErrorValidator.evAxisLargerRank,
4140 TosaErrorValidator.evAxisSmallerZero,
4141 TosaErrorValidator.evShapeOfAxisNotOne,
4142 TosaErrorValidator.evWrongInputType,
4143 TosaErrorValidator.evWrongOutputType,
4144 TosaErrorValidator.evWrongRank,
4145 TosaErrorValidator.evWrongInputList,
4146 TosaErrorValidator.evWrongOutputList,
4147 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004148 "data_gen": {
4149 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4150 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004151 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004152 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004153 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004154 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004155 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004156 "build_fcn": (
4157 build_reduce,
4158 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004159 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004160 TosaArgGen.agAxis,
4161 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004162 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004163 "error_if_validators": (
4164 TosaErrorValidator.evAxisLargerRank,
4165 TosaErrorValidator.evAxisSmallerZero,
4166 TosaErrorValidator.evShapeOfAxisNotOne,
4167 TosaErrorValidator.evWrongInputType,
4168 TosaErrorValidator.evWrongOutputType,
4169 TosaErrorValidator.evWrongRank,
4170 TosaErrorValidator.evWrongInputList,
4171 TosaErrorValidator.evWrongOutputList,
4172 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004173 "data_gen": {
4174 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4175 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004176 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004177 "reduce_product": {
4178 "op": Op.REDUCE_PRODUCT,
4179 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004180 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004181 "build_fcn": (
4182 build_reduce,
4183 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004184 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004185 TosaArgGen.agAxis,
4186 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004187 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004188 "error_if_validators": (
4189 TosaErrorValidator.evAxisLargerRank,
4190 TosaErrorValidator.evAxisSmallerZero,
4191 TosaErrorValidator.evShapeOfAxisNotOne,
4192 TosaErrorValidator.evWrongInputType,
4193 TosaErrorValidator.evWrongOutputType,
4194 TosaErrorValidator.evWrongRank,
4195 TosaErrorValidator.evWrongInputList,
4196 TosaErrorValidator.evWrongOutputList,
4197 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004198 "data_gen": {
4199 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4200 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004201 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004202 "reduce_sum": {
4203 "op": Op.REDUCE_SUM,
4204 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004205 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004206 "build_fcn": (
4207 build_reduce,
4208 TosaTensorGen.tgBasic,
4209 TosaTensorValuesGen.tvgReduceSum,
4210 TosaArgGen.agAxis,
4211 ),
James Ward24dbc422022-10-19 12:20:31 +01004212 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004213 "error_if_validators": (
4214 TosaErrorValidator.evAxisLargerRank,
4215 TosaErrorValidator.evAxisSmallerZero,
4216 TosaErrorValidator.evShapeOfAxisNotOne,
4217 TosaErrorValidator.evWrongInputType,
4218 TosaErrorValidator.evWrongOutputType,
4219 TosaErrorValidator.evWrongRank,
4220 TosaErrorValidator.evWrongInputList,
4221 TosaErrorValidator.evWrongOutputList,
4222 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004223 "data_gen": {
4224 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4225 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004226 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004227 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004228 "concat": {
4229 "op": Op.CONCAT,
4230 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004231 "build_fcn": (
4232 build_concat,
4233 TosaTensorGen.tgConcat,
4234 TosaTensorValuesGen.tvgConcat,
4235 TosaArgGen.agAxis,
4236 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004237 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004238 "error_if_validators": (
4239 TosaErrorValidator.evAxisLargerRank,
4240 TosaErrorValidator.evAxisSmallerZero,
4241 TosaErrorValidator.evConcatInputRankMismatch,
4242 TosaErrorValidator.evConcatShapeSumMismatch,
4243 TosaErrorValidator.evConcatInputDimMismatch,
4244 TosaErrorValidator.evWrongInputType,
4245 TosaErrorValidator.evWrongOutputType,
4246 TosaErrorValidator.evWrongOutputList,
4247 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004248 "data_gen": {
4249 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4250 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004251 },
4252 "pad": {
4253 "op": Op.PAD,
4254 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004255 "build_fcn": (
4256 build_pad,
4257 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004258 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004259 TosaArgGen.agPad,
4260 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004261 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004262 "error_if_validators": (
4263 TosaErrorValidator.evWrongInputType,
4264 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004265 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004266 TosaErrorValidator.evWrongOutputType,
4267 TosaErrorValidator.evWrongInputList,
4268 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004269 TosaErrorValidator.evRankMismatch,
4270 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004271 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004272 "data_gen": {
4273 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4274 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004275 },
Won Jeona21b2e82023-08-10 10:33:01 +00004276 "dim": {
4277 "op": Op.DIM,
4278 "operands": (1, 0),
4279 "build_fcn": (
4280 build_dim,
4281 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004282 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004283 TosaArgGen.agAxis,
4284 ),
4285 "types": TYPE_FIB,
4286 "error_if_validators": (
4287 TosaErrorValidator.evAxisLargerRank,
4288 TosaErrorValidator.evAxisSmallerZero,
4289 TosaErrorValidator.evWrongInputType,
4290 TosaErrorValidator.evWrongInputList,
4291 TosaErrorValidator.evWrongOutputList,
4292 TosaErrorValidator.evWrongRank,
4293 ),
4294 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004295 "reshape": {
4296 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004297 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004298 "build_fcn": (
4299 build_reshape,
4300 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004301 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004302 TosaArgGen.agReshape,
4303 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004304 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004305 "error_if_validators": (
4306 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4307 TosaErrorValidator.evWrongInputType,
4308 TosaErrorValidator.evWrongOutputType,
4309 TosaErrorValidator.evWrongInputList,
4310 TosaErrorValidator.evWrongOutputList,
4311 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004312 "data_gen": {
4313 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4314 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004315 },
4316 "reverse": {
4317 "op": Op.REVERSE,
4318 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004319 "build_fcn": (
4320 build_reverse,
4321 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004322 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004323 TosaArgGen.agAxis,
4324 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004325 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004326 "error_if_validators": (
4327 TosaErrorValidator.evAxisSmallerZero,
4328 TosaErrorValidator.evAxisLargerRank,
4329 TosaErrorValidator.evWrongInputType,
4330 TosaErrorValidator.evWrongOutputType,
4331 TosaErrorValidator.evWrongInputList,
4332 TosaErrorValidator.evWrongOutputList,
4333 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004334 },
4335 "slice": {
4336 "op": Op.SLICE,
4337 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004338 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004339 "build_fcn": (
4340 build_slice,
4341 TosaTensorGen.tgBasic,
evacha017f7d4252024-01-24 12:08:09 +00004342 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004343 TosaArgGen.agSlice,
4344 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004345 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004346 "error_if_validators": (
4347 TosaErrorValidator.evStartSmallerZero,
4348 TosaErrorValidator.evSizeSmallerEqualZero,
4349 TosaErrorValidator.evStartSizeOutsideBounds,
4350 TosaErrorValidator.evSizeOutputShapeMismatch,
4351 TosaErrorValidator.evInputSizeStartLengthMismatch,
4352 TosaErrorValidator.evWrongRank,
4353 TosaErrorValidator.evWrongInputType,
4354 TosaErrorValidator.evWrongOutputType,
4355 TosaErrorValidator.evWrongInputList,
4356 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004357 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004358 ),
evacha017f7d4252024-01-24 12:08:09 +00004359 "data_gen": {
4360 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4361 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004362 },
4363 "tile": {
4364 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004365 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004366 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004367 "build_fcn": (
4368 build_tile,
4369 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004370 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004371 TosaArgGen.agTile,
4372 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004373 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004374 "error_if_validators": (
4375 TosaErrorValidator.evWrongInputType,
4376 TosaErrorValidator.evWrongOutputType,
4377 TosaErrorValidator.evWrongInputList,
4378 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004379 TosaErrorValidator.evRankMismatch,
4380 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004381 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004382 "data_gen": {
4383 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4384 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004385 },
4386 "transpose": {
4387 "op": Op.TRANSPOSE,
4388 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004389 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004390 "build_fcn": (
4391 build_transpose,
4392 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004393 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004394 TosaArgGen.agTranspose,
4395 ),
4396 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004397 "error_if_validators": (
4398 TosaErrorValidator.evIndexOutsideBounds,
4399 TosaErrorValidator.evIndexUsedTwice,
4400 TosaErrorValidator.evWrongInputType,
4401 TosaErrorValidator.evWrongOutputType,
4402 TosaErrorValidator.evWrongInputList,
4403 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004404 TosaErrorValidator.evWrongRank,
4405 TosaErrorValidator.evRankMismatch,
4406 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004407 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004408 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004409 # Data nodes
4410 "const": {
4411 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004412 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004413 "build_fcn": (
4414 build_const,
4415 TosaTensorGen.tgBasic,
4416 TosaTensorValuesGen.tvgDefault,
4417 None,
4418 ),
Luke Hutton65872422023-02-20 10:33:04 +00004419 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004420 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004421 "identity": {
4422 "op": Op.IDENTITY,
4423 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004424 "build_fcn": (
4425 build_unary,
4426 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004427 TosaTensorValuesGen.tvgLazyGenDefault,
4428 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004429 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004430 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004431 "data_gen": {
4432 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4433 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004434 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004435 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004436 "gather": {
4437 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004438 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004439 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004440 "build_fcn": (
4441 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004442 TosaTensorGen.tgGather,
4443 TosaTensorValuesGen.tvgGather,
4444 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004445 ),
James Ward24dbc422022-10-19 12:20:31 +01004446 "types": (
4447 DType.INT8,
4448 DType.INT16,
4449 DType.INT32,
4450 DType.FP16,
4451 DType.BF16,
4452 DType.FP32,
4453 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004454 "error_if_validators": (
4455 TosaErrorValidator.evWrongInputType,
4456 TosaErrorValidator.evWrongOutputType,
4457 TosaErrorValidator.evWrongInputList,
4458 TosaErrorValidator.evWrongOutputList,
4459 TosaErrorValidator.evWrongRank,
4460 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004461 "data_gen": {
4462 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4463 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004464 },
4465 "scatter": {
4466 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004467 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004468 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004469 "build_fcn": (
4470 build_scatter,
4471 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004472 TosaTensorValuesGen.tvgScatter,
4473 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004474 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004475 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004476 "error_if_validators": (
4477 TosaErrorValidator.evWrongInputType,
4478 TosaErrorValidator.evWrongOutputType,
4479 TosaErrorValidator.evWrongInputList,
4480 TosaErrorValidator.evWrongOutputList,
4481 TosaErrorValidator.evWrongRank,
4482 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004483 "data_gen": {
4484 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4485 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004486 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004487 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004488 "resize": {
4489 "op": Op.RESIZE,
4490 "operands": (1, 0),
4491 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004492 "build_fcn": (
4493 build_resize,
4494 TosaTensorGen.tgNHWC,
4495 TosaTensorValuesGen.tvgDefault,
4496 TosaArgGen.agResize,
4497 ),
James Ward24dbc422022-10-19 12:20:31 +01004498 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004499 "invalid_test_validators": (
4500 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004501 ),
4502 "error_if_validators": (
4503 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004504 TosaErrorValidator.evScaleSmallerEqualZero,
4505 TosaErrorValidator.evScaleNLargerMax,
4506 TosaErrorValidator.evScaleDLargerMax,
4507 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004508 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004509 TosaErrorValidator.evBorderSmallerMin,
4510 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004511 TosaErrorValidator.evWrongInputType,
4512 TosaErrorValidator.evWrongOutputType,
4513 TosaErrorValidator.evWrongRank,
4514 TosaErrorValidator.evWrongInputList,
4515 TosaErrorValidator.evWrongOutputList,
4516 TosaErrorValidator.evBatchMismatch,
4517 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004518 TosaErrorValidator.evResizeOutputShapeMismatch,
4519 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004520 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004521 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004522 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004523 "cast": {
4524 "op": Op.CAST,
4525 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004526 "build_fcn": (
4527 build_cast,
4528 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004529 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004530 TosaArgGen.agCast,
4531 ),
James Ward8b390432022-08-12 20:48:56 +01004532 "types": (
4533 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004534 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004535 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004536 DType.INT8,
4537 DType.INT16,
4538 DType.INT32,
4539 DType.BOOL,
4540 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004541 "error_if_validators": (
4542 TosaErrorValidator.evWrongInputType,
4543 TosaErrorValidator.evWrongOutputType,
4544 TosaErrorValidator.evWrongInputList,
4545 TosaErrorValidator.evWrongOutputList,
4546 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004547 "data_gen": {
4548 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4549 },
4550 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004551 },
4552 "rescale": {
4553 "op": Op.RESCALE,
4554 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004555 "build_fcn": (
4556 build_rescale,
4557 TosaTensorGen.tgBasic,
4558 TosaTensorValuesGen.tvgDefault,
4559 TosaArgGen.agRescale,
4560 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004561 "types": [
4562 DType.UINT8,
4563 DType.INT8,
4564 DType.INT16,
4565 DType.INT32,
4566 DType.INT48,
4567 DType.UINT16,
4568 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004569 "error_if_validators": (
4570 TosaErrorValidator.evInputZeroPointNotZero,
4571 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004572 TosaErrorValidator.evU16InputZeroPointNotValid,
4573 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004574 TosaErrorValidator.evScaleTrue,
4575 TosaErrorValidator.evScaleNotTrue,
4576 TosaErrorValidator.evWrongInputType,
4577 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004578 TosaErrorValidator.evWrongInputList,
4579 TosaErrorValidator.evWrongOutputList,
4580 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004581 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004582 # Custom
4583 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004584 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004585 # Two varients of cond_if, one that generates one of two constant tensors (no
4586 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4587 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004588 "cond_if_const": {
4589 "op": Op.COND_IF,
4590 "operands": (0, 2),
4591 "build_fcn": (
4592 build_cond_if_const,
4593 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004594 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004595 TosaArgGen.agCondIf,
4596 ),
4597 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004598 "error_if_validators": (
4599 TosaErrorValidator.evOutputListThenGraphMismatch,
4600 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004601 TosaErrorValidator.evCondIfCondNotMatchingBool,
4602 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004603 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004604 },
4605 "cond_if_binary": {
4606 "op": Op.COND_IF,
4607 "operands": (2, 0),
4608 "build_fcn": (
4609 build_cond_if_binary,
4610 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004611 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004612 TosaArgGen.agCondIf,
4613 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004614 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004615 "error_if_validators": (
4616 TosaErrorValidator.evInputListThenGraphMismatch,
4617 TosaErrorValidator.evInputListElseGraphMismatch,
4618 TosaErrorValidator.evOutputListThenGraphMismatch,
4619 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004620 TosaErrorValidator.evCondIfCondNotMatchingBool,
4621 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004622 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004623 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004624 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004625 "while_loop": {
4626 "op": Op.WHILE_LOOP,
4627 "operands": (0, 1),
4628 "build_fcn": (
4629 build_while_loop,
4630 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004631 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004632 TosaArgGen.agWhileLoop,
4633 ),
4634 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004635 "error_if_validators": (
4636 TosaErrorValidator.evInputListOutputListMismatch,
4637 TosaErrorValidator.evInputListCondGraphMismatch,
4638 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4639 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4640 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004641 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004642 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004643 },
Luke Hutton57287132023-02-06 14:54:18 +00004644 "fft2d": {
4645 "op": Op.FFT2D,
4646 "operands": (2, 0),
4647 "rank": (3, 3),
4648 "build_fcn": (
4649 build_fft2d,
4650 TosaTensorGen.tgFFT2d,
4651 TosaTensorValuesGen.tvgDefault,
4652 TosaArgGen.agFFT2d,
4653 ),
4654 "types": [DType.FP32],
4655 "error_if_validators": (
4656 TosaErrorValidator.evWrongInputType,
4657 TosaErrorValidator.evWrongOutputType,
4658 TosaErrorValidator.evWrongInputList,
4659 TosaErrorValidator.evWrongOutputList,
4660 TosaErrorValidator.evWrongRank,
4661 TosaErrorValidator.evBatchMismatch,
4662 TosaErrorValidator.evKernelNotPowerOfTwo,
4663 TosaErrorValidator.evFFTInputShapeMismatch,
4664 TosaErrorValidator.evFFTOutputShapeMismatch,
4665 ),
4666 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004667 "rfft2d": {
4668 "op": Op.RFFT2D,
4669 "operands": (1, 0),
4670 "rank": (3, 3),
4671 "build_fcn": (
4672 build_rfft2d,
4673 TosaTensorGen.tgRFFT2d,
4674 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004675 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004676 ),
4677 "types": [DType.FP32],
4678 "error_if_validators": (
4679 TosaErrorValidator.evWrongInputType,
4680 TosaErrorValidator.evWrongOutputType,
4681 TosaErrorValidator.evWrongInputList,
4682 TosaErrorValidator.evWrongOutputList,
4683 TosaErrorValidator.evWrongRank,
4684 TosaErrorValidator.evBatchMismatch,
4685 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004686 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004687 ),
4688 },
Won Jeon74342e52024-01-09 00:34:40 +00004689 # Shape
4690 "add_shape": {
4691 "op": Op.ADD_SHAPE,
4692 "operands": (2, 0),
4693 "build_fcn": (
4694 build_shape_op,
4695 TosaTensorGen.tgShape,
4696 TosaTensorValuesGen.tvgAddSub,
4697 TosaArgGen.agNone,
4698 ),
4699 "types": [DType.SHAPE],
4700 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4701 },
4702 "sub_shape": {
4703 "op": Op.SUB_SHAPE,
4704 "operands": (2, 0),
4705 "build_fcn": (
4706 build_shape_op,
4707 TosaTensorGen.tgShape,
4708 TosaTensorValuesGen.tvgAddSub,
4709 TosaArgGen.agNone,
4710 ),
4711 "types": [DType.SHAPE],
4712 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4713 },
4714 "mul_shape": {
4715 "op": Op.MUL_SHAPE,
4716 "operands": (2, 0),
4717 "build_fcn": (
4718 build_shape_op,
4719 TosaTensorGen.tgShape,
4720 TosaTensorValuesGen.tvgMul,
4721 TosaArgGen.agNone,
4722 ),
4723 "types": [DType.SHAPE],
4724 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4725 },
4726 "div_shape": {
4727 "op": Op.DIV_SHAPE,
4728 "operands": (2, 0),
4729 "build_fcn": (
4730 build_shape_op,
4731 TosaTensorGen.tgShape,
4732 TosaTensorValuesGen.tvgIntDiv,
4733 TosaArgGen.agNone,
4734 ),
4735 "types": [DType.SHAPE],
4736 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4737 },
4738 "concat_shape": {
4739 "op": Op.CONCAT_SHAPE,
4740 "operands": (2, 0),
4741 "build_fcn": (
4742 build_concat,
4743 TosaTensorGen.tgConcat,
4744 TosaTensorValuesGen.tvgConcat,
4745 TosaArgGen.agNone,
4746 ),
4747 "types": [DType.SHAPE],
4748 "error_if_validators": (),
4749 },
4750 "const_shape": {
4751 "op": Op.CONST_SHAPE,
4752 "operands": (0, 1),
4753 "build_fcn": (
4754 build_const,
4755 TosaTensorGen.tgBasic,
4756 TosaTensorValuesGen.tvgDefault,
4757 None,
4758 ),
4759 "types": [DType.SHAPE],
4760 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004761 }
4762
Kevin Cheng550ccc52021-03-03 11:21:43 -08004763
Eric Kunzee5e26762020-10-13 16:11:07 -07004764class OutputShaper:
4765 # Methods in this class compute the expected output shape and datatype
4766 # for common classes of operations
4767 def __init__(self):
4768 pass
4769
4770 # These methods return arguments that can be used for
4771 # creating a new output tensor
4772 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004773 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4774 if error_name != ErrorIf.RankMismatch:
4775 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004776 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004777
4778 shape = []
4779 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004780 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004781 shape.append(b.shape[i])
4782 else:
4783 shape.append(a.shape[i])
4784
Jerry Ge135c9552023-05-23 20:59:32 +00004785 fuzz_idx = rng.integers(0, len(a.shape))
4786 if error_name == ErrorIf.DimensionMismatch:
4787 shape[fuzz_idx] += 1
4788
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004789 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004790 all_dtypes = [
4791 DType.INT8,
4792 DType.INT16,
4793 DType.INT32,
4794 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004795 DType.FP16,
4796 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004797 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004798 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004799 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4800 outputDType = rng.choice(wrong_dtypes)
4801 else:
4802 outputDType = a.dtype
4803
4804 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004805
4806 @staticmethod
4807 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004808 assert len(a.shape) == len(b.shape)
4809 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004810
4811 shape = []
4812 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004813 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004814 shape.append(a.shape[i])
4815
Kevin Cheng550ccc52021-03-03 11:21:43 -08004816 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004817
4818 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004819 def unaryOp(ser, rng, a, error_name=None):
4820 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004821 all_dtypes = [
4822 DType.INT8,
4823 DType.INT16,
4824 DType.INT32,
4825 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004826 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004827 DType.FP16,
4828 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004829 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004830 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4831 outputDType = rng.choice(wrong_dtypes)
4832 else:
4833 outputDType = a.dtype
4834
4835 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004836
4837 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004838 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004839 if error_name != ErrorIf.RankMismatch:
4840 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004841 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004842
4843 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004844 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004845 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004846 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4847 else:
4848 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004849
Jerry Ge135c9552023-05-23 20:59:32 +00004850 fuzz_idx = rng.integers(0, len(a.shape))
4851 if error_name == ErrorIf.DimensionMismatch:
4852 shape[fuzz_idx] += 1
4853
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004854 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004855 all_dtypes = [
4856 DType.INT8,
4857 DType.INT16,
4858 DType.INT32,
4859 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004860 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004861 DType.FP16,
4862 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004863 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004864 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4865 outputDType = rng.choice(wrong_dtypes)
4866 else:
4867 outputDType = a.dtype
4868
4869 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004870
4871 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004872 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004873 if error_name != ErrorIf.RankMismatch:
4874 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004875 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004876
4877 # Do broadcast
4878 shape = []
4879 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004880 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004881 shape.append(b.shape[i])
4882 else:
4883 shape.append(a.shape[i])
4884
Jerry Ge135c9552023-05-23 20:59:32 +00004885 fuzz_idx = rng.integers(0, len(a.shape))
4886 if error_name == ErrorIf.DimensionMismatch:
4887 shape[fuzz_idx] += 1
4888
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004889 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004890 wrong_dtypes = [
4891 DType.INT8,
4892 DType.INT16,
4893 DType.INT32,
4894 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004895 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004896 DType.FP16,
4897 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004898 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004899 outputDType = rng.choice(wrong_dtypes)
4900 else:
4901 outputDType = DType.BOOL
4902
4903 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004904
4905 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004906 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004907 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004908 if error_name not in [
4909 ErrorIf.AxisSmallerZero,
4910 ErrorIf.AxisLargerRank,
4911 ErrorIf.ShapeOfAxisNotOne,
4912 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004913 shape[axis] = 1
4914 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4915 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004916
Matthew Haddond6ce7252021-09-29 15:35:44 +01004917 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004918 all_dtypes = [
4919 DType.INT8,
4920 DType.INT16,
4921 DType.INT32,
4922 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004923 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004924 DType.FP16,
4925 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004926 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004927 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4928 outputDType = rng.choice(wrong_dtypes)
4929 else:
4930 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004931
Matthew Haddond6ce7252021-09-29 15:35:44 +01004932 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004933
4934 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004935 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004936 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004937
4938 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4939 del shape[axis]
4940
4941 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4942 remove = rng.choice([True, False])
4943 if remove and len(shape) > 1:
4944 del shape[0]
4945 else:
4946 shape.append(1)
4947 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4948 for i in range(len(shape)):
4949 shape[i] = shape[i] + rng.integers(1, 10)
4950
4951 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004952 all_dtypes = [
4953 DType.INT8,
4954 DType.INT16,
4955 DType.INT32,
4956 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004957 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004958 DType.FP16,
4959 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004960 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004961 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4962 outputDType = rng.choice(wrong_dtypes)
4963 else:
4964 outputDType = DType.INT32
4965
4966 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004967
4968 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004969 def conv2dOp(
4970 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4971 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004972
4973 # IFM: NHWC
4974 # Filter: OHWI
4975 # OFM: NHWC
4976
Kevin Cheng550ccc52021-03-03 11:21:43 -08004977 h = (
4978 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004979 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004980 + padding[0]
4981 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004982 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004983 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004984
Kevin Cheng550ccc52021-03-03 11:21:43 -08004985 w = (
4986 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004987 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004988 + padding[2]
4989 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004990 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004991 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004992
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004993 if error_name == ErrorIf.ConvOutputShapeMismatch:
4994 choices = [1, 2, 3]
4995 change = rng.choice(choices)
4996 # increment in multiples of stride to not hit non-integer error case
4997 if change in [1, 3]:
4998 h = h + (rng.choice(choices) * strides[0])
4999 if change in [2, 3]:
5000 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005001
Eric Kunzee5e26762020-10-13 16:11:07 -07005002 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5003
James Ward8b390432022-08-12 20:48:56 +01005004 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005005 # Pick some potentially correct output dtype if input type is incorrect
5006 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005007 else:
James Ward8b390432022-08-12 20:48:56 +01005008 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005009
5010 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005011 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005012 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005013 else:
5014 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005015 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005016 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005017
Kevin Cheng550ccc52021-03-03 11:21:43 -08005018 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005019
5020 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005021 def conv3dOp(
5022 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5023 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005024
5025 # IFM: NDHWC
5026 # Filter: ODHWI
5027 # OFM: NDHWC
5028
5029 d = (
5030 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005031 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005032 + padding[0]
5033 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005034 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005035 ) // strides[0] + 1
5036
5037 h = (
5038 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005039 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005040 + padding[2]
5041 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005042 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005043 ) // strides[1] + 1
5044
5045 w = (
5046 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005047 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005048 + padding[4]
5049 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005050 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005051 ) // strides[2] + 1
5052
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005053 if error_name == ErrorIf.ConvOutputShapeMismatch:
5054 choices = [1, 2, 3, 4]
5055 change = rng.choice(choices)
5056 # increment in multiples of stride to not hit non-integer error case
5057 if change in [1, 4]:
5058 d = d + (rng.choice(choices) * strides[0])
5059 if change in [2, 4]:
5060 h = h + (rng.choice(choices) * strides[1])
5061 if change in [3, 4]:
5062 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005063
Kevin Cheng1533b852021-09-01 12:51:58 -07005064 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5065
James Ward8b390432022-08-12 20:48:56 +01005066 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005067 # Pick some potentially correct output dtype if input type is incorrect
5068 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005069 else:
James Ward8b390432022-08-12 20:48:56 +01005070 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005071
5072 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005073 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005074 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005075 else:
5076 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005077 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005078 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005079
5080 return ser.addOutput(ofm_shape, out_dtype)
5081
5082 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005083 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005084 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005085 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005086 # IFM: NHWC
5087 # Filter: HWCM
5088 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005089
Kevin Cheng550ccc52021-03-03 11:21:43 -08005090 h = (
5091 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005092 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005093 + padding[0]
5094 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005095 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005096 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005097
Kevin Cheng550ccc52021-03-03 11:21:43 -08005098 w = (
5099 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005100 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005101 + padding[2]
5102 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005103 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005104 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005105
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005106 if error_name == ErrorIf.ConvOutputShapeMismatch:
5107 choices = [1, 2, 3]
5108 change = rng.choice(choices)
5109 # increment in multiples of stride to not hit non-integer error case
5110 if change in [1, 3]:
5111 h = h + (rng.choice(choices) * strides[0])
5112 if change in [2, 3]:
5113 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005114
Eric Kunzee5e26762020-10-13 16:11:07 -07005115 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5116
James Ward8b390432022-08-12 20:48:56 +01005117 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005118 # Pick some potentially correct output dtype if input type is incorrect
5119 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005120 else:
James Ward8b390432022-08-12 20:48:56 +01005121 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005122
5123 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005124 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005125 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005126 else:
5127 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005128 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005129 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005130
Kevin Cheng550ccc52021-03-03 11:21:43 -08005131 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005132
5133 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005134 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005135 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005136 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005137 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005138 h = 1
5139 w = 1
5140 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005141 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5142 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005143
5144 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005145 choices = [1, 2, 3]
5146 change = rng.choice(choices)
5147 # increment in multiples of stride to not hit non-integer error case
5148 if change in [1, 3]:
5149 h = h + (rng.choice(choices) * stride[0])
5150 if change in [2, 3]:
5151 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005152 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005153
5154 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005155 all_dtypes = [
5156 DType.INT8,
5157 DType.INT16,
5158 DType.INT32,
5159 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005160 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005161 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005162 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005163 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005164 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5165 outputDType = rng.choice(wrong_dtypes)
5166 else:
5167 outputDType = ifm.dtype
5168
5169 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005170
5171 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005172 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005173 # input: N, IC
5174 # filter: OC, IC
5175 # output: N, OC
5176
5177 output_shape = [input.shape[0], filter.shape[0]]
5178
James Ward8b390432022-08-12 20:48:56 +01005179 # Validated in arg_gen (also invalidated for ErrorIf)
5180 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005181
Kevin Cheng550ccc52021-03-03 11:21:43 -08005182 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005183
5184 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005185 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005186 # a: N, H, C
5187 # b: N, C, W
5188 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005189
Kevin Cheng2d60f002021-06-09 14:18:32 -07005190 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005191
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005192 if error_name == ErrorIf.WrongOutputType:
5193 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005194 incorrect_types = (
5195 DType.INT4,
5196 DType.INT8,
5197 DType.INT16,
5198 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005199 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005200 DType.FP16,
5201 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005202 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005203 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005204 incorrect_types = (
5205 DType.INT4,
5206 DType.INT8,
5207 DType.INT16,
5208 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005209 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005210 DType.FP16,
5211 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005212 )
James Ward24dbc422022-10-19 12:20:31 +01005213 elif (
5214 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5215 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005216 incorrect_types = (
5217 DType.INT4,
5218 DType.INT8,
5219 DType.INT16,
5220 DType.INT32,
5221 DType.INT48,
5222 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005223 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005224 elif error_name == ErrorIf.WrongInputType:
5225 # Pick some potentially correct output dtype if input type is incorrect
5226 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005227 else:
James Ward8b390432022-08-12 20:48:56 +01005228 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005229
Kevin Cheng550ccc52021-03-03 11:21:43 -08005230 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005231
5232 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005233 def concatOp(ser, rng, axis, inputs, error_name=None):
5234 input1 = inputs[0]
5235 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005236
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005237 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005238 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005239 if not (
5240 # unable to concat tensors of different ranks
5241 error_name == ErrorIf.ConcatInputRankMismatch
5242 # unable to concat tensors along an invalid axis
5243 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005244 ):
5245 for tensor in remaining_inputs:
5246 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005247
Matthew Haddon01c359d2021-10-15 16:30:48 +01005248 if error_name == ErrorIf.ConcatShapeSumMismatch:
5249 output_shape[axis] += rng.integers(5, 10)
5250
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005251 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005252 all_dtypes = {
5253 DType.INT8,
5254 DType.INT16,
5255 DType.INT32,
5256 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005257 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005258 DType.FP16,
5259 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005260 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005261 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5262 outputDType = rng.choice(wrong_dtypes)
5263 else:
5264 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005265
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005266 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005267
5268 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005269 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005270
5271 output_shape = a.shape.copy()
5272
5273 for i in range(len(output_shape)):
5274 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5275
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005276 if error_name == ErrorIf.PadOutputShapeMismatch:
5277 bad_dim = rng.choice(range(len(output_shape)))
5278 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005279 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005280 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005281
Matthew Haddone807aae2021-10-11 18:12:58 +01005282 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005283 all_dtypes = [
5284 DType.INT8,
5285 DType.INT16,
5286 DType.INT32,
5287 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005288 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005289 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005290 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005291 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005292 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5293 outputDType = rng.choice(wrong_dtypes)
5294 else:
5295 outputDType = a.dtype
5296
5297 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005298
5299 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005300 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005301 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005302
5303 if error_name == ErrorIf.WrongOutputType:
5304 all_dtypes = [
5305 DType.INT8,
5306 DType.INT16,
5307 DType.INT32,
5308 DType.INT48,
5309 DType.FP32,
5310 DType.FP16,
5311 DType.BF16,
5312 ]
5313 wrong_dtypes = list(set(all_dtypes))
5314 outputDType = rng.choice(wrong_dtypes)
5315 else:
5316 outputDType = DType.SHAPE
5317
5318 return ser.addOutput(output_shape, outputDType)
5319
5320 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005321 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005322 output_shape = shape.copy()
5323
Matthew Haddone807aae2021-10-11 18:12:58 +01005324 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5325 for i in range(len(output_shape)):
5326 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5327
5328 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005329 all_dtypes = [
5330 DType.INT8,
5331 DType.INT16,
5332 DType.INT32,
5333 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005334 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005335 DType.FP16,
5336 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005337 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005338 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5339 outputDType = rng.choice(wrong_dtypes)
5340 else:
5341 outputDType = a.dtype
5342
5343 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005344
5345 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005346 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005347
Matthew Haddone807aae2021-10-11 18:12:58 +01005348 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005349 all_dtypes = [
5350 DType.INT8,
5351 DType.INT16,
5352 DType.INT32,
5353 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005354 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005355 DType.FP16,
5356 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005357 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005358 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005359 outputDType = rng.choice(wrong_dtypes)
5360 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005361 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005362
Luke Huttona4e48ca2023-02-22 11:53:48 +00005363 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005364 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005365 for index in range(len(output_shape)):
5366 if output_shape[index] <= 2:
5367 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5368 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005369 output_shape[index] = output_shape[index] + rng.choice(
5370 [-2, -1, 1, 2]
5371 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005372 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5373 output_shape = input.shape.copy()
5374 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005375 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005376
5377 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005378
5379 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005380 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005381
5382 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005383 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005384
5385 for i in range(len(output_shape)):
5386 output_shape[i] = a.shape[i] * multiples[i]
5387
Luke Huttona4e48ca2023-02-22 11:53:48 +00005388 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005389 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005390
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005391 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005392 all_dtypes = [
5393 DType.INT8,
5394 DType.INT16,
5395 DType.INT32,
5396 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005397 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005398 DType.FP16,
5399 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005400 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005401 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5402 outputDType = rng.choice(wrong_dtypes)
5403 else:
5404 outputDType = a.dtype
5405
5406 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005407
5408 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005409 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005410 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005411
Kevin Cheng550ccc52021-03-03 11:21:43 -08005412 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005413
Luke Huttona4e48ca2023-02-22 11:53:48 +00005414 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005415 for i in range(len(output_shape)):
5416 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005417
Luke Huttona4e48ca2023-02-22 11:53:48 +00005418 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5419 for i in range(len(output_shape)):
5420 output_shape[i] += rng.integers(1, 10)
5421 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005422 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005423
Matthew Haddone807aae2021-10-11 18:12:58 +01005424 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005425 all_dtypes = [
5426 DType.INT8,
5427 DType.INT16,
5428 DType.INT32,
5429 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005430 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005431 DType.FP16,
5432 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005433 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005434 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5435 outputDType = rng.choice(wrong_dtypes)
5436 else:
5437 outputDType = a.dtype
5438
5439 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005440
5441 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005442 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005443 if error_name != ErrorIf.WrongRank:
5444 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005445 assert len(indices.shape) == 2
5446 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005447
Kevin Cheng77d0f762020-11-24 10:26:32 -08005448 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5449
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005450 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005451 all_dtypes = [
5452 DType.INT8,
5453 DType.INT16,
5454 DType.INT32,
5455 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005456 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005457 DType.FP16,
5458 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005459 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005460 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5461 outputDType = rng.choice(wrong_dtypes)
5462 else:
5463 outputDType = values.dtype
5464
5465 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005466
5467 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005468 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005469 if error_name != ErrorIf.WrongRank:
5470 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005471 assert len(indices.shape) == 2
5472 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005473 assert values_in.shape[0] == indices.shape[0] # N
5474 assert input.shape[1] == indices.shape[1] # W
5475 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005476
5477 output_shape = values_in.shape
5478
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005479 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005480 all_dtypes = [
5481 DType.INT8,
5482 DType.INT16,
5483 DType.INT32,
5484 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005485 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005486 DType.FP16,
5487 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005488 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005489 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5490 outputDType = rng.choice(wrong_dtypes)
5491 else:
5492 outputDType = values_in.dtype
5493
5494 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005495
5496 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005497 def tableOp(ser, rng, input, error_name=None):
5498 # Same shape as the input, dtype dependent on input dtype
5499 if error_name != ErrorIf.WrongInputType:
5500 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005501 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005502 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005503 wrong_dtypes = [
5504 DType.INT8,
5505 DType.INT16,
5506 DType.INT32,
5507 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005508 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005509 DType.FP16,
5510 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005511 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005512 wrong_dtypes.remove(output_dtype)
5513 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005514 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005515
5516 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005517 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005518 serializer,
5519 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005520 input,
5521 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005522 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005523 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005524 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005525 input_dtype,
5526 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005527 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005528 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005529 # Calculate OH, OW
5530 scale_y_n = scale[0]
5531 scale_y_d = scale[1]
5532 scale_x_n = scale[2]
5533 scale_x_d = scale[3]
5534 if error_name == ErrorIf.ScaleSmallerEqualZero:
5535 scale_y_n = max(scale_y_n, 1)
5536 scale_y_d = max(scale_y_d, 1)
5537 scale_x_n = max(scale_x_n, 1)
5538 scale_x_d = max(scale_x_d, 1)
5539
5540 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5541 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5542
5543 if error_name is not None:
5544 # Make sure the output tensor is valid, which can occur when
5545 # scale, offset or border have been changed for ERROR_IFs
5546 oh = max(oh, 1)
5547 ow = max(ow, 1)
5548 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005549 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5550 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005551
5552 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5553 choices = [1, 2, 3]
5554 change = rng.choice(choices)
5555 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5556 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005557 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005558 oh -= scale_y_d
5559 assert oh > 0 # Should have been caught in agResize
5560 else:
5561 oh += scale_y_d
5562 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005563 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005564 ow -= scale_x_d
5565 assert ow > 0 # Should have been caught in agResize
5566 else:
5567 ow += scale_x_d
5568
Matthew Haddon848efb42021-09-09 12:30:53 +01005569 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005570 output_dims = [
5571 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005572 oh,
5573 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005574 input.shape[0],
5575 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005576 elif error_name == ErrorIf.BatchMismatch:
5577 output_dims = [
5578 input.shape[0] + rng.integers(1, 10),
5579 oh,
5580 ow,
5581 input.shape[3],
5582 ]
5583 elif error_name == ErrorIf.ChannelMismatch:
5584 output_dims = [
5585 input.shape[0],
5586 oh,
5587 ow,
5588 input.shape[3] + rng.integers(1, 10),
5589 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005590 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005591 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005592
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005593 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005594
5595 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005596 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005597 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005598
5599 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005600 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005601 if error_name == ErrorIf.ConvOutputShapeMismatch:
5602 choices = [1, 2, 3]
5603 change = rng.choice(choices)
5604 if change in [1, 3]:
5605 output_shape[1] = output_shape[1] + rng.choice(choices)
5606 if change in [2, 3]:
5607 output_shape[2] = output_shape[2] + rng.choice(choices)
5608
James Ward8b390432022-08-12 20:48:56 +01005609 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005610 # Pick some potentially correct output dtype if input type is incorrect
5611 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005612 else:
James Ward8b390432022-08-12 20:48:56 +01005613 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005614
5615 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005616 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005617 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005618 else:
5619 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005620 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005621 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005622
Kevin Cheng550ccc52021-03-03 11:21:43 -08005623 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005624
5625 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005626 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5627 outputs = []
5628
5629 assert ifm1.dtype == ifm2.dtype
5630 input_dtype = ifm1.dtype
5631
5632 if error_name != ErrorIf.FFTInputShapeMismatch:
5633 assert ifm1.shape == ifm2.shape
5634
5635 input_shape = ifm1.shape
5636 if error_name != ErrorIf.WrongRank:
5637 assert len(input_shape) == 3
5638
5639 output_shape = input_shape.copy()
5640 output_dtype = input_dtype
5641
5642 if error_name == ErrorIf.WrongOutputType:
5643 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005644 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005645 output_dtype = rng.choice(wrong_dtypes)
5646 elif error_name == ErrorIf.BatchMismatch:
5647 output_shape[0] += rng.integers(1, 10)
5648 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5649 modify_dim = rng.choice([1, 2])
5650 output_shape[modify_dim] += rng.integers(1, 10)
5651
5652 outputs.append(serializer.addOutput(output_shape, output_dtype))
5653 outputs.append(serializer.addOutput(output_shape, output_dtype))
5654 return outputs
5655
5656 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005657 def rfft2dOp(serializer, rng, value, error_name=None):
5658 outputs = []
5659
5660 input_shape = value.shape
5661 if error_name != ErrorIf.WrongRank:
5662 assert len(input_shape) == 3
5663
5664 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5665
5666 output_dtype = value.dtype
5667 if error_name == ErrorIf.WrongOutputType:
5668 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005669 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005670 output_dtype = rng.choice(wrong_dtypes)
5671 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005672 output_shape[0] += rng.integers(1, 10)
5673 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5674 modify_dim = rng.choice([1, 2])
5675 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005676
5677 outputs.append(serializer.addOutput(output_shape, output_dtype))
5678 outputs.append(serializer.addOutput(output_shape, output_dtype))
5679 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005680
5681 @staticmethod
5682 def addShapeOp(ser, rng, a, b, error_name=None):
5683 if error_name != ErrorIf.RankMismatch:
5684 assert len(a.shape) == len(b.shape)
5685 assert a.dtype == b.dtype
5686
5687 shape = []
5688 for i in range(len(a.shape)):
5689 shape.append(a.shape[i])
5690
5691 fuzz_idx = rng.integers(0, len(a.shape))
5692 if error_name == ErrorIf.DimensionMismatch:
5693 shape[fuzz_idx] += 1
5694
5695 if error_name == ErrorIf.WrongOutputType:
5696 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5697 outputDType = rng.choice(wrong_dtypes)
5698 else:
5699 outputDType = DType.SHAPE
5700 return ser.addOutput(shape, outputDType)