blob: b9352acbb9795a7e2fec0221d5604a0914a30c1a [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
79 for dtype in (DType.FP32, DType.FP16, DType.BF16):
80 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000194 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100195 return np.int64(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
197 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
198
199 if dtype == DType.FP16:
200 return np.float16(f_tensor)
201 else:
202 f32_tensor = np.float32(f_tensor)
203 if dtype == DType.BF16:
204 # Floor the last 16 bits of each f32 value
205 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
206 else:
207 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700208 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100209 # All other integer types
210 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700211
Kevin Cheng989cb052021-04-28 16:29:44 -0700212 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700213 placeholders = []
214
Kevin Cheng989cb052021-04-28 16:29:44 -0700215 assert len(shape_list) == len(dtype_list)
216
Jeremy Johnson1271c442023-09-05 11:39:26 +0100217 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100219 if not self.args.lazy_data_gen:
220 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700221 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700222
223 return placeholders
224
Kevin Cheng989cb052021-04-28 16:29:44 -0700225 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700226 consts = []
227
Kevin Cheng989cb052021-04-28 16:29:44 -0700228 assert len(shape_list) == len(dtype_list)
229
Jeremy Johnson1271c442023-09-05 11:39:26 +0100230 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700231 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100232 if not self.args.lazy_data_gen:
233 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700234 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700235
236 return consts
237
238 def makeShape(self, rank):
239 if self.targetted_shape:
240 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800241 return np.int32(
242 self.rng.integers(
243 low=self.args.tensor_shape_range[0],
244 high=self.args.tensor_shape_range[1],
245 size=rank,
246 )
247 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700248
249 def setTargetShape(self, shape):
250 self.targetted_shape = shape
251
252 def randInt(self, low=0, high=256):
253 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
254
255 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100256 low, high = self.getDTypeRange(dtype)
257
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100258 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100259 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100260 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100261 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100262 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100263 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
264 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 elif dtype == DType.BOOL:
266 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000267 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700268 # Special size
269 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700270
271 return np.int32(self.rng.integers(low, high, size=1))[0]
272
273 def shapeStr(self, shape):
274
275 sStr = []
276 # Convert to strings
277 for i in shape:
278 sStr.append(str(i))
279
Kevin Cheng550ccc52021-03-03 11:21:43 -0800280 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700281
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100282 def typeStr(self, dtype):
283 if isinstance(dtype, list) or isinstance(dtype, tuple):
284 assert len(dtype) >= 2
285 strs = [self.typeStr(t) for t in dtype]
286 # Limit types to the first 2 as the 3rd is the accumulator
287 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700288 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100289 if dtype in gtu.DTYPE_ATTRIBUTES:
290 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700291 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100292 raise Exception(
293 "Unknown dtype, cannot convert to string: {}".format(dtype)
294 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700295
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100296 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100297 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100298 if dtype in gtu.DTYPE_ATTRIBUTES:
299 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700300 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100301 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700302
Luke Hutton57287132023-02-06 14:54:18 +0000303 def constrictBatchSize(self, shape):
304 # Limit the batch size unless an explicit target shape set
305 if self.args.max_batch_size and not self.args.target_shapes:
306 shape[0] = min(shape[0], self.args.max_batch_size)
307 return shape
308
James Ward30124a82023-02-02 14:56:33 +0000309 def makeDimension(self):
310 return self.randInt(
311 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
312 )
313
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100314 def tensorComplianceMetaData(
315 self, op, inputType, argsDict, outputTensor, errorName
316 ):
Jeremy Johnson708da822023-11-15 16:25:45 +0000317 # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet
318 UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100319 if (
320 errorName
321 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000322 or (
323 not gtu.dtypeIsSupportedByCompliance(inputType)
324 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
325 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100326 ):
327 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100328 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100329
Jeremy Johnson1271c442023-09-05 11:39:26 +0100330 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100331 compliance_tens = {
332 "mode": None,
333 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
334 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
335 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100336 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
337 mode = gtu.ComplianceMode.DOT_PRODUCT
338 compliance_tens["dot_product_info"] = {
339 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100340 "ks": int(argsDict["ksb"])
341 if "ksb" in argsDict
342 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100343 }
344 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
345 mode = gtu.ComplianceMode.FP_SPECIAL
346 elif "compliance" in op and "ulp" in op["compliance"]:
347 mode = gtu.ComplianceMode.ULP
348 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
349 elif op["op"] == Op.REDUCE_PRODUCT:
350 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000351 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000352 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000353 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000354 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
355 compliance_tens["abs_error_info"] = {
356 "lower_bound": op["compliance"]["abs_error_lower_bound"]
357 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100358 else:
359 mode = gtu.ComplianceMode.EXACT
360 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
361
362 return compliance_tens
363
364 # Build Op functions
365 # Create the output tensor (calling OutputShaper as needed)
366 # Do final tweaks to attributes (if necessary for errorIf)
367 # Add Op into graph
368 # Return resulting tensor information or BuildInfo
369
370 class BuildInfo:
371 """Enhanced build information containing result tensor and associated compliance dict."""
372
373 def __init__(self, resultTensor, complianceDict):
374 self.resultTensor = resultTensor
375 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700376
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000377 def build_unary(
378 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
379 ):
380 assert len(inputs) == 1
381 a = inputs[0]
382 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100383
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000384 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100385
386 # Ensure new output type has correct qinfo
387 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000388 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000389 qinfo = [
390 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000391 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000392 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100393
394 # Invalidate Input/Output list for error if checks.
395 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000396 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100397 pCount, cCount = op["operands"]
398 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000399 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
400 self, error_name, input_list, output_list
401 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100402
Les Bell729b0352021-11-24 10:28:21 +0000403 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100404 self.ser,
405 validator_fcns,
406 error_name,
407 op=op,
408 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000409 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000410 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000411 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100412 input_list=input_list,
413 output_list=output_list,
414 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000415 ):
416 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100417
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000418 attr = None
419 if op["op"] == Op.NEGATE:
420 attr = ts.TosaSerializerAttribute()
421 attr.NegateAttribute(qinfo[0], qinfo[1])
422
423 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000424
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000425 compliance = self.tensorComplianceMetaData(
426 op, a.dtype, args_dict, result_tensor, error_name
427 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000428 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700429
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000430 def build_binary_broadcast(
431 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
432 ):
433 assert len(inputs) == 2
434 a, b = inputs
435 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000436 self.ser, self.rng, a, b, error_name
437 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100438
439 # Invalidate Input/Output list for error if checks.
440 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000441 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100442 pCount, cCount = op["operands"]
443 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000444 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
445 self, error_name, input_list, output_list
446 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100447
Les Bell729b0352021-11-24 10:28:21 +0000448 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100449 self.ser,
450 validator_fcns,
451 error_name,
452 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000453 input1=a,
454 input2=b,
455 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000456 output_dtype=result_tensor.dtype,
457 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100458 input_list=input_list,
459 output_list=output_list,
460 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000461 ):
462 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100463
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000464 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000465
Jeremy Johnson9a758382023-11-07 16:27:35 +0000466 compliance = self.tensorComplianceMetaData(
467 op, a.dtype, args_dict, result_tensor, error_name
468 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000469
470 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700471
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100472 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700473 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000474 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700475 return result_tens
476
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000477 def build_arithmetic_right_shift(
478 self, op, a, b, round, validator_fcns=None, error_name=None
479 ):
480 result_tens = OutputShaper.binaryBroadcastOp(
481 self.ser, self.rng, a, b, error_name
482 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100483
484 # Invalidate Input/Output list for error if checks.
485 input_list = [a.name, b.name]
486 output_list = [result_tens.name]
487 pCount, cCount = op["operands"]
488 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000489 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
490 self, error_name, input_list, output_list
491 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100492
Les Bell729b0352021-11-24 10:28:21 +0000493 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100494 self.ser,
495 validator_fcns,
496 error_name,
497 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000498 input1=a,
499 input2=b,
500 input_dtype=a.dtype,
501 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000502 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100503 input_list=input_list,
504 output_list=output_list,
505 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000506 ):
507 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800508
509 attr = ts.TosaSerializerAttribute()
510 attr.ArithmeticRightShiftAttribute(round)
511
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000512 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800513 return result_tens
514
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100515 def build_mul(
516 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
517 ):
518 assert len(inputs) == 2
519 a, b = inputs
520 shift = args_dict["shift"]
521
522 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000523 self.ser, self.rng, a, b, error_name
524 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700525
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100526 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100527 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100528 result_tensor.setDtype(DType.INT32)
529
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100530 if error_name == ErrorIf.WrongOutputType:
531 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
532 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100533 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100534
535 # Invalidate Input/Output list for error if checks.
536 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100537 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100538 pCount, cCount = op["operands"]
539 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000540 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
541 self, error_name, input_list, output_list
542 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100543
Les Bell729b0352021-11-24 10:28:21 +0000544 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100545 self.ser,
546 validator_fcns,
547 error_name,
548 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000549 input1=a,
550 input2=b,
551 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100552 output_dtype=result_tensor.dtype,
553 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100554 input_list=input_list,
555 output_list=output_list,
556 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000557 ):
558 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700559
Kevin Chengaee1fac2020-11-11 13:54:06 -0800560 attr = ts.TosaSerializerAttribute()
561 attr.MulAttribute(shift)
562
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000563 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100564
565 compliance = self.tensorComplianceMetaData(
566 op, a.dtype, args_dict, result_tensor, error_name
567 )
568
569 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700570
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100571 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
572 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700573
Kevin Chengfe392ce2021-10-18 21:51:55 +0000574 attr = ts.TosaSerializerAttribute()
575 attr.TableAttribute(table)
576
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100577 # Invalidate Input/Output list for error if checks.
578 input_list = [a.name]
579 output_list = [result_tens.name]
580 pCount, cCount = op["operands"]
581 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000582 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
583 self, error_name, input_list, output_list
584 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100585
Les Bell729b0352021-11-24 10:28:21 +0000586 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100587 self.ser,
588 validator_fcns,
589 error_name,
590 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000591 input_shape=a.shape,
592 input_dtype=a.dtype,
593 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000594 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100595 input_list=input_list,
596 output_list=output_list,
597 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000598 ):
599 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100600
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000601 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700602
603 return result_tens
604
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000605 def build_select(
606 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
607 ):
608 assert len(inputs) == 3
609 cond, a, b = inputs
610
611 result_tensor = OutputShaper.selectOp(
612 self.ser, self.rng, cond, a, b, error_name
613 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100614
615 # Invalidate Input/Output list for error if checks.
616 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000617 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100618 pCount, cCount = op["operands"]
619 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000620 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
621 self, error_name, input_list, output_list
622 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100623
Les Bell729b0352021-11-24 10:28:21 +0000624 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100625 self.ser,
626 validator_fcns,
627 error_name,
628 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000629 input1=cond,
630 input2=a,
631 input3=b,
632 input_shape=a.shape,
633 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000634 output_dtype=result_tensor.dtype,
635 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100636 input_list=input_list,
637 output_list=output_list,
638 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000639 ):
640 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100641
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000642 self.ser.addOperator(
643 op["op"],
644 input_list,
645 output_list,
646 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000647 compliance = self.tensorComplianceMetaData(
648 op, a.dtype, args_dict, result_tensor, error_name
649 )
650
651 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700652
Jeremy Johnsona0150012023-11-15 15:52:06 +0000653 def build_comparison(
654 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
655 ):
656 assert len(inputs) == 2
657 a, b = inputs
658
659 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000660 self.ser, self.rng, a, b, error_name
661 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100662
663 # Invalidate Input/Output list for error if checks.
664 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000665 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100666 pCount, cCount = op["operands"]
667 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000668 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
669 self, error_name, input_list, output_list
670 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100671
Les Bell729b0352021-11-24 10:28:21 +0000672 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100673 self.ser,
674 validator_fcns,
675 error_name,
676 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000677 input1=a,
678 input2=b,
679 input_shape=a.shape,
680 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000681 output_shape=result_tensor.shape,
682 output_dtype=result_tensor.dtype,
683 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100684 input_list=input_list,
685 output_list=output_list,
686 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000687 ):
688 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100689
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000690 self.ser.addOperator(
691 op["op"],
692 input_list,
693 output_list,
694 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000695
696 compliance = self.tensorComplianceMetaData(
697 op, a.dtype, args_dict, result_tensor, error_name
698 )
699 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700700
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000701 def build_argmax(
702 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
703 ):
704 assert len(inputs) == 1
705 a = inputs[0]
706 axis = args_dict["axis"]
707 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100708
709 # Invalidate Input/Output list for error if checks.
710 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000711 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100712 pCount, cCount = op["operands"]
713 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000714 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
715 self, error_name, input_list, output_list
716 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100717
Les Bell729b0352021-11-24 10:28:21 +0000718 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100719 self.ser,
720 validator_fcns,
721 error_name,
722 op=op,
723 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000724 input_shape=a.shape,
725 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000726 output_shape=result_tensor.shape,
727 output_dtype=result_tensor.dtype,
728 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100729 input_list=input_list,
730 output_list=output_list,
731 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000732 ):
733 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700734
735 attr = ts.TosaSerializerAttribute()
736 attr.AxisAttribute(axis)
737
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000738 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000739
740 compliance = self.tensorComplianceMetaData(
741 op, inputs[0].dtype, args_dict, result_tensor, error_name
742 )
743 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700744
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000745 def build_pool2d(
746 self,
747 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100748 inputs,
749 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000750 validator_fcns=None,
751 error_name=None,
752 qinfo=None,
753 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100754 assert len(inputs) == 1
755 input = inputs[0]
756 # max_pool has no accum_dtype
757 accum_dtype = (
758 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
759 )
760 stride = args_dict["stride"]
761 pad = args_dict["pad"]
762 kernel = args_dict["kernel"]
763
Jeremy Johnson0601f802023-11-08 16:28:09 +0000764 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000765 self.ser, self.rng, input, kernel, stride, pad, error_name
766 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100767
768 # Ensure new output type has correct qinfo
769 if error_name == ErrorIf.WrongInputType:
770 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000771 qinfo = [
772 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000773 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000774 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100775
776 # Invalidate Input/Output list for error if checks.
777 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000778 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100779 pCount, cCount = op["operands"]
780 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000781 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
782 self, error_name, input_list, output_list
783 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100784
Les Bell729b0352021-11-24 10:28:21 +0000785 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100786 self.ser,
787 validator_fcns,
788 error_name,
789 op=op,
790 input_shape=input.shape,
791 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000792 output_shape=result_tensor.shape,
793 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100794 kernel=kernel,
795 stride=stride,
796 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000797 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000798 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100799 input_list=input_list,
800 output_list=output_list,
801 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000802 ):
803 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700804
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000805 if qinfo is None:
806 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700807
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000808 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100809 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000810
811 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700812
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100813 compliance = self.tensorComplianceMetaData(
814 op, inputs[0].dtype, args_dict, result_tensor, error_name
815 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100816
817 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100818
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000819 def build_conv2d(
820 self,
821 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100822 inputs,
823 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000824 validator_fcns=None,
825 error_name=None,
826 qinfo=None,
827 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100828 assert len(inputs) == 3
829 ifm, filter, bias = inputs
830 accum_dtype = args_dict["acc_type"]
831 strides = args_dict["stride"]
832 padding = args_dict["pad"]
833 dilations = args_dict["dilation"]
834
Kevin Cheng550ccc52021-03-03 11:21:43 -0800835 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100836 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100837 self.ser,
838 self.rng,
839 ifm,
840 filter,
841 accum_dtype,
842 strides,
843 padding,
844 dilations,
845 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000846 )
847
848 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000849 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
850 DType.INT8,
851 DType.UINT8,
852 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000853 qinfo = [
854 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100855 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000856 ]
Les Bell0e027d42021-11-09 14:42:14 +0000857
858 # Invalidate Input/Output list for error_if checks.
859 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100860 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000861 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000862 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
863 self, error_name, input_list, output_list
864 )
Les Bell0e027d42021-11-09 14:42:14 +0000865
Les Bell729b0352021-11-24 10:28:21 +0000866 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000867 self.ser,
868 validator_fcns,
869 error_name,
870 op=op,
871 input_dtype=ifm.dtype,
872 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100873 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000874 qinfo=qinfo,
875 input_list=input_list,
876 num_operands=num_operands,
877 output_list=output_list,
878 pad=padding,
879 stride=strides,
880 dilation=dilations,
881 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100882 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100883 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000884 ):
885 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700886
Tai Lyd3797f02023-11-15 23:06:19 +0000887 # TODO - Test local_bound, for now set local bound attribute to False
888 local_bound = False
889
Eric Kunzee5e26762020-10-13 16:11:07 -0700890 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000891 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700892
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000893 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100894
895 compliance = self.tensorComplianceMetaData(
896 op, ifm.dtype, args_dict, result_tensor, error_name
897 )
898
899 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700900
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000901 def build_conv3d(
902 self,
903 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100904 inputs,
905 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000906 validator_fcns=None,
907 error_name=None,
908 qinfo=None,
909 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100910 assert len(inputs) == 3
911 ifm, filter, bias = inputs
912 accum_dtype = args_dict["acc_type"]
913 strides = args_dict["stride"]
914 padding = args_dict["pad"]
915 dilations = args_dict["dilation"]
916
Kevin Cheng1533b852021-09-01 12:51:58 -0700917 assert len(padding) == 6
918 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100919 self.ser,
920 self.rng,
921 ifm,
922 filter,
923 accum_dtype,
924 strides,
925 padding,
926 dilations,
927 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000928 )
929
930 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000931 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
932 DType.INT8,
933 DType.UINT8,
934 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000935 qinfo = [
936 TosaQuantGen.getZeroPoint(self, ifm.dtype),
937 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
938 ]
Les Bell0e027d42021-11-09 14:42:14 +0000939
940 # Invalidate Input/Output list for error_if checks.
941 input_list = [ifm.name, filter.name, bias.name]
942 output_list = [result_tens.name]
943 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000944 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
945 self, error_name, input_list, output_list
946 )
Les Bell0e027d42021-11-09 14:42:14 +0000947
Les Bell729b0352021-11-24 10:28:21 +0000948 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000949 self.ser,
950 validator_fcns,
951 error_name,
952 op=op,
953 input_dtype=ifm.dtype,
954 weight_dtype=filter.dtype,
955 output_dtype=result_tens.dtype,
956 qinfo=qinfo,
957 input_list=input_list,
958 num_operands=num_operands,
959 output_list=output_list,
960 pad=padding,
961 stride=strides,
962 dilation=dilations,
963 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100964 weight_shape=filter.shape,
965 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000966 ):
967 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700968
Tai Lyd3797f02023-11-15 23:06:19 +0000969 # TODO - Test local_bound, for now set local bound attribute to False
970 local_bound = False
971
Kevin Cheng1533b852021-09-01 12:51:58 -0700972 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000973 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700974
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000975 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700976 return result_tens
977
Kevin Cheng550ccc52021-03-03 11:21:43 -0800978 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000979 self,
980 op,
981 ifm,
982 filter,
983 bias,
James Ward8b390432022-08-12 20:48:56 +0100984 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000985 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700986 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000987 output_shape,
988 validator_fcns=None,
989 error_name=None,
990 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800991 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700992 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000993 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100994 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000995 )
Les Bell0e027d42021-11-09 14:42:14 +0000996
997 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000998 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
999 DType.INT8,
1000 DType.UINT8,
1001 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001002 qinfo = [
1003 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1004 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1005 ]
Les Bell0e027d42021-11-09 14:42:14 +00001006
1007 # Invalidate Input/Output list for error_if checks.
1008 input_list = [ifm.name, filter.name, bias.name]
1009 output_list = [result_tens.name]
1010 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001011 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1012 self, error_name, input_list, output_list
1013 )
Les Bell0e027d42021-11-09 14:42:14 +00001014
Les Bell729b0352021-11-24 10:28:21 +00001015 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001016 self.ser,
1017 validator_fcns,
1018 error_name,
1019 op=op,
1020 input_dtype=ifm.dtype,
1021 weight_dtype=filter.dtype,
1022 output_dtype=result_tens.dtype,
1023 qinfo=qinfo,
1024 input_list=input_list,
1025 num_operands=num_operands,
1026 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001027 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001028 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001029 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001030 weight_shape=filter.shape,
1031 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001032 ):
1033 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001034
Tai Lyd3797f02023-11-15 23:06:19 +00001035 # TODO - Test local_bound, for now set local bound attribute to False
1036 local_bound = False
1037
Eric Kunzee5e26762020-10-13 16:11:07 -07001038 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001039 attr.TransposeConvAttribute(
1040 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1041 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001042
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001043 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001044 return result_tens
1045
Kevin Cheng550ccc52021-03-03 11:21:43 -08001046 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001047 self,
1048 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001049 inputs,
1050 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001051 validator_fcns=None,
1052 error_name=None,
1053 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001054 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001055 assert len(inputs) == 3
1056 ifm, filter, bias = inputs
1057 accum_dtype = args_dict["acc_type"]
1058 strides = args_dict["stride"]
1059 padding = args_dict["pad"]
1060 dilations = args_dict["dilation"]
1061
Kevin Cheng550ccc52021-03-03 11:21:43 -08001062 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001063 self.ser,
1064 self.rng,
1065 ifm,
1066 filter,
1067 accum_dtype,
1068 strides,
1069 padding,
1070 dilations,
1071 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001072 )
1073
1074 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001075 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1076 DType.INT8,
1077 DType.UINT8,
1078 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001079 qinfo = [
1080 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1081 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1082 ]
Les Bell0e027d42021-11-09 14:42:14 +00001083
1084 # Invalidate Input/Output list for error_if checks.
1085 input_list = [ifm.name, filter.name, bias.name]
1086 output_list = [result_tens.name]
1087 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001088 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1089 self, error_name, input_list, output_list
1090 )
Les Bell0e027d42021-11-09 14:42:14 +00001091
Les Bell729b0352021-11-24 10:28:21 +00001092 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001093 self.ser,
1094 validator_fcns,
1095 error_name,
1096 op=op,
1097 input_dtype=ifm.dtype,
1098 weight_dtype=filter.dtype,
1099 output_dtype=result_tens.dtype,
1100 qinfo=qinfo,
1101 input_list=input_list,
1102 num_operands=num_operands,
1103 output_list=output_list,
1104 pad=padding,
1105 stride=strides,
1106 dilation=dilations,
1107 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001108 weight_shape=filter.shape,
1109 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001110 ):
1111 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001112
Tai Lyd3797f02023-11-15 23:06:19 +00001113 # TODO - Test local_bound, for now set local bound attribute to False
1114 local_bound = False
1115
Eric Kunzee5e26762020-10-13 16:11:07 -07001116 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001117 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001118
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001119 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001120 return result_tens
1121
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001122 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001123 self,
1124 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001125 inputs,
1126 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001127 validator_fcns=None,
1128 error_name=None,
1129 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001130 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001131 assert len(inputs) == 3
1132 ifm, filter, bias = inputs
1133 accum_dtype = args_dict["acc_type"]
1134
1135 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001136 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001137 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001138
1139 # Invalidate Input/Output list for error if checks.
1140 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001141 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001142 pCount, cCount = op["operands"]
1143 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001144 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1145 self, error_name, input_list, output_list
1146 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001147
Les Bell729b0352021-11-24 10:28:21 +00001148 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001149 self.ser,
1150 validator_fcns,
1151 error_name,
1152 op=op,
1153 input_shape=ifm.shape,
1154 input_dtype=ifm.dtype,
1155 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001156 output_shape=result_tensor.shape,
1157 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001158 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001159 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001160 input_list=input_list,
1161 output_list=output_list,
1162 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001163 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001164 ):
1165 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001166
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001167 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001168 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001169
1170 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001171
1172 compliance = self.tensorComplianceMetaData(
1173 op, ifm.dtype, args_dict, result_tensor, error_name
1174 )
1175
1176 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001177
James Ward8b390432022-08-12 20:48:56 +01001178 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001179 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001180 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001181 assert len(inputs) == 2
1182 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001183 accum_dtype = args_dict["acc_type"]
1184 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001185 self.ser, self.rng, a, b, accum_dtype, error_name
1186 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001187
1188 # Invalidate Input/Output list for error if checks.
1189 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001190 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001191 pCount, cCount = op["operands"]
1192 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001193 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1194 self, error_name, input_list, output_list
1195 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001196
Les Bell729b0352021-11-24 10:28:21 +00001197 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001198 self.ser,
1199 validator_fcns,
1200 error_name,
1201 op=op,
1202 input_shape=a.shape,
1203 input_dtype=a.dtype,
1204 input2_shape=b.shape,
1205 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001206 output_shape=result_tensor.shape,
1207 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001208 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001209 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001210 input_list=input_list,
1211 output_list=output_list,
1212 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001213 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001214 ):
1215 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001216
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001217 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001218 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001219
1220 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001221
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001222 compliance = self.tensorComplianceMetaData(
1223 op, a.dtype, args_dict, result_tensor, error_name
1224 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001225
1226 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001227
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001228 def build_reduce(
1229 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1230 ):
1231 assert len(inputs) == 1
1232 a = inputs[0]
1233 axis = args_dict["axis"]
1234 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001235
1236 # Invalidate Input/Output list for error if checks.
1237 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001238 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001239 pCount, cCount = op["operands"]
1240 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001241 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1242 self, error_name, input_list, output_list
1243 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001244
Les Bell729b0352021-11-24 10:28:21 +00001245 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001246 self.ser,
1247 validator_fcns,
1248 error_name,
1249 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001250 axis=axis,
1251 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001252 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001253 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001254 output_dtype=result_tensor.dtype,
1255 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001256 input_list=input_list,
1257 output_list=output_list,
1258 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001259 ):
1260 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001261
1262 attr = ts.TosaSerializerAttribute()
1263 attr.AxisAttribute(axis)
1264
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001265 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001266
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001267 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1268 # Number of products - needed for compliance
1269 args_dict["n"] = a.shape[axis]
1270
1271 compliance = self.tensorComplianceMetaData(
1272 op, a.dtype, args_dict, result_tensor, error_name
1273 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001274
1275 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001276
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001277 def build_clamp(
1278 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1279 ):
1280 assert len(inputs) == 1
1281 a = inputs[0]
1282
1283 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001284
Jeremy Johnson18e26662021-07-22 16:15:29 +01001285 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001286
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001287 if error_name == ErrorIf.MaxSmallerMin:
1288 # Make sure the numbers are different to invoke this error
1289 while v[0] == v[1]:
1290 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1291 max_val = min(v)
1292 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001293 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001294 max_val = max(v)
1295 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001296
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001297 # Invalidate Input/Output list for error if checks.
1298 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001299 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001300 pCount, cCount = op["operands"]
1301 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001302 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1303 self, error_name, input_list, output_list
1304 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001305
Les Bell729b0352021-11-24 10:28:21 +00001306 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001307 self.ser,
1308 validator_fcns,
1309 error_name,
1310 op=op,
1311 max_val=max_val,
1312 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001313 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001314 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001315 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001316 output_dtype=result_tensor.dtype,
1317 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001318 input_list=input_list,
1319 output_list=output_list,
1320 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001321 ):
1322 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001323
1324 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001325 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1326 if a.dtype == DType.FP16:
1327 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1328 min_val = min_val.astype(np.float32)
1329 max_val = max_val.astype(np.float32)
1330
1331 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001332 else:
James Ward34071252022-12-07 15:48:47 +00001333 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001334
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001335 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001336
1337 compliance = self.tensorComplianceMetaData(
1338 op, a.dtype, args_dict, result_tensor, error_name
1339 )
1340
1341 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001342
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001343 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1344 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001345 attr = ts.TosaSerializerAttribute()
1346
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001347 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001348
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001349 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001350 return result_tens
1351
1352 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001353 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1354 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001355
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001356 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001357 return result_tens
1358
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001359 def build_activation(
1360 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1361 ):
1362 assert len(inputs) == 1
1363 a = inputs[0]
1364
1365 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001366
1367 # Invalidate Input/Output list for error if checks.
1368 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001369 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001370 pCount, cCount = op["operands"]
1371 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001372 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1373 self, error_name, input_list, output_list
1374 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001375
Les Bell729b0352021-11-24 10:28:21 +00001376 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001377 self.ser,
1378 validator_fcns,
1379 error_name,
1380 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001381 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001382 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001383 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001384 output_dtype=result_tensor.dtype,
1385 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001386 input_list=input_list,
1387 output_list=output_list,
1388 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001389 ):
1390 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001391
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001392 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001393
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001394 compliance = self.tensorComplianceMetaData(
1395 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001396 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001397
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001398 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001399
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001400 def build_concat(
1401 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1402 ):
Won Jeon74342e52024-01-09 00:34:40 +00001403 if op["op"] == Op.CONCAT_SHAPE:
1404 axis = 0
1405 else:
1406 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001407 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001408 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001409
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001410 result_tensor = OutputShaper.concatOp(
1411 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001412 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001413
Matthew Haddon818ab902021-07-27 09:12:49 +01001414 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001415 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001416 input_tensor_names.append(tensor.name)
1417
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001418 # Invalidate Input/Output list for error if checks.
1419 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001420 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001421 pCount, cCount = op["operands"]
1422 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001423 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1424 self, error_name, input_list, output_list
1425 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001426
Les Bell729b0352021-11-24 10:28:21 +00001427 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001428 self.ser,
1429 validator_fcns,
1430 error_name,
1431 op=op,
1432 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001433 input_shape=inputs[0].shape,
1434 output_shape=result_tensor.shape,
1435 input_dtype=inputs[0].dtype,
1436 output_dtype=result_tensor.dtype,
1437 inputs=inputs,
1438 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001439 input_list=input_list,
1440 output_list=output_list,
1441 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001442 ):
1443 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001444
Won Jeon74342e52024-01-09 00:34:40 +00001445 if op["op"] == Op.CONCAT:
1446 attr = ts.TosaSerializerAttribute()
1447 attr.AxisAttribute(axis)
1448 else:
1449 assert op["op"] == Op.CONCAT_SHAPE
1450 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001451 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001452
1453 compliance = self.tensorComplianceMetaData(
1454 op, inputs[0].dtype, args_dict, result_tensor, error_name
1455 )
1456
1457 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001458
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001459 def build_pad(
1460 self,
1461 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001462 inputs,
1463 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001464 validator_fcns=None,
1465 error_name=None,
1466 qinfo=None,
1467 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001468 assert len(inputs) == 1
1469 a = inputs[0]
1470 padding = args_dict["pad"]
1471 pad_const_int = args_dict["pad_const_int"]
1472 pad_const_float = args_dict["pad_const_fp"]
1473
1474 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001475
Kevin Chengfe392ce2021-10-18 21:51:55 +00001476 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001477 attr.PadAttribute(
1478 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1479 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001480
Matthew Haddone807aae2021-10-11 18:12:58 +01001481 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001482 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001483 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001484 pCount, cCount = op["operands"]
1485 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001486 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1487 self, error_name, input_list, output_list
1488 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001489
Les Bell729b0352021-11-24 10:28:21 +00001490 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001491 self.ser,
1492 validator_fcns,
1493 error_name,
1494 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001495 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001496 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001497 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001498 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001499 pad=padding,
1500 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001501 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001502 input_list=input_list,
1503 output_list=output_list,
1504 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001505 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001506 ):
1507 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001508
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001509 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001510
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001511 compliance = self.tensorComplianceMetaData(
1512 op, a.dtype, args_dict, result_tensor, error_name
1513 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001514
1515 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001516
Won Jeona21b2e82023-08-10 10:33:01 +00001517 def build_dim(
1518 self,
1519 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001520 inputs,
1521 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001522 validator_fcns=None,
1523 error_name=None,
1524 qinfo=None,
1525 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001526 assert len(inputs) == 1
1527 a = inputs[0]
1528 axis = args_dict["axis"]
1529 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001530
1531 # Invalidate Input/Output list for error if checks.
1532 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001533 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001534 pCount, cCount = op["operands"]
1535 num_operands = pCount + cCount
1536 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1537 self, error_name, input_list, output_list
1538 )
1539
1540 if not TosaErrorValidator.evValidateErrorIfs(
1541 self.ser,
1542 validator_fcns,
1543 error_name,
1544 op=op,
1545 axis=axis,
1546 input_shape=a.shape,
1547 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001548 output_shape=result_tensor.shape,
1549 output_dtype=result_tensor.dtype,
1550 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001551 input_list=input_list,
1552 output_list=output_list,
1553 num_operands=num_operands,
1554 ):
1555 return None
1556
1557 attr = ts.TosaSerializerAttribute()
1558 attr.AxisAttribute(axis)
1559
1560 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001561 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001562
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001563 def build_reshape(
1564 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1565 ):
Tai Ly8690a082023-12-18 20:40:24 +00001566 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001567 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001568 shape = inputs[1]
1569 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001570 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001571 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001572 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001573
1574 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001575 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001576 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001577 pCount, cCount = op["operands"]
1578 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001579 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1580 self, error_name, input_list, output_list
1581 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001582
Les Bell729b0352021-11-24 10:28:21 +00001583 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001584 self.ser,
1585 validator_fcns,
1586 error_name,
1587 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001588 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001589 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001590 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001591 output_dtype=result_tensor.dtype,
1592 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001593 input_list=input_list,
1594 output_list=output_list,
1595 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001596 ):
1597 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001598
Tai Ly8690a082023-12-18 20:40:24 +00001599 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001600
1601 compliance = self.tensorComplianceMetaData(
1602 op, a.dtype, args_dict, result_tensor, error_name
1603 )
1604
1605 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001606
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001607 def build_reverse(
1608 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1609 ):
1610 assert len(inputs) == 1
1611 a = inputs[0]
1612 axis = args_dict["axis"]
1613 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001614
1615 # Invalidate Input/Output list for error if checks.
1616 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001617 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001618 pCount, cCount = op["operands"]
1619 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001620 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1621 self, error_name, input_list, output_list
1622 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001623
Les Bell729b0352021-11-24 10:28:21 +00001624 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001625 self.ser,
1626 validator_fcns,
1627 error_name,
1628 op=op,
1629 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001630 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001631 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001632 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001633 output_dtype=result_tensor.dtype,
1634 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001635 input_list=input_list,
1636 output_list=output_list,
1637 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001638 ):
1639 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001640
1641 attr = ts.TosaSerializerAttribute()
1642 attr.AxisAttribute(axis)
1643
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001644 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001645 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001646
Matthew Haddone807aae2021-10-11 18:12:58 +01001647 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1648 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001649
Kevin Chengfe392ce2021-10-18 21:51:55 +00001650 attr = ts.TosaSerializerAttribute()
1651 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001652
Matthew Haddone807aae2021-10-11 18:12:58 +01001653 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001654 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001655 output_list = [result_tens.name]
1656 pCount, cCount = op["operands"]
1657 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001658 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1659 self, error_name, input_list, output_list
1660 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001661
Les Bell729b0352021-11-24 10:28:21 +00001662 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001663 self.ser,
1664 validator_fcns,
1665 error_name,
1666 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001667 input_shape=a.shape,
1668 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001669 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001670 input_dtype=a.dtype,
1671 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001672 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001673 input_list=input_list,
1674 output_list=output_list,
1675 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001676 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001677 ):
1678 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001679
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001680 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001681 return result_tens
1682
Matthew Haddone807aae2021-10-11 18:12:58 +01001683 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001684 result_tens = OutputShaper.sliceOp(
1685 self.ser, self.rng, a, start, size, error_name
1686 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001687
1688 # Invalidate Input/Output list for error if checks.
1689 input_list = [a.name]
1690 output_list = [result_tens.name]
1691 pCount, cCount = op["operands"]
1692 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001693 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1694 self, error_name, input_list, output_list
1695 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001696
Les Bell729b0352021-11-24 10:28:21 +00001697 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001698 self.ser,
1699 validator_fcns,
1700 error_name,
1701 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001702 input_shape=a.shape,
1703 output_shape=result_tens.shape,
1704 input_dtype=a.dtype,
1705 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001706 start=start,
1707 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001708 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001709 input_list=input_list,
1710 output_list=output_list,
1711 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001712 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001713 ):
1714 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001715
1716 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001717 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001718
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001719 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001720 return result_tens
1721
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001722 def build_tile(
1723 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1724 ):
Tai Ly8690a082023-12-18 20:40:24 +00001725 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001726 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001727 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001728 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001729 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001730 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001731 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001732
1733 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001734 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001735 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001736 pCount, cCount = op["operands"]
1737 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001738 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1739 self, error_name, input_list, output_list
1740 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001741
Les Bell729b0352021-11-24 10:28:21 +00001742 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001743 self.ser,
1744 validator_fcns,
1745 error_name,
1746 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001747 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001748 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001749 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001750 output_dtype=result_tensor.dtype,
1751 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001752 input_list=input_list,
1753 output_list=output_list,
1754 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001755 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001756 ):
1757 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001758
Tai Ly8690a082023-12-18 20:40:24 +00001759 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001760
1761 compliance = self.tensorComplianceMetaData(
1762 op, a.dtype, args_dict, result_tensor, error_name
1763 )
1764
1765 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001766
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001767 def build_gather(
1768 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1769 ):
1770 assert len(inputs) == 2
1771 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001772
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001773 result_tensor = OutputShaper.gatherOp(
1774 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001775 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001776
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001777 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001778 input_list = [values.name, indices.name]
1779 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001780 pCount, cCount = op["operands"]
1781 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001782 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1783 self, error_name, input_list, output_list
1784 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001785
Les Bell729b0352021-11-24 10:28:21 +00001786 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001787 self.ser,
1788 validator_fcns,
1789 error_name,
1790 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001791 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001792 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001793 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001794 output_dtype=result_tensor.dtype,
1795 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001796 input_list=input_list,
1797 output_list=output_list,
1798 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001799 ):
1800 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001801
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001802 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001803
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001804 compliance = self.tensorComplianceMetaData(
1805 op, values.dtype, args_dict, result_tensor, error_name
1806 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001807
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001808 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001809
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001810 def build_scatter(
1811 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1812 ):
1813 assert len(inputs) == 3
1814 values_in, indices, input = inputs
1815 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001816 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001817 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001818
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001819 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001820 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001821 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001822 pCount, cCount = op["operands"]
1823 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001824 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1825 self, error_name, input_list, output_list
1826 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001827
Les Bell729b0352021-11-24 10:28:21 +00001828 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001829 self.ser,
1830 validator_fcns,
1831 error_name,
1832 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001833 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001834 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001835 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001836 output_dtype=result_tensor.dtype,
1837 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001838 input_list=input_list,
1839 output_list=output_list,
1840 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001841 ):
1842 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001843
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001844 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001845
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001846 compliance = self.tensorComplianceMetaData(
1847 op, values_in.dtype, args_dict, result_tensor, error_name
1848 )
1849
1850 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001851
Kevin Cheng550ccc52021-03-03 11:21:43 -08001852 def build_resize(
1853 self,
1854 op,
1855 input,
1856 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001857 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001858 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001859 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001860 input_dtype,
1861 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001862 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001863 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001864 ):
1865 result_tens = OutputShaper.resizeOp(
1866 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001867 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001868 input,
1869 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001870 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001871 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001872 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001873 input_dtype,
1874 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001875 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001876 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001877
Matthew Haddon848efb42021-09-09 12:30:53 +01001878 # Invalidate Input/Output list for error if checks.
1879 input_list = [input.name]
1880 output_list = [result_tens.name]
1881 pCount, cCount = op["operands"]
1882 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001883 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1884 self, error_name, input_list, output_list
1885 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001886
Les Bell729b0352021-11-24 10:28:21 +00001887 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001888 self.ser,
1889 validator_fcns,
1890 error_name,
1891 op=op,
1892 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001893 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001894 input_dtype=input_dtype,
1895 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001896 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001897 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001898 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001899 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001900 input_list=input_list,
1901 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001902 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001903 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001904 ):
1905 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001906
Eric Kunzee5e26762020-10-13 16:11:07 -07001907 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001908
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001909 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001910
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001911 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001912 return result_tens
1913
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001914 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1915 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1916 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001917 self.ser.addOperator(
1918 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1919 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001920 return result_tens
1921
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001922 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001923 self.ser.addOutputTensor(val)
1924 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001925
1926 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00001927 def build_cast(
1928 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1929 ):
1930 assert len(inputs) == 1
1931 val = inputs[0]
1932 out_dtype = args_dict["out_type"]
1933
1934 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001935 self.ser, self.rng, val, out_dtype, error_name
1936 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001937
1938 # Invalidate Input/Output list for error if checks.
1939 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00001940 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001941 pCount, cCount = op["operands"]
1942 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001943 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1944 self, error_name, input_list, output_list
1945 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001946
Les Bell729b0352021-11-24 10:28:21 +00001947 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001948 self.ser,
1949 validator_fcns,
1950 error_name,
1951 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001952 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00001953 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001954 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00001955 output_dtype=result_tensor.dtype,
1956 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001957 input_list=input_list,
1958 output_list=output_list,
1959 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001960 ):
1961 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001962
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001963 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00001964
1965 compliance = self.tensorComplianceMetaData(
1966 op, val.dtype, args_dict, result_tensor, error_name
1967 )
1968
1969 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001970
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001971 def build_rescale(
1972 self,
1973 op,
1974 val,
1975 out_dtype,
1976 scale32,
1977 double_round,
1978 per_channel,
1979 validator_fcns,
1980 error_name,
1981 ):
1982 result_tens = OutputShaper.typeConversionOp(
1983 self.ser, self.rng, val, out_dtype, error_name
1984 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001985
1986 if per_channel:
1987 nc = val.shape[-1]
1988 else:
1989 nc = 1
1990
1991 in_type_width = self.typeWidth(val.dtype)
1992 out_type_width = self.typeWidth(out_dtype)
1993
Tai Ly8690a082023-12-18 20:40:24 +00001994 input_unsigned = False
1995 output_unsigned = False
1996
Kevin Cheng3a478572021-01-22 17:21:02 -08001997 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001998 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001999 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002000 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002001 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002002 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002003 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002004 elif error_name in [
2005 ErrorIf.InputZeroPointNotZero,
2006 ErrorIf.U16InputZeroPointNotValid,
2007 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002008 input_zp = self.randInt(-128, 128)
2009 if input_zp == 0:
2010 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002011 in_type_width += 1
2012 elif val.dtype == DType.UINT16:
2013 # Must come after ErrorIf.U16InputZeroPointNotValid check
2014 input_zp = self.rng.choice([0, 32768])
2015 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002016 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002017 else:
2018 input_zp = 0
2019
Kevin Cheng3a478572021-01-22 17:21:02 -08002020 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002021 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002022 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002023 elif out_dtype == DType.UINT8:
2024 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002025 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002026 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002027 elif error_name in [
2028 ErrorIf.OutputZeroPointNotZero,
2029 ErrorIf.U16OutputZeroPointNotValid,
2030 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002031 output_zp = self.randInt(-128, 128)
2032 if output_zp == 0:
2033 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002034 out_type_width += 1
2035 elif out_dtype == DType.UINT16:
2036 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2037 output_zp = self.rng.choice([0, 32768])
2038 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002039 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002040 else:
2041 output_zp = 0
2042
2043 # Calculate scale based on:
2044 # scale = a *(2^output_width)/(2^input_width))
2045
2046 a = np.float32(self.rng.random(size=[nc]))
2047 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2048
2049 if scale32:
2050 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002051 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002052 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2053 else:
2054 # Cap the scaling at 2^15 - 1 for scale16
2055 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2056
Kevin Cheng550ccc52021-03-03 11:21:43 -08002057 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002058
2059 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2060 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002061 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2062 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002063
2064 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002065 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2066 scale_arr[i], scale32
2067 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002068 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2069 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002070
Kevin Cheng550ccc52021-03-03 11:21:43 -08002071 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002072 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002073 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002074 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002075 assert val.placeholderFilename
2076 values = np.load(
2077 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2078 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002079 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2080 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2081 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2082 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002083 if not np.all(np.array_equal(values, val_adj)):
2084 # Values changed so overwrite file with new values
2085 np.save(
2086 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2087 val_adj,
2088 False,
2089 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002090
Matthew Haddonc2025212021-10-08 21:21:05 +01002091 # Invalidate Input/Output list for error if checks.
2092 input_list = [val.name]
2093 output_list = [result_tens.name]
2094 pCount, cCount = op["operands"]
2095 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002096 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2097 self, error_name, input_list, output_list
2098 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002099
2100 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002101 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002102 self.ser,
2103 validator_fcns,
2104 error_name,
2105 op=op,
2106 input_dtype=val.dtype,
2107 output_dtype=out_dtype,
2108 input_shape=val.shape,
2109 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002110 scale32=scale32,
2111 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002112 input_list=input_list,
2113 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002114 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002115 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002116 ):
2117 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002118
Eric Kunzee5e26762020-10-13 16:11:07 -07002119 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002120 attr.RescaleAttribute(
2121 input_zp,
2122 output_zp,
2123 multiplier_arr,
2124 shift_arr,
2125 scale32,
2126 double_round,
2127 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002128 input_unsigned,
2129 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002130 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002131
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002132 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002133 return result_tens
2134
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002135 def _get_condition_tensor(self, op, cond, error_name):
2136 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002137 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002138 else:
2139 cond_type = DType.BOOL
2140 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2141 choice = self.rng.choice([1, 2])
2142 if choice == 1:
2143 cond_shape = [2]
2144 else:
2145 cond_shape = [1, 2]
2146 else:
2147 # Must be of size 1 (rank 0)
2148 cond_shape = []
2149 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2150 return cond_tens
2151
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002152 def build_cond_if_const(
2153 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2154 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002155 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002156 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002157 # and fill them with const nodes for the body.
2158
2159 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002160 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002161
2162 # Make then/else tensors
2163 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002164
2165 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002166 if error_name in [
2167 ErrorIf.CondIfOutputListThenGraphMismatch,
2168 ErrorIf.CondIfOutputListElseGraphMismatch,
2169 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002170 incorrect_shape = deepcopy(then_tens.shape)
2171 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002172 incorrect_shape[i] += (
2173 self.rng.choice([-3, -2, 2, 3])
2174 if incorrect_shape[i] > 3
2175 else self.rng.choice([1, 2, 4])
2176 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002177 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2178
Jeremy Johnson18e26662021-07-22 16:15:29 +01002179 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2180 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002181
2182 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002183 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002184
2185 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002186 then_block = "THEN_BLOCK"
2187 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002188 attr = ts.TosaSerializerAttribute()
2189 attr.CondIfAttribute(then_block, else_block)
2190
2191 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002192 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002193
Jerry Ge9e94af82022-10-27 09:57:00 -07002194 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002195 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002196 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2197 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2198 else:
2199 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002200 self.ser.addOutputTensor(then_tens)
2201
Jerry Ge9e94af82022-10-27 09:57:00 -07002202 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002203 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2204 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2205 else:
2206 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002207 self.ser.addOutputTensor(else_tens)
2208
Les Bell729b0352021-11-24 10:28:21 +00002209 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002210 self.ser,
2211 validator_fcns,
2212 error_name,
2213 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002214 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002215 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002216 ):
2217 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002218
Eric Kunzee5e26762020-10-13 16:11:07 -07002219 return result_tens
2220
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002221 def build_cond_if_binary(
2222 self, op, a, b, cond, validator_fcns=None, error_name=None
2223 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002224 # For cond_if with a binary op in the then/else blocks, take a and b and
2225 # alternately add or subtract them based on the condition
2226
2227 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002228 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002229
Kevin Cheng550ccc52021-03-03 11:21:43 -08002230 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002231
2232 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002233 then_block = "THEN_BLOCK"
2234 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002235 attr = ts.TosaSerializerAttribute()
2236 attr.CondIfAttribute(then_block, else_block)
2237
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002238 if error_name in [
2239 ErrorIf.CondIfInputListThenGraphMismatch,
2240 ErrorIf.CondIfInputListElseGraphMismatch,
2241 ErrorIf.CondIfOutputListElseGraphMismatch,
2242 ErrorIf.CondIfOutputListThenGraphMismatch,
2243 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002244 incorrect_shape = a.shape.copy()
2245 for i in range(len(incorrect_shape)):
2246 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2247 incorrect_block_input = deepcopy(a)
2248 incorrect_block_input.shape = incorrect_shape
2249
Eric Kunzee5e26762020-10-13 16:11:07 -07002250 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002251 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002252 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002253 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002254
James Ward24dbc422022-10-19 12:20:31 +01002255 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002256 then_op, else_op = Op.ADD, Op.SUB
2257 elif a.dtype in (DType.INT8, DType.INT16):
2258 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2259 else:
2260 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002261
Les Bell6040b4d2021-10-11 12:50:31 +01002262 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002263 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002264 if (
2265 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2266 and block == then_block
2267 ) or (
2268 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2269 and block == else_block
2270 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002271 self.ser.addInputTensor(incorrect_block_input)
2272 self.ser.addInputTensor(b)
2273 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002274 elif (
2275 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2276 and block == then_block
2277 ) or (
2278 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2279 and block == else_block
2280 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002281 self.ser.addInputTensor(a)
2282 self.ser.addInputTensor(b)
2283 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2284 else:
2285 self.ser.addInputTensor(a)
2286 self.ser.addInputTensor(b)
2287 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002288 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002289
Les Bell729b0352021-11-24 10:28:21 +00002290 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002291 self.ser,
2292 validator_fcns,
2293 error_name,
2294 op=op,
2295 a=a,
2296 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002297 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002298 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002299 ):
2300 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002301
Eric Kunzee5e26762020-10-13 16:11:07 -07002302 return result_tens
2303
Matthew Haddon630c17c2021-10-14 15:05:41 +01002304 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002305 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002306
Kevin Cheng550ccc52021-03-03 11:21:43 -08002307 cond_block = "COND_BLOCK"
2308 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002309
2310 attr = ts.TosaSerializerAttribute()
2311 attr.WhileLoopAttribute(cond_block, body_block)
2312
2313 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002314 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002315 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002316 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002317
2318 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002319 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2320 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002321 if error_name == ErrorIf.InputListOutputListMismatch:
2322 incorrect_acc = deepcopy(acc)
2323 for i in range(len(incorrect_acc.shape)):
2324 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2325 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2326 else:
2327 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002328
2329 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002330 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002331 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002332 [iter.name, a.name, acc.name],
2333 [iter_out.name, a_out.name, acc_out.name],
2334 attr,
2335 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002336 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002337
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002338 if error_name in [
2339 ErrorIf.InputListCondGraphMismatch,
2340 ErrorIf.InputListBodyGraphInputMismatch,
2341 ErrorIf.InputListBodyGraphOutputMismatch,
2342 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002343 incorrect_iter = deepcopy(iter)
2344 for i in range(len(incorrect_iter.shape)):
2345 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2346 if len(incorrect_iter.shape) == 0:
2347 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2348
2349 incorrect_acc = deepcopy(acc)
2350 for i in range(len(incorrect_acc.shape)):
2351 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2352
Eric Kunzee5e26762020-10-13 16:11:07 -07002353 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002354 self.ser.addBasicBlock(cond_block)
2355
Matthew Haddon630c17c2021-10-14 15:05:41 +01002356 if error_name == ErrorIf.InputListCondGraphMismatch:
2357 self.ser.addInputTensor(incorrect_iter)
2358 self.ser.addInputTensor(a)
2359 self.ser.addInputTensor(incorrect_acc)
2360 else:
2361 self.ser.addInputTensor(iter)
2362 self.ser.addInputTensor(a)
2363 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002364 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002365
2366 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002367 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002368 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002369 cond_type = DType.BOOL
2370 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2371 choice = self.rng.choice([1, 2])
2372 if choice == 1:
2373 cond_shape = [3]
2374 else:
2375 cond_shape = [1, 2]
2376 else:
2377 cond_shape = []
2378 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002379
Kevin Cheng550ccc52021-03-03 11:21:43 -08002380 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002381
2382 # BODY block (input: a, acc, iter, output: a, acc, iter)
2383 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002384 self.ser.addBasicBlock(body_block)
2385
Matthew Haddon630c17c2021-10-14 15:05:41 +01002386 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2387 self.ser.addInputTensor(incorrect_iter)
2388 self.ser.addInputTensor(a)
2389 self.ser.addInputTensor(incorrect_acc)
2390 else:
2391 self.ser.addInputTensor(iter)
2392 self.ser.addInputTensor(a)
2393 self.ser.addInputTensor(acc)
2394
Kevin Cheng550ccc52021-03-03 11:21:43 -08002395 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002396
2397 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002398 iter_body_out = self.ser.addIntermediate(
2399 incorrect_iter.shape, incorrect_iter.dtype
2400 )
2401 acc_body_out = self.ser.addIntermediate(
2402 incorrect_acc.shape, incorrect_acc.dtype
2403 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002404 else:
2405 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2406 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2407
Eric Kunzee5e26762020-10-13 16:11:07 -07002408 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2409 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2410 self.ser.addOutputTensor(iter_body_out)
2411 self.ser.addOutputTensor(a)
2412 self.ser.addOutputTensor(acc_body_out)
2413
Les Bell729b0352021-11-24 10:28:21 +00002414 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002415 self.ser,
2416 validator_fcns,
2417 error_name,
2418 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002419 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002420 ):
2421 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002422
Eric Kunzee5e26762020-10-13 16:11:07 -07002423 return acc_out
2424
Luke Hutton57287132023-02-06 14:54:18 +00002425 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002426 self,
2427 op,
2428 val1,
2429 val2,
2430 inverse,
2431 validator_fcns=None,
2432 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002433 ):
2434 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2435
2436 input_names = [val1.name, val2.name]
2437 pCount, cCount = op["operands"]
2438 num_operands = pCount + cCount
2439
2440 output_names = [res.name for res in results]
2441 output_shapes = [res.shape for res in results]
2442 output_dtypes = [res.dtype for res in results]
2443
2444 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2445 self, error_name, input_names, output_names
2446 )
2447
2448 if not TosaErrorValidator.evValidateErrorIfs(
2449 self.ser,
2450 validator_fcns,
2451 error_name,
2452 op=op,
2453 inverse=inverse,
2454 input1=val1,
2455 input2=val2,
2456 input_shape=val1.shape,
2457 input_dtype=val1.dtype,
2458 output_shape=output_shapes,
2459 output_dtype=output_dtypes,
2460 result_tensors=results,
2461 input_list=input_names,
2462 output_list=output_names,
2463 num_operands=num_operands,
2464 ):
2465 return None
2466
Tai Lyd3797f02023-11-15 23:06:19 +00002467 # TODO - Test local_bound, for now set local bound attribute to False
2468 local_bound = False
2469
Luke Hutton57287132023-02-06 14:54:18 +00002470 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002471 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002472
2473 self.ser.addOperator(op["op"], input_names, output_names, attr)
2474 return results
2475
Tai Lyd3797f02023-11-15 23:06:19 +00002476 def build_rfft2d(
2477 self,
2478 op,
2479 val,
2480 validator_fcns=None,
2481 error_name=None,
2482 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002483 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2484
2485 input_names = [val.name]
2486 pCount, cCount = op["operands"]
2487 num_operands = pCount + cCount
2488
2489 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002490 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002491 output_dtypes = [res.dtype for res in results]
2492
2493 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2494 self, error_name, input_names, output_names
2495 )
2496
2497 if not TosaErrorValidator.evValidateErrorIfs(
2498 self.ser,
2499 validator_fcns,
2500 error_name,
2501 op=op,
2502 input_shape=val.shape,
2503 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002504 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002505 output_dtype=output_dtypes,
2506 result_tensors=results,
2507 input_list=input_names,
2508 output_list=output_names,
2509 num_operands=num_operands,
2510 ):
2511 return None
2512
Tai Lyd3797f02023-11-15 23:06:19 +00002513 # TODO - Test local_bound, for now set local bound attribute to False
2514 local_bound = False
2515
2516 attr = ts.TosaSerializerAttribute()
2517 attr.RFFTAttribute(local_bound)
2518
2519 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002520 return results
2521
Won Jeon74342e52024-01-09 00:34:40 +00002522 def build_shape_op(
2523 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2524 ):
2525 assert len(inputs) == 2
2526 a, b = inputs
2527
2528 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2529
2530 # Invalidate Input/Output list for error if checks.
2531 input_list = [a.name, b.name]
2532 output_list = [result_tensor.name]
2533 pCount, cCount = op["operands"]
2534 num_operands = pCount + cCount
2535 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2536 self, error_name, input_list, output_list
2537 )
2538
2539 if not TosaErrorValidator.evValidateErrorIfs(
2540 self.ser,
2541 validator_fcns,
2542 error_name,
2543 op=op,
2544 input1=a,
2545 input2=b,
2546 input_shape=a.shape,
2547 input_dtype=a.dtype,
2548 output_shape=result_tensor.shape,
2549 output_dtype=result_tensor.dtype,
2550 result_tensors=[result_tensor],
2551 input_list=input_list,
2552 output_list=output_list,
2553 num_operands=num_operands,
2554 ):
2555 return None
2556
2557 self.ser.addOperator(
2558 op["op"],
2559 input_list,
2560 output_list,
2561 )
2562 compliance = self.tensorComplianceMetaData(
2563 op, a.dtype, args_dict, result_tensor, error_name
2564 )
2565
2566 return TosaTestGen.BuildInfo(result_tensor, compliance)
2567
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002568 def create_filter_lists(
2569 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2570 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002571 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2572 default_test_rank_range = range(1, 5)
2573 if not shapeFilter:
2574 shapeFilter = [None]
2575
2576 # Calculate the filters based on what is requested and what the operator allows
2577 rmin, rmax = op["rank"]
2578 if rankFilter is not None:
2579 cleanRankFilter = []
2580 # Ensure rankFilter values are allowed by operator
2581 for rank in rankFilter:
2582 if rank >= rmin and rank <= rmax:
2583 cleanRankFilter.append(rank)
2584 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002585 # Ensure default behaviour is bounded by default range or by operator,
2586 # whichever is the smaller range of ranks.
2587 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002588 cleanRankFilter = (
2589 opRankRange
2590 if len(opRankRange) <= len(default_test_rank_range)
2591 else default_test_rank_range
2592 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002593 else:
2594 cleanRankFilter = range(rmin, rmax + 1)
2595
2596 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002597
Matthew Haddon1c00b712021-10-01 15:51:03 +01002598 if dtypeFilter is not None:
2599 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002600 # Create list of operator dtypes filtered by requested dtypes
2601 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002602 if dtype in dtypeFilter or (
2603 isinstance(dtype, list) and dtype[0] in dtypeFilter
2604 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002605 cleanDtypeFilter.append(dtype)
2606 else:
2607 cleanDtypeFilter = dtypes
2608
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002609 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002610 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002611 "shapeFilter": shapeFilter,
2612 "rankFilter": cleanRankFilter,
2613 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002614 }
2615 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002616 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002617 if validator is not None:
2618 validator_info = validator(check=False, op=op)
2619 else:
2620 return None
2621
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002622 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002623
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002624 # Set parameters as required
2625 if error_arguments["rank"] is not None:
2626 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002627 else:
2628 rankFilter = cleanRankFilter
2629
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002630 if error_arguments["dtype"] is not None:
2631 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002632 else:
2633 dtypeFilter = cleanDtypeFilter
2634
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002635 if error_arguments["shape"] is not None:
2636 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002637 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002638 shapeFilter = shapeFilter[
2639 :2
2640 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002641
2642 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002643 "shapeFilter": shapeFilter,
2644 "rankFilter": rankFilter,
2645 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002646 }
2647 return filterDict
2648
Kevin Cheng550ccc52021-03-03 11:21:43 -08002649 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002650 self,
2651 opName,
2652 shapeFilter=[None],
2653 rankFilter=None,
2654 dtypeFilter=None,
2655 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002656 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002657
2658 try:
2659 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002660 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002661 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002662
2663 # Initialize a new random number generator
2664 self.rng = np.random.default_rng(self.random_seed)
2665
Jeremy Johnson1271c442023-09-05 11:39:26 +01002666 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002667
Eric Kunzee5e26762020-10-13 16:11:07 -07002668 # Test list consists of a tuple of:
2669 # (opName, testNameStr, dtype, shapeList, argumentsList)
2670 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002671 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002672 error_if_validators = op["error_if_validators"]
2673 else:
2674 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002675
Matthew Haddon1c00b712021-10-01 15:51:03 +01002676 for validator in error_if_validators:
2677 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002678 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002679 else:
2680 error_name = None
2681
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002682 filterDict = self.create_filter_lists(
2683 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2684 )
2685 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002686 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002687 cleanRankFilter = filterDict["rankFilter"]
2688 cleanDtypeFilter = filterDict["dtypeFilter"]
2689 cleanShapeFilter = filterDict["shapeFilter"]
2690 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002691
2692 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002693 for t in cleanDtypeFilter:
2694 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002695 # Filter out by rank
2696 if shape is not None and len(shape) != r:
2697 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002698 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002699 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002700
Matthew Haddon74567092021-07-16 15:38:20 +01002701 shapeStr = self.shapeStr(shapeList[0])
2702 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002703
Matthew Haddon74567092021-07-16 15:38:20 +01002704 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2705 argList = []
2706 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002707 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002708 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002709 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002710
Matthew Haddon74567092021-07-16 15:38:20 +01002711 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002712 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002713 if argStr:
2714 testStr = "{}_{}_{}_{}".format(
2715 opName, shapeStr, typeStr, argStr
2716 )
2717 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002718 testStr = "{}_{}_{}".format(
2719 opName, shapeStr, typeStr
2720 )
2721 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002722 if argStr:
2723 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2724 opName, error_name, shapeStr, typeStr, argStr
2725 )
2726 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002727 testStr = "{}_ERRORIF_{}_{}_{}".format(
2728 opName, error_name, shapeStr, typeStr
2729 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002730
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002731 testList.append(
2732 (opName, testStr, t, error_name, shapeList, args)
2733 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002734
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002735 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002736 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2737 if "invalid_test_validators" in op:
2738 invalid_test_validators = op["invalid_test_validators"]
2739 clean_testList = []
2740 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002741 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002742 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002743 if validator_fcn(
2744 opName=test[0],
2745 input_dtype=test[2],
2746 shapeList=test[4],
2747 args=test[5],
2748 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002749 remove_test = True
2750 if not remove_test:
2751 clean_testList.append(test)
2752 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002753
2754 return testList
2755
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002756 def serializeTest(
2757 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2758 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002759 try:
2760 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002761 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002762 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002763
Jeremy Johnson0c716862023-04-13 17:18:19 +01002764 if self.args.verbose:
2765 print(f"Creating {testStr}")
2766
Eric Kunzee5e26762020-10-13 16:11:07 -07002767 # Create a serializer
2768 self.createSerializer(opName, testStr)
2769
Jeremy Johnson1271c442023-09-05 11:39:26 +01002770 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002771 if "error_if_validators" in op:
2772 error_if_validators = op["error_if_validators"]
2773 else:
2774 error_if_validators = None
2775
Kevin Cheng550ccc52021-03-03 11:21:43 -08002776 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002777 num_operands = pCount + cCount
2778
2779 if isinstance(dtype_or_dtypeList, list):
2780 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00002781 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002782 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002783 else:
2784 dtypeList = [dtype_or_dtypeList] * (num_operands)
2785
Won Jeon74342e52024-01-09 00:34:40 +00002786 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01002787 assert (
2788 len(shapeList) == num_operands
2789 ), "shapeList length {} must match number of operands {}".format(
2790 len(shapeList), num_operands
2791 )
2792 assert (
2793 len(dtypeList) == num_operands
2794 ), "dtypeList length {} must match number of operands {}".format(
2795 len(dtypeList), num_operands
2796 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002797
2798 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002799 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002800 except KeyError:
2801 qgen = None
2802
2803 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002804
Matthew Haddon1c00b712021-10-01 15:51:03 +01002805 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002806 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002807 else:
2808 qinfo = None
2809
Jeremy Johnson1271c442023-09-05 11:39:26 +01002810 # Extra meta data for the desc.json
2811 tensMeta = {}
2812
2813 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002814 if isinstance(testArgs, dict):
2815 # New interface with args info in dictionary
2816 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002817 assert "dg_type" in argsDict
2818 tvgInfo = tvgen_fcn(
2819 self, opName, dtypeList, shapeList, argsDict, error_name
2820 )
2821 if tvgInfo.dataGenDict:
2822 tensMeta["data_gen"] = tvgInfo.dataGenDict
2823 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002824
2825 result = build_fcn(
2826 self,
2827 op,
2828 tens,
2829 argsDict,
2830 validator_fcns=error_if_validators,
2831 error_name=error_name,
2832 qinfo=qinfo,
2833 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002834 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002835 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002836 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002837
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002838 try:
2839 if error_if_validators is None:
2840 if qinfo is not None:
2841 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2842 else:
2843 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002844 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002845 if qinfo is not None:
2846 result = build_fcn(
2847 self,
2848 op,
2849 *tens,
2850 *testArgs,
2851 validator_fcns=error_if_validators,
2852 error_name=error_name,
2853 qinfo=qinfo,
2854 )
2855 else:
2856 result = build_fcn(
2857 self,
2858 op,
2859 *tens,
2860 *testArgs,
2861 validator_fcns=error_if_validators,
2862 error_name=error_name,
2863 )
2864 except TypeError as e:
2865 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2866 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002867
Jeremy Johnson1271c442023-09-05 11:39:26 +01002868 if result:
Les Bell729b0352021-11-24 10:28:21 +00002869 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002870 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2871 # Add the compliance meta data
2872 # NOTE: This currently expects only one result output
2873 tensMeta["compliance"] = {
2874 "version": "0.1",
2875 "tensors": {result.resultTensor.name: result.complianceDict},
2876 }
2877 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002878 else:
2879 # The test is not valid
2880 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002881
Eric Kunzee5e26762020-10-13 16:11:07 -07002882 def createDynamicOpLists(self):
2883
Jeremy Johnson00423432022-09-12 17:27:37 +01002884 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2885 # Already created these lists (can occur when class is initialized more than once)
2886 return
2887
Eric Kunzee5e26762020-10-13 16:11:07 -07002888 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002889 if not self.args.level8k:
2890 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2891 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2892 else:
2893 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2894 KERNELS_2D = [[1, bigK], [bigK, 2]]
2895 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002896
Kevin Cheng1533b852021-09-01 12:51:58 -07002897 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002898 testName = "conv2d_{}x{}".format(k[0], k[1])
2899 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2900 self.TOSA_OP_LIST[testName]["filter"] = k
2901 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002902
Kevin Cheng550ccc52021-03-03 11:21:43 -08002903 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2904 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2905 "depthwise_conv2d_TEMPLATE"
2906 ].copy()
2907 self.TOSA_OP_LIST[testName]["filter"] = k
2908 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002909
Kevin Cheng550ccc52021-03-03 11:21:43 -08002910 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2911 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2912 "transpose_conv2d_TEMPLATE"
2913 ].copy()
2914 self.TOSA_OP_LIST[testName]["filter"] = k
2915 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002916
Kevin Cheng1533b852021-09-01 12:51:58 -07002917 for k in KERNELS_3D:
2918 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2919 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2920 self.TOSA_OP_LIST[testName]["filter"] = k
2921 self.TOSA_OP_LIST[testName]["template"] = False
2922
Eric Kunzee5e26762020-10-13 16:11:07 -07002923 # Delete any templates after having created any dynamic ops
2924 # This is a two-pass operation because it's bad practice to delete
2925 # keys from dictionaries while iterating
2926 keyList = []
2927 for k in self.TOSA_OP_LIST:
2928 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002929 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002930 keyList.append(k)
2931 continue
2932 except KeyError:
2933 pass
2934
2935 for k in keyList:
2936 del self.TOSA_OP_LIST[k]
2937
2938 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002939 """Fill in default fields for ops if they aren't already specified.
2940 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002941 for op in self.TOSA_OP_LIST:
2942
2943 # Required fields
2944 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002945 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002946 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002947 raise Exception(
2948 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2949 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002950
2951 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002952 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002953 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002954 raise Exception(
2955 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2956 op
2957 )
2958 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002959
2960 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002961 _ = self.TOSA_OP_LIST[op]["types"]
2962 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002963 raise Exception(
2964 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2965 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002966
2967 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002968 _ = self.TOSA_OP_LIST[op]["op"]
2969 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002970 raise Exception(
2971 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2972 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002973
2974 # Put in default rank range, if missing
2975 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002976 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002977 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002978 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002979
2980 # Tensor operator list
2981 # 'op': op name
2982 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002983 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2984 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002985 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2986 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002987 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002988
Kevin Cheng550ccc52021-03-03 11:21:43 -08002989 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002990 TYPE_INT_FP = [
2991 DType.INT8,
2992 DType.INT16,
2993 DType.INT32,
2994 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002995 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002996 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002997 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002998
Kevin Cheng550ccc52021-03-03 11:21:43 -08002999 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003000 TYPE_FI32 = [
3001 DType.FP32,
3002 DType.FP16,
3003 DType.BF16,
3004 DType.INT32,
3005 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003006 TYPE_FIB = [
3007 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003008 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003009 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003010 DType.INT8,
3011 DType.INT16,
3012 DType.INT32,
3013 DType.BOOL,
3014 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003015 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003016
James Ward24dbc422022-10-19 12:20:31 +01003017 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07003018
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003019 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003020 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003021 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003022 [DType.INT8, DType.INT8, DType.INT32],
3023 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003024 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003025 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003026 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003027 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003028 ]
3029
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003030 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003031
3032 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003033 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003034 "argmax": {
3035 "op": Op.ARGMAX,
3036 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003037 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003038 "build_fcn": (
3039 build_argmax,
3040 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003041 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003042 TosaArgGen.agAxis,
3043 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003044 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003045 "error_if_validators": (
3046 TosaErrorValidator.evAxisSmallerZero,
3047 TosaErrorValidator.evAxisLargerRank,
3048 TosaErrorValidator.evArgmaxOutputRankMismatch,
3049 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3050 TosaErrorValidator.evWrongRank,
3051 TosaErrorValidator.evWrongInputType,
3052 TosaErrorValidator.evWrongOutputType,
3053 TosaErrorValidator.evWrongInputList,
3054 TosaErrorValidator.evWrongOutputList,
3055 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003056 "data_gen": {
3057 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3058 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003059 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003060 "avg_pool2d": {
3061 "op": Op.AVG_POOL2D,
3062 "operands": (1, 0),
3063 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003064 "build_fcn": (
3065 build_pool2d,
3066 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003067 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003068 TosaArgGen.agPooling,
3069 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003070 "qgen": TosaQuantGen.qgUnary,
3071 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003072 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003073 "error_if_validators": (
3074 TosaErrorValidator.evKernelSmallerOne,
3075 TosaErrorValidator.evStrideSmallerOne,
3076 TosaErrorValidator.evPadSmallerZero,
3077 TosaErrorValidator.evWrongRank,
3078 TosaErrorValidator.evWrongInputType,
3079 TosaErrorValidator.evWrongOutputType,
3080 TosaErrorValidator.evWrongInputList,
3081 TosaErrorValidator.evWrongOutputList,
3082 TosaErrorValidator.evInputZeroPointNotZero,
3083 TosaErrorValidator.evOutputZeroPointNotZero,
3084 TosaErrorValidator.evPadLargerEqualKernel,
3085 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003086 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003087 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003088 "data_gen": {
3089 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3090 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003091 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003092 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003093 "conv2d_TEMPLATE": {
3094 "op": Op.CONV2D,
3095 "operands": (1, 2),
3096 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003097 "build_fcn": (
3098 build_conv2d,
3099 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003100 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003101 TosaArgGen.agConv,
3102 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003103 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003104 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003105 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3106 "error_if_validators": (
3107 TosaErrorValidator.evWrongInputType,
3108 TosaErrorValidator.evWrongOutputType,
3109 TosaErrorValidator.evWrongInputList,
3110 TosaErrorValidator.evWrongOutputList,
3111 TosaErrorValidator.evInputZeroPointNotZero,
3112 TosaErrorValidator.evWeightZeroPointNotZero,
3113 TosaErrorValidator.evPadSmallerZero,
3114 TosaErrorValidator.evStrideSmallerOne,
3115 TosaErrorValidator.evDilationSmallerOne,
3116 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003117 TosaErrorValidator.evConvOutputShapeMismatch,
3118 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003119 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003120 "data_gen": {
3121 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3122 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003123 "template": True,
3124 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003125 # Templated operator. Filled in by createDynamicOpLists
3126 "conv3d_TEMPLATE": {
3127 "op": Op.CONV3D,
3128 "operands": (1, 2),
3129 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003130 "build_fcn": (
3131 build_conv3d,
3132 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003133 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003134 TosaArgGen.agConv,
3135 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003136 "qgen": TosaQuantGen.qgConv,
3137 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003138 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3139 "error_if_validators": (
3140 TosaErrorValidator.evWrongInputType,
3141 TosaErrorValidator.evWrongOutputType,
3142 TosaErrorValidator.evWrongInputList,
3143 TosaErrorValidator.evWrongOutputList,
3144 TosaErrorValidator.evInputZeroPointNotZero,
3145 TosaErrorValidator.evWeightZeroPointNotZero,
3146 TosaErrorValidator.evPadSmallerZero,
3147 TosaErrorValidator.evStrideSmallerOne,
3148 TosaErrorValidator.evDilationSmallerOne,
3149 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003150 TosaErrorValidator.evConvOutputShapeMismatch,
3151 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003152 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003153 "template": True,
3154 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003155 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003156 "depthwise_conv2d_TEMPLATE": {
3157 "op": Op.DEPTHWISE_CONV2D,
3158 "operands": (1, 2),
3159 "filter": [1, 1],
3160 "rank": (4, 4),
3161 "build_fcn": (
3162 build_depthwise_conv2d,
3163 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003164 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003165 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003166 ),
3167 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003168 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003169 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3170 "error_if_validators": (
3171 TosaErrorValidator.evWrongInputType,
3172 TosaErrorValidator.evWrongOutputType,
3173 TosaErrorValidator.evWrongInputList,
3174 TosaErrorValidator.evWrongOutputList,
3175 TosaErrorValidator.evInputZeroPointNotZero,
3176 TosaErrorValidator.evWeightZeroPointNotZero,
3177 TosaErrorValidator.evPadSmallerZero,
3178 TosaErrorValidator.evStrideSmallerOne,
3179 TosaErrorValidator.evDilationSmallerOne,
3180 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003181 TosaErrorValidator.evConvOutputShapeMismatch,
3182 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003183 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003184 "template": True,
3185 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003186 "fully_connected": {
3187 "op": Op.FULLY_CONNECTED,
3188 "operands": (1, 2),
3189 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003190 "build_fcn": (
3191 build_fully_connected,
3192 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003193 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003194 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003195 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003196 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003197 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003198 "error_if_validators": (
3199 TosaErrorValidator.evInputZeroPointNotZero,
3200 TosaErrorValidator.evWeightZeroPointNotZero,
3201 TosaErrorValidator.evWrongRank,
3202 TosaErrorValidator.evWrongInputType,
3203 TosaErrorValidator.evWrongOutputType,
3204 TosaErrorValidator.evWrongInputList,
3205 TosaErrorValidator.evWrongOutputList,
3206 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003207 "data_gen": {
3208 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3209 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003210 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003211 "matmul": {
3212 "op": Op.MATMUL,
3213 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003214 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003215 "build_fcn": (
3216 build_matmul,
3217 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003218 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003219 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003220 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003221 "qgen": TosaQuantGen.qgMatmul,
3222 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003223 "error_if_validators": (
3224 TosaErrorValidator.evInputZeroPointNotZero,
3225 TosaErrorValidator.evWrongRank,
3226 TosaErrorValidator.evWrongInputType,
3227 TosaErrorValidator.evWrongOutputType,
3228 TosaErrorValidator.evWrongInputList,
3229 TosaErrorValidator.evWrongOutputList,
3230 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003231 "data_gen": {
3232 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003233 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003234 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003235 "max_pool2d": {
3236 "op": Op.MAX_POOL2D,
3237 "operands": (1, 0),
3238 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003239 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003240 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003241 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003242 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003243 TosaArgGen.agPooling,
3244 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003245 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003246 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003247 "error_if_validators": (
3248 TosaErrorValidator.evKernelSmallerOne,
3249 TosaErrorValidator.evStrideSmallerOne,
3250 TosaErrorValidator.evPadSmallerZero,
3251 TosaErrorValidator.evWrongRank,
3252 TosaErrorValidator.evWrongInputType,
3253 TosaErrorValidator.evWrongOutputType,
3254 TosaErrorValidator.evWrongInputList,
3255 TosaErrorValidator.evWrongOutputList,
3256 TosaErrorValidator.evPadLargerEqualKernel,
3257 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003258 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003259 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003260 "data_gen": {
3261 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3262 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003263 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003264 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003265 "transpose_conv2d_TEMPLATE": {
3266 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003267 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003268 "rank": (4, 4),
3269 "build_fcn": (
3270 build_transpose_conv2d,
3271 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003272 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003273 TosaArgGen.agTransposeConv2D,
3274 ),
3275 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003276 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003277 "invalid_test_validators": (
3278 TosaInvalidValidator.ivHeightWidthInvalid,
3279 TosaInvalidValidator.ivNonPositiveOutputShape,
3280 ),
3281 "error_if_validators": (
3282 TosaErrorValidator.evWrongInputType,
3283 TosaErrorValidator.evWrongOutputType,
3284 TosaErrorValidator.evWrongInputList,
3285 TosaErrorValidator.evWrongOutputList,
3286 TosaErrorValidator.evInputZeroPointNotZero,
3287 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003288 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003289 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003290 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003291 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003292 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003293 "template": True,
3294 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003295 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003296 "clamp": {
3297 "op": Op.CLAMP,
3298 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003299 "build_fcn": (
3300 build_clamp,
3301 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003302 TosaTensorValuesGen.tvgLazyGenDefault,
3303 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003304 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003305 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003306 "error_if_validators": (
3307 TosaErrorValidator.evMaxSmallerMin,
3308 TosaErrorValidator.evWrongInputType,
3309 TosaErrorValidator.evWrongOutputType,
3310 TosaErrorValidator.evWrongInputList,
3311 TosaErrorValidator.evWrongOutputList,
3312 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003313 "data_gen": {
3314 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3315 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003316 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003317 "sigmoid": {
3318 "op": Op.SIGMOID,
3319 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003320 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003321 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003322 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003323 TosaTensorValuesGen.tvgLazyGenDefault,
3324 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003325 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003326 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003327 "error_if_validators": (
3328 TosaErrorValidator.evWrongInputType,
3329 TosaErrorValidator.evWrongOutputType,
3330 TosaErrorValidator.evWrongInputList,
3331 TosaErrorValidator.evWrongOutputList,
3332 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003333 "data_gen": {
3334 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3335 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003336 },
3337 "tanh": {
3338 "op": Op.TANH,
3339 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003340 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003341 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003342 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003343 TosaTensorValuesGen.tvgLazyGenDefault,
3344 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003345 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003346 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003347 "error_if_validators": (
3348 TosaErrorValidator.evWrongInputType,
3349 TosaErrorValidator.evWrongOutputType,
3350 TosaErrorValidator.evWrongInputList,
3351 TosaErrorValidator.evWrongOutputList,
3352 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003353 "data_gen": {
3354 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3355 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003356 "compliance": {
3357 "abs_error_lower_bound": 0.5,
3358 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003359 },
Won Jeon78155c62023-06-10 00:20:04 +00003360 "erf": {
3361 "op": Op.ERF,
3362 "operands": (1, 0),
3363 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003364 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003365 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003366 TosaTensorValuesGen.tvgLazyGenDefault,
3367 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003368 ),
3369 "types": TYPE_FP,
3370 "error_if_validators": (
3371 TosaErrorValidator.evWrongInputType,
3372 TosaErrorValidator.evWrongOutputType,
3373 TosaErrorValidator.evWrongInputList,
3374 TosaErrorValidator.evWrongOutputList,
3375 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003376 "data_gen": {
3377 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3378 },
3379 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003380 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003381 # Elementwise Binary Operators
3382 "add": {
3383 "op": Op.ADD,
3384 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003385 "build_fcn": (
3386 build_binary_broadcast,
3387 TosaTensorGen.tgBroadcastFuzz,
3388 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003389 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003390 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003391 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003392 "error_if_validators": (
3393 TosaErrorValidator.evRankMismatch,
3394 TosaErrorValidator.evWrongInputType,
3395 TosaErrorValidator.evWrongOutputType,
3396 TosaErrorValidator.evWrongInputList,
3397 TosaErrorValidator.evWrongOutputList,
3398 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003399 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003400 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003401 "data_gen": {
3402 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3403 },
3404 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003405 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003406 "arithmetic_right_shift": {
3407 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3408 "operands": (2, 0),
3409 "build_fcn": (
3410 build_arithmetic_right_shift,
3411 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003412 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003413 TosaArgGen.agArithmeticRightShift,
3414 ),
3415 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003416 "error_if_validators": (
3417 TosaErrorValidator.evRankMismatch,
3418 TosaErrorValidator.evWrongInputType,
3419 TosaErrorValidator.evWrongOutputType,
3420 TosaErrorValidator.evWrongInputList,
3421 TosaErrorValidator.evWrongOutputList,
3422 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003423 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003424 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003425 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003426 "bitwise_and": {
3427 "op": Op.BITWISE_AND,
3428 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003429 "build_fcn": (
3430 build_binary_broadcast,
3431 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003432 TosaTensorValuesGen.tvgLazyGenDefault,
3433 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003434 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003435 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003436 "error_if_validators": (
3437 TosaErrorValidator.evRankMismatch,
3438 TosaErrorValidator.evWrongInputType,
3439 TosaErrorValidator.evWrongOutputType,
3440 TosaErrorValidator.evWrongInputList,
3441 TosaErrorValidator.evWrongOutputList,
3442 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003443 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003444 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003445 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003446 "bitwise_or": {
3447 "op": Op.BITWISE_OR,
3448 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003449 "build_fcn": (
3450 build_binary_broadcast,
3451 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003452 TosaTensorValuesGen.tvgLazyGenDefault,
3453 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003454 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003455 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003456 "error_if_validators": (
3457 TosaErrorValidator.evRankMismatch,
3458 TosaErrorValidator.evWrongInputType,
3459 TosaErrorValidator.evWrongOutputType,
3460 TosaErrorValidator.evWrongInputList,
3461 TosaErrorValidator.evWrongOutputList,
3462 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003463 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003464 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003465 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003466 "bitwise_xor": {
3467 "op": Op.BITWISE_XOR,
3468 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003469 "build_fcn": (
3470 build_binary_broadcast,
3471 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003472 TosaTensorValuesGen.tvgLazyGenDefault,
3473 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003474 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003475 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003476 "error_if_validators": (
3477 TosaErrorValidator.evRankMismatch,
3478 TosaErrorValidator.evWrongInputType,
3479 TosaErrorValidator.evWrongOutputType,
3480 TosaErrorValidator.evWrongInputList,
3481 TosaErrorValidator.evWrongOutputList,
3482 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003483 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003484 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003485 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003486 "intdiv": {
3487 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003488 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003489 "build_fcn": (
3490 build_binary_broadcast,
3491 TosaTensorGen.tgBroadcastFuzz,
3492 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003493 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003494 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003495 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003496 "error_if_validators": (
3497 TosaErrorValidator.evRankMismatch,
3498 TosaErrorValidator.evWrongInputType,
3499 TosaErrorValidator.evWrongOutputType,
3500 TosaErrorValidator.evWrongInputList,
3501 TosaErrorValidator.evWrongOutputList,
3502 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003503 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003504 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003505 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003506 "logical_and": {
3507 "op": Op.LOGICAL_AND,
3508 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003509 "build_fcn": (
3510 build_binary_broadcast,
3511 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003512 TosaTensorValuesGen.tvgLazyGenDefault,
3513 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003514 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003515 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003516 "error_if_validators": (
3517 TosaErrorValidator.evRankMismatch,
3518 TosaErrorValidator.evWrongInputType,
3519 TosaErrorValidator.evWrongOutputType,
3520 TosaErrorValidator.evWrongInputList,
3521 TosaErrorValidator.evWrongOutputList,
3522 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003523 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003524 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003525 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003526 "logical_left_shift": {
3527 "op": Op.LOGICAL_LEFT_SHIFT,
3528 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003529 "build_fcn": (
3530 build_binary_broadcast,
3531 TosaTensorGen.tgBroadcastFuzz,
3532 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003533 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003534 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003536 "error_if_validators": (
3537 TosaErrorValidator.evRankMismatch,
3538 TosaErrorValidator.evWrongInputType,
3539 TosaErrorValidator.evWrongOutputType,
3540 TosaErrorValidator.evWrongInputList,
3541 TosaErrorValidator.evWrongOutputList,
3542 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003543 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003544 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003545 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003546 "logical_right_shift": {
3547 "op": Op.LOGICAL_RIGHT_SHIFT,
3548 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003549 "build_fcn": (
3550 build_binary_broadcast,
3551 TosaTensorGen.tgBroadcastFuzz,
3552 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003553 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003554 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003555 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003556 "error_if_validators": (
3557 TosaErrorValidator.evRankMismatch,
3558 TosaErrorValidator.evWrongInputType,
3559 TosaErrorValidator.evWrongOutputType,
3560 TosaErrorValidator.evWrongInputList,
3561 TosaErrorValidator.evWrongOutputList,
3562 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003563 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003564 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003565 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003566 "logical_or": {
3567 "op": Op.LOGICAL_OR,
3568 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003569 "build_fcn": (
3570 build_binary_broadcast,
3571 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003572 TosaTensorValuesGen.tvgLazyGenDefault,
3573 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003574 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003575 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003576 "error_if_validators": (
3577 TosaErrorValidator.evRankMismatch,
3578 TosaErrorValidator.evWrongInputType,
3579 TosaErrorValidator.evWrongOutputType,
3580 TosaErrorValidator.evWrongInputList,
3581 TosaErrorValidator.evWrongOutputList,
3582 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003583 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003584 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003585 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003586 "logical_xor": {
3587 "op": Op.LOGICAL_XOR,
3588 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003589 "build_fcn": (
3590 build_binary_broadcast,
3591 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003592 TosaTensorValuesGen.tvgLazyGenDefault,
3593 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003594 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003595 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003596 "error_if_validators": (
3597 TosaErrorValidator.evRankMismatch,
3598 TosaErrorValidator.evWrongInputType,
3599 TosaErrorValidator.evWrongOutputType,
3600 TosaErrorValidator.evWrongInputList,
3601 TosaErrorValidator.evWrongOutputList,
3602 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003603 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003604 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003605 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003606 "maximum": {
3607 "op": Op.MAXIMUM,
3608 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003609 "build_fcn": (
3610 build_binary_broadcast,
3611 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003612 TosaTensorValuesGen.tvgLazyGenDefault,
3613 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003614 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003615 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003616 "error_if_validators": (
3617 TosaErrorValidator.evRankMismatch,
3618 TosaErrorValidator.evWrongInputType,
3619 TosaErrorValidator.evWrongOutputType,
3620 TosaErrorValidator.evWrongInputList,
3621 TosaErrorValidator.evWrongOutputList,
3622 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003623 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003624 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003625 "data_gen": {
3626 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3627 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003628 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003629 "minimum": {
3630 "op": Op.MINIMUM,
3631 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003632 "build_fcn": (
3633 build_binary_broadcast,
3634 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003635 TosaTensorValuesGen.tvgLazyGenDefault,
3636 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003637 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003638 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003639 "error_if_validators": (
3640 TosaErrorValidator.evRankMismatch,
3641 TosaErrorValidator.evWrongInputType,
3642 TosaErrorValidator.evWrongOutputType,
3643 TosaErrorValidator.evWrongInputList,
3644 TosaErrorValidator.evWrongOutputList,
3645 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003646 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003647 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003648 "data_gen": {
3649 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3650 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003651 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003652 "mul": {
3653 "op": Op.MUL,
3654 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003655 "build_fcn": (
3656 build_mul,
3657 TosaTensorGen.tgBroadcastFuzz,
3658 TosaTensorValuesGen.tvgMul,
3659 TosaArgGen.agMul,
3660 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003661 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003662 "error_if_validators": (
3663 TosaErrorValidator.evWrongInputType,
3664 TosaErrorValidator.evWrongOutputType,
3665 TosaErrorValidator.evWrongInputList,
3666 TosaErrorValidator.evWrongOutputList,
3667 TosaErrorValidator.evRankMismatch,
3668 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003669 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003670 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003671 "data_gen": {
3672 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3673 },
3674 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003675 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003676 "pow": {
3677 "op": Op.POW,
3678 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003679 "build_fcn": (
3680 build_binary_broadcast,
3681 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003682 TosaTensorValuesGen.tvgPow,
3683 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003684 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003685 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003686 "error_if_validators": (
3687 TosaErrorValidator.evRankMismatch,
3688 TosaErrorValidator.evWrongInputType,
3689 TosaErrorValidator.evWrongOutputType,
3690 TosaErrorValidator.evWrongInputList,
3691 TosaErrorValidator.evWrongOutputList,
3692 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003693 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003694 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003695 "data_gen": {
3696 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3697 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003698 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003699 "sub": {
3700 "op": Op.SUB,
3701 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003702 "build_fcn": (
3703 build_binary_broadcast,
3704 TosaTensorGen.tgBroadcastFuzz,
3705 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003706 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003707 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003708 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003709 "error_if_validators": (
3710 TosaErrorValidator.evRankMismatch,
3711 TosaErrorValidator.evWrongInputType,
3712 TosaErrorValidator.evWrongOutputType,
3713 TosaErrorValidator.evWrongInputList,
3714 TosaErrorValidator.evWrongOutputList,
3715 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003716 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003717 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003718 "data_gen": {
3719 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3720 },
3721 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003722 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003723 "table": {
3724 "op": Op.TABLE,
3725 # Use the automatic generation functions to create the input array
3726 # but create the table tensor in the build function, as it may be
3727 # a different type from the input
3728 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003729 "build_fcn": (
3730 build_table,
3731 TosaTensorGen.tgBasic,
3732 TosaTensorValuesGen.tvgDefault,
3733 TosaArgGen.agTable,
3734 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003735 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003736 "error_if_validators": (
3737 TosaErrorValidator.evWrongInputType,
3738 TosaErrorValidator.evWrongOutputType,
3739 TosaErrorValidator.evWrongInputList,
3740 TosaErrorValidator.evWrongOutputList,
3741 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003742 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003743 # Elementwise Unary operators
3744 "abs": {
3745 "op": Op.ABS,
3746 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003747 "build_fcn": (
3748 build_unary,
3749 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003750 TosaTensorValuesGen.tvgLazyGenDefault,
3751 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003752 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003753 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003754 "error_if_validators": (
3755 TosaErrorValidator.evWrongInputType,
3756 TosaErrorValidator.evWrongOutputType,
3757 TosaErrorValidator.evWrongInputList,
3758 TosaErrorValidator.evWrongOutputList,
3759 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003760 "data_gen": {
3761 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3762 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003763 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003764 "bitwise_not": {
3765 "op": Op.BITWISE_NOT,
3766 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003767 "build_fcn": (
3768 build_unary,
3769 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003770 TosaTensorValuesGen.tvgLazyGenDefault,
3771 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003772 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003773 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003774 "error_if_validators": (
3775 TosaErrorValidator.evWrongInputType,
3776 TosaErrorValidator.evWrongOutputType,
3777 TosaErrorValidator.evWrongInputList,
3778 TosaErrorValidator.evWrongOutputList,
3779 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003780 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 "ceil": {
3782 "op": Op.CEIL,
3783 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003784 "build_fcn": (
3785 build_unary,
3786 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003787 TosaTensorValuesGen.tvgLazyGenDefault,
3788 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003789 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003790 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003791 "error_if_validators": (
3792 TosaErrorValidator.evWrongInputType,
3793 TosaErrorValidator.evWrongOutputType,
3794 TosaErrorValidator.evWrongInputList,
3795 TosaErrorValidator.evWrongOutputList,
3796 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003797 "data_gen": {
3798 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3799 },
3800 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003801 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003802 "clz": {
3803 "op": Op.CLZ,
3804 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003805 "build_fcn": (
3806 build_unary,
3807 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003808 TosaTensorValuesGen.tvgLazyGenDefault,
3809 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003810 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003811 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003812 "error_if_validators": (
3813 TosaErrorValidator.evWrongInputType,
3814 TosaErrorValidator.evWrongOutputType,
3815 TosaErrorValidator.evWrongInputList,
3816 TosaErrorValidator.evWrongOutputList,
3817 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003818 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003819 "exp": {
3820 "op": Op.EXP,
3821 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003822 "build_fcn": (
3823 build_unary,
3824 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003825 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003826 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003827 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003828 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003829 "error_if_validators": (
3830 TosaErrorValidator.evWrongInputType,
3831 TosaErrorValidator.evWrongOutputType,
3832 TosaErrorValidator.evWrongInputList,
3833 TosaErrorValidator.evWrongOutputList,
3834 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003835 "data_gen": {
3836 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3837 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003838 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003839 "floor": {
3840 "op": Op.FLOOR,
3841 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003842 "build_fcn": (
3843 build_unary,
3844 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003845 TosaTensorValuesGen.tvgLazyGenDefault,
3846 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003847 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003848 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003849 "error_if_validators": (
3850 TosaErrorValidator.evWrongInputType,
3851 TosaErrorValidator.evWrongOutputType,
3852 TosaErrorValidator.evWrongInputList,
3853 TosaErrorValidator.evWrongOutputList,
3854 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003855 "data_gen": {
3856 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3857 },
3858 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003859 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003860 "log": {
3861 "op": Op.LOG,
3862 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003863 "build_fcn": (
3864 build_unary,
3865 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003866 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003867 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003868 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003869 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003870 "error_if_validators": (
3871 TosaErrorValidator.evWrongInputType,
3872 TosaErrorValidator.evWrongOutputType,
3873 TosaErrorValidator.evWrongInputList,
3874 TosaErrorValidator.evWrongOutputList,
3875 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003876 "data_gen": {
3877 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3878 },
3879 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003880 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003881 "logical_not": {
3882 "op": Op.LOGICAL_NOT,
3883 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003884 "build_fcn": (
3885 build_unary,
3886 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003887 TosaTensorValuesGen.tvgLazyGenDefault,
3888 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003889 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003890 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003891 "error_if_validators": (
3892 TosaErrorValidator.evWrongInputType,
3893 TosaErrorValidator.evWrongOutputType,
3894 TosaErrorValidator.evWrongInputList,
3895 TosaErrorValidator.evWrongOutputList,
3896 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003897 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003898 "negate": {
3899 "op": Op.NEGATE,
3900 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003901 "build_fcn": (
3902 build_unary,
3903 TosaTensorGen.tgBasic,
3904 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003905 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003906 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003907 "qgen": TosaQuantGen.qgUnary,
3908 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003909 "error_if_validators": (
3910 TosaErrorValidator.evInputZeroPointNotZero,
3911 TosaErrorValidator.evOutputZeroPointNotZero,
3912 TosaErrorValidator.evWrongInputType,
3913 TosaErrorValidator.evWrongOutputType,
3914 TosaErrorValidator.evWrongInputList,
3915 TosaErrorValidator.evWrongOutputList,
3916 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003917 "data_gen": {
3918 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3919 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003920 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003921 "reciprocal": {
3922 "op": Op.RECIPROCAL,
3923 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003924 "build_fcn": (
3925 build_unary,
3926 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003927 TosaTensorValuesGen.tvgLazyGenDefault,
3928 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003929 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003930 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003931 "error_if_validators": (
3932 TosaErrorValidator.evWrongInputType,
3933 TosaErrorValidator.evWrongOutputType,
3934 TosaErrorValidator.evWrongInputList,
3935 TosaErrorValidator.evWrongOutputList,
3936 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003937 "data_gen": {
3938 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3939 },
3940 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003941 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003942 "rsqrt": {
3943 "op": Op.RSQRT,
3944 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003945 "build_fcn": (
3946 build_unary,
3947 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003948 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003949 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003950 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003951 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003952 "error_if_validators": (
3953 TosaErrorValidator.evWrongInputType,
3954 TosaErrorValidator.evWrongOutputType,
3955 TosaErrorValidator.evWrongInputList,
3956 TosaErrorValidator.evWrongOutputList,
3957 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003958 "data_gen": {
3959 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3960 },
3961 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003962 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003963 # Elementwise Ternary operators
3964 "select": {
3965 "op": Op.SELECT,
3966 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003967 "build_fcn": (
3968 build_select,
3969 TosaTensorGen.tgBroadcastFuzz,
3970 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00003971 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003972 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003973 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003974 "error_if_validators": (
3975 TosaErrorValidator.evRankMismatch,
3976 TosaErrorValidator.evWrongInputType,
3977 TosaErrorValidator.evWrongOutputType,
3978 TosaErrorValidator.evWrongInputList,
3979 TosaErrorValidator.evWrongOutputList,
3980 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003981 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003982 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00003983 "data_gen": {
3984 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3985 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003986 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003987 # Comparison operators
3988 "equal": {
3989 "op": Op.EQUAL,
3990 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003991 "build_fcn": (
3992 build_comparison,
3993 TosaTensorGen.tgBroadcastFuzz,
3994 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003995 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003996 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003997 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003998 "error_if_validators": (
3999 TosaErrorValidator.evRankMismatch,
4000 TosaErrorValidator.evWrongInputType,
4001 TosaErrorValidator.evWrongOutputType,
4002 TosaErrorValidator.evWrongInputList,
4003 TosaErrorValidator.evWrongOutputList,
4004 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004005 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004006 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004007 "data_gen": {
4008 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4009 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004010 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004011 "greater_equal": {
4012 "op": Op.GREATER_EQUAL,
4013 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004014 "build_fcn": (
4015 build_comparison,
4016 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004017 TosaTensorValuesGen.tvgLazyGenDefault,
4018 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004019 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004020 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004021 "error_if_validators": (
4022 TosaErrorValidator.evRankMismatch,
4023 TosaErrorValidator.evWrongInputType,
4024 TosaErrorValidator.evWrongOutputType,
4025 TosaErrorValidator.evWrongInputList,
4026 TosaErrorValidator.evWrongOutputList,
4027 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004028 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004029 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004030 "data_gen": {
4031 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4032 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004033 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004034 "greater": {
4035 "op": Op.GREATER,
4036 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004037 "build_fcn": (
4038 build_comparison,
4039 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004040 TosaTensorValuesGen.tvgLazyGenDefault,
4041 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004042 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004043 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004044 "error_if_validators": (
4045 TosaErrorValidator.evRankMismatch,
4046 TosaErrorValidator.evWrongInputType,
4047 TosaErrorValidator.evWrongOutputType,
4048 TosaErrorValidator.evWrongInputList,
4049 TosaErrorValidator.evWrongOutputList,
4050 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004051 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004052 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004053 "data_gen": {
4054 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4055 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004056 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004057 # Reduction operators
4058 "reduce_all": {
4059 "op": Op.REDUCE_ALL,
4060 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004061 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004062 "build_fcn": (
4063 build_reduce,
4064 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004065 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004066 TosaArgGen.agAxis,
4067 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004068 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004069 "error_if_validators": (
4070 TosaErrorValidator.evAxisLargerRank,
4071 TosaErrorValidator.evAxisSmallerZero,
4072 TosaErrorValidator.evShapeOfAxisNotOne,
4073 TosaErrorValidator.evWrongInputType,
4074 TosaErrorValidator.evWrongOutputType,
4075 TosaErrorValidator.evWrongRank,
4076 TosaErrorValidator.evWrongInputList,
4077 TosaErrorValidator.evWrongOutputList,
4078 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004079 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004080 "reduce_any": {
4081 "op": Op.REDUCE_ANY,
4082 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004083 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004084 "build_fcn": (
4085 build_reduce,
4086 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004087 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004088 TosaArgGen.agAxis,
4089 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004090 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004091 "error_if_validators": (
4092 TosaErrorValidator.evAxisLargerRank,
4093 TosaErrorValidator.evAxisSmallerZero,
4094 TosaErrorValidator.evShapeOfAxisNotOne,
4095 TosaErrorValidator.evWrongInputType,
4096 TosaErrorValidator.evWrongOutputType,
4097 TosaErrorValidator.evWrongRank,
4098 TosaErrorValidator.evWrongInputList,
4099 TosaErrorValidator.evWrongOutputList,
4100 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004101 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004102 "reduce_max": {
4103 "op": Op.REDUCE_MAX,
4104 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004105 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004106 "build_fcn": (
4107 build_reduce,
4108 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004109 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004110 TosaArgGen.agAxis,
4111 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004112 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004113 "error_if_validators": (
4114 TosaErrorValidator.evAxisLargerRank,
4115 TosaErrorValidator.evAxisSmallerZero,
4116 TosaErrorValidator.evShapeOfAxisNotOne,
4117 TosaErrorValidator.evWrongInputType,
4118 TosaErrorValidator.evWrongOutputType,
4119 TosaErrorValidator.evWrongRank,
4120 TosaErrorValidator.evWrongInputList,
4121 TosaErrorValidator.evWrongOutputList,
4122 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004123 "data_gen": {
4124 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4125 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004126 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004127 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004128 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004129 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004130 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004131 "build_fcn": (
4132 build_reduce,
4133 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004134 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004135 TosaArgGen.agAxis,
4136 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004137 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004138 "error_if_validators": (
4139 TosaErrorValidator.evAxisLargerRank,
4140 TosaErrorValidator.evAxisSmallerZero,
4141 TosaErrorValidator.evShapeOfAxisNotOne,
4142 TosaErrorValidator.evWrongInputType,
4143 TosaErrorValidator.evWrongOutputType,
4144 TosaErrorValidator.evWrongRank,
4145 TosaErrorValidator.evWrongInputList,
4146 TosaErrorValidator.evWrongOutputList,
4147 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004148 "data_gen": {
4149 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4150 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004151 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004152 "reduce_product": {
4153 "op": Op.REDUCE_PRODUCT,
4154 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004155 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004156 "build_fcn": (
4157 build_reduce,
4158 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004159 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004160 TosaArgGen.agAxis,
4161 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004162 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004163 "error_if_validators": (
4164 TosaErrorValidator.evAxisLargerRank,
4165 TosaErrorValidator.evAxisSmallerZero,
4166 TosaErrorValidator.evShapeOfAxisNotOne,
4167 TosaErrorValidator.evWrongInputType,
4168 TosaErrorValidator.evWrongOutputType,
4169 TosaErrorValidator.evWrongRank,
4170 TosaErrorValidator.evWrongInputList,
4171 TosaErrorValidator.evWrongOutputList,
4172 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004173 "data_gen": {
4174 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4175 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004176 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004177 "reduce_sum": {
4178 "op": Op.REDUCE_SUM,
4179 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004180 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004181 "build_fcn": (
4182 build_reduce,
4183 TosaTensorGen.tgBasic,
4184 TosaTensorValuesGen.tvgReduceSum,
4185 TosaArgGen.agAxis,
4186 ),
James Ward24dbc422022-10-19 12:20:31 +01004187 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004188 "error_if_validators": (
4189 TosaErrorValidator.evAxisLargerRank,
4190 TosaErrorValidator.evAxisSmallerZero,
4191 TosaErrorValidator.evShapeOfAxisNotOne,
4192 TosaErrorValidator.evWrongInputType,
4193 TosaErrorValidator.evWrongOutputType,
4194 TosaErrorValidator.evWrongRank,
4195 TosaErrorValidator.evWrongInputList,
4196 TosaErrorValidator.evWrongOutputList,
4197 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004198 "data_gen": {
4199 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4200 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004201 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004202 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004203 "concat": {
4204 "op": Op.CONCAT,
4205 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004206 "build_fcn": (
4207 build_concat,
4208 TosaTensorGen.tgConcat,
4209 TosaTensorValuesGen.tvgConcat,
4210 TosaArgGen.agAxis,
4211 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004212 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004213 "error_if_validators": (
4214 TosaErrorValidator.evAxisLargerRank,
4215 TosaErrorValidator.evAxisSmallerZero,
4216 TosaErrorValidator.evConcatInputRankMismatch,
4217 TosaErrorValidator.evConcatShapeSumMismatch,
4218 TosaErrorValidator.evConcatInputDimMismatch,
4219 TosaErrorValidator.evWrongInputType,
4220 TosaErrorValidator.evWrongOutputType,
4221 TosaErrorValidator.evWrongOutputList,
4222 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004223 "data_gen": {
4224 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4225 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004226 },
4227 "pad": {
4228 "op": Op.PAD,
4229 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004230 "build_fcn": (
4231 build_pad,
4232 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004233 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004234 TosaArgGen.agPad,
4235 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004236 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004237 "error_if_validators": (
4238 TosaErrorValidator.evWrongInputType,
4239 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004240 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004241 TosaErrorValidator.evWrongOutputType,
4242 TosaErrorValidator.evWrongInputList,
4243 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004244 TosaErrorValidator.evRankMismatch,
4245 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004246 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004247 "data_gen": {
4248 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4249 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004250 },
Won Jeona21b2e82023-08-10 10:33:01 +00004251 "dim": {
4252 "op": Op.DIM,
4253 "operands": (1, 0),
4254 "build_fcn": (
4255 build_dim,
4256 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004257 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004258 TosaArgGen.agAxis,
4259 ),
4260 "types": TYPE_FIB,
4261 "error_if_validators": (
4262 TosaErrorValidator.evAxisLargerRank,
4263 TosaErrorValidator.evAxisSmallerZero,
4264 TosaErrorValidator.evWrongInputType,
4265 TosaErrorValidator.evWrongInputList,
4266 TosaErrorValidator.evWrongOutputList,
4267 TosaErrorValidator.evWrongRank,
4268 ),
4269 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004270 "reshape": {
4271 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004272 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004273 "build_fcn": (
4274 build_reshape,
4275 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004276 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004277 TosaArgGen.agReshape,
4278 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004279 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004280 "error_if_validators": (
4281 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4282 TosaErrorValidator.evWrongInputType,
4283 TosaErrorValidator.evWrongOutputType,
4284 TosaErrorValidator.evWrongInputList,
4285 TosaErrorValidator.evWrongOutputList,
4286 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004287 "data_gen": {
4288 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4289 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004290 },
4291 "reverse": {
4292 "op": Op.REVERSE,
4293 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004294 "build_fcn": (
4295 build_reverse,
4296 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004297 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004298 TosaArgGen.agAxis,
4299 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004300 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004301 "error_if_validators": (
4302 TosaErrorValidator.evAxisSmallerZero,
4303 TosaErrorValidator.evAxisLargerRank,
4304 TosaErrorValidator.evWrongInputType,
4305 TosaErrorValidator.evWrongOutputType,
4306 TosaErrorValidator.evWrongInputList,
4307 TosaErrorValidator.evWrongOutputList,
4308 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004309 },
4310 "slice": {
4311 "op": Op.SLICE,
4312 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004313 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004314 "build_fcn": (
4315 build_slice,
4316 TosaTensorGen.tgBasic,
4317 TosaTensorValuesGen.tvgDefault,
4318 TosaArgGen.agSlice,
4319 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004320 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004321 "error_if_validators": (
4322 TosaErrorValidator.evStartSmallerZero,
4323 TosaErrorValidator.evSizeSmallerEqualZero,
4324 TosaErrorValidator.evStartSizeOutsideBounds,
4325 TosaErrorValidator.evSizeOutputShapeMismatch,
4326 TosaErrorValidator.evInputSizeStartLengthMismatch,
4327 TosaErrorValidator.evWrongRank,
4328 TosaErrorValidator.evWrongInputType,
4329 TosaErrorValidator.evWrongOutputType,
4330 TosaErrorValidator.evWrongInputList,
4331 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004332 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004333 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004334 },
4335 "tile": {
4336 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004337 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004338 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004339 "build_fcn": (
4340 build_tile,
4341 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004342 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004343 TosaArgGen.agTile,
4344 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004345 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004346 "error_if_validators": (
4347 TosaErrorValidator.evWrongInputType,
4348 TosaErrorValidator.evWrongOutputType,
4349 TosaErrorValidator.evWrongInputList,
4350 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004351 TosaErrorValidator.evRankMismatch,
4352 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004353 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004354 "data_gen": {
4355 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4356 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004357 },
4358 "transpose": {
4359 "op": Op.TRANSPOSE,
4360 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004361 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004362 "build_fcn": (
4363 build_transpose,
4364 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004365 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004366 TosaArgGen.agTranspose,
4367 ),
4368 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004369 "error_if_validators": (
4370 TosaErrorValidator.evIndexOutsideBounds,
4371 TosaErrorValidator.evIndexUsedTwice,
4372 TosaErrorValidator.evWrongInputType,
4373 TosaErrorValidator.evWrongOutputType,
4374 TosaErrorValidator.evWrongInputList,
4375 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004376 TosaErrorValidator.evWrongRank,
4377 TosaErrorValidator.evRankMismatch,
4378 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004379 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004380 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004381 # Data nodes
4382 "const": {
4383 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004384 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004385 "build_fcn": (
4386 build_const,
4387 TosaTensorGen.tgBasic,
4388 TosaTensorValuesGen.tvgDefault,
4389 None,
4390 ),
Luke Hutton65872422023-02-20 10:33:04 +00004391 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004392 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004393 "identity": {
4394 "op": Op.IDENTITY,
4395 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004396 "build_fcn": (
4397 build_unary,
4398 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004399 TosaTensorValuesGen.tvgLazyGenDefault,
4400 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004401 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004402 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004403 "data_gen": {
4404 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4405 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004406 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004407 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004408 "gather": {
4409 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004410 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004411 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004412 "build_fcn": (
4413 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004414 TosaTensorGen.tgGather,
4415 TosaTensorValuesGen.tvgGather,
4416 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004417 ),
James Ward24dbc422022-10-19 12:20:31 +01004418 "types": (
4419 DType.INT8,
4420 DType.INT16,
4421 DType.INT32,
4422 DType.FP16,
4423 DType.BF16,
4424 DType.FP32,
4425 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004426 "error_if_validators": (
4427 TosaErrorValidator.evWrongInputType,
4428 TosaErrorValidator.evWrongOutputType,
4429 TosaErrorValidator.evWrongInputList,
4430 TosaErrorValidator.evWrongOutputList,
4431 TosaErrorValidator.evWrongRank,
4432 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004433 "data_gen": {
4434 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4435 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004436 },
4437 "scatter": {
4438 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004439 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004440 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004441 "build_fcn": (
4442 build_scatter,
4443 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004444 TosaTensorValuesGen.tvgScatter,
4445 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004446 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004447 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004448 "error_if_validators": (
4449 TosaErrorValidator.evWrongInputType,
4450 TosaErrorValidator.evWrongOutputType,
4451 TosaErrorValidator.evWrongInputList,
4452 TosaErrorValidator.evWrongOutputList,
4453 TosaErrorValidator.evWrongRank,
4454 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004455 "data_gen": {
4456 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4457 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004458 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004459 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004460 "resize": {
4461 "op": Op.RESIZE,
4462 "operands": (1, 0),
4463 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004464 "build_fcn": (
4465 build_resize,
4466 TosaTensorGen.tgNHWC,
4467 TosaTensorValuesGen.tvgDefault,
4468 TosaArgGen.agResize,
4469 ),
James Ward24dbc422022-10-19 12:20:31 +01004470 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004471 "invalid_test_validators": (
4472 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004473 ),
4474 "error_if_validators": (
4475 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004476 TosaErrorValidator.evScaleSmallerEqualZero,
4477 TosaErrorValidator.evScaleNLargerMax,
4478 TosaErrorValidator.evScaleDLargerMax,
4479 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004480 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004481 TosaErrorValidator.evBorderSmallerMin,
4482 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004483 TosaErrorValidator.evWrongInputType,
4484 TosaErrorValidator.evWrongOutputType,
4485 TosaErrorValidator.evWrongRank,
4486 TosaErrorValidator.evWrongInputList,
4487 TosaErrorValidator.evWrongOutputList,
4488 TosaErrorValidator.evBatchMismatch,
4489 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004490 TosaErrorValidator.evResizeOutputShapeMismatch,
4491 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004492 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004493 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004494 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004495 "cast": {
4496 "op": Op.CAST,
4497 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004498 "build_fcn": (
4499 build_cast,
4500 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004501 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004502 TosaArgGen.agCast,
4503 ),
James Ward8b390432022-08-12 20:48:56 +01004504 "types": (
4505 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004506 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004507 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004508 DType.INT8,
4509 DType.INT16,
4510 DType.INT32,
4511 DType.BOOL,
4512 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004513 "error_if_validators": (
4514 TosaErrorValidator.evWrongInputType,
4515 TosaErrorValidator.evWrongOutputType,
4516 TosaErrorValidator.evWrongInputList,
4517 TosaErrorValidator.evWrongOutputList,
4518 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004519 "data_gen": {
4520 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4521 },
4522 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004523 },
4524 "rescale": {
4525 "op": Op.RESCALE,
4526 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004527 "build_fcn": (
4528 build_rescale,
4529 TosaTensorGen.tgBasic,
4530 TosaTensorValuesGen.tvgDefault,
4531 TosaArgGen.agRescale,
4532 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004533 "types": [
4534 DType.UINT8,
4535 DType.INT8,
4536 DType.INT16,
4537 DType.INT32,
4538 DType.INT48,
4539 DType.UINT16,
4540 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004541 "error_if_validators": (
4542 TosaErrorValidator.evInputZeroPointNotZero,
4543 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004544 TosaErrorValidator.evU16InputZeroPointNotValid,
4545 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004546 TosaErrorValidator.evScaleTrue,
4547 TosaErrorValidator.evScaleNotTrue,
4548 TosaErrorValidator.evWrongInputType,
4549 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004550 TosaErrorValidator.evWrongInputList,
4551 TosaErrorValidator.evWrongOutputList,
4552 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004553 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004554 # Custom
4555 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004556 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004557 # Two varients of cond_if, one that generates one of two constant tensors (no
4558 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4559 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004560 "cond_if_const": {
4561 "op": Op.COND_IF,
4562 "operands": (0, 2),
4563 "build_fcn": (
4564 build_cond_if_const,
4565 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004566 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004567 TosaArgGen.agCondIf,
4568 ),
4569 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004570 "error_if_validators": (
4571 TosaErrorValidator.evOutputListThenGraphMismatch,
4572 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004573 TosaErrorValidator.evCondIfCondNotMatchingBool,
4574 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004575 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004576 },
4577 "cond_if_binary": {
4578 "op": Op.COND_IF,
4579 "operands": (2, 0),
4580 "build_fcn": (
4581 build_cond_if_binary,
4582 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004583 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004584 TosaArgGen.agCondIf,
4585 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004586 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004587 "error_if_validators": (
4588 TosaErrorValidator.evInputListThenGraphMismatch,
4589 TosaErrorValidator.evInputListElseGraphMismatch,
4590 TosaErrorValidator.evOutputListThenGraphMismatch,
4591 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004592 TosaErrorValidator.evCondIfCondNotMatchingBool,
4593 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004594 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004595 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004596 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004597 "while_loop": {
4598 "op": Op.WHILE_LOOP,
4599 "operands": (0, 1),
4600 "build_fcn": (
4601 build_while_loop,
4602 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004603 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004604 TosaArgGen.agWhileLoop,
4605 ),
4606 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004607 "error_if_validators": (
4608 TosaErrorValidator.evInputListOutputListMismatch,
4609 TosaErrorValidator.evInputListCondGraphMismatch,
4610 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4611 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4612 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004613 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004614 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004615 },
Luke Hutton57287132023-02-06 14:54:18 +00004616 "fft2d": {
4617 "op": Op.FFT2D,
4618 "operands": (2, 0),
4619 "rank": (3, 3),
4620 "build_fcn": (
4621 build_fft2d,
4622 TosaTensorGen.tgFFT2d,
4623 TosaTensorValuesGen.tvgDefault,
4624 TosaArgGen.agFFT2d,
4625 ),
4626 "types": [DType.FP32],
4627 "error_if_validators": (
4628 TosaErrorValidator.evWrongInputType,
4629 TosaErrorValidator.evWrongOutputType,
4630 TosaErrorValidator.evWrongInputList,
4631 TosaErrorValidator.evWrongOutputList,
4632 TosaErrorValidator.evWrongRank,
4633 TosaErrorValidator.evBatchMismatch,
4634 TosaErrorValidator.evKernelNotPowerOfTwo,
4635 TosaErrorValidator.evFFTInputShapeMismatch,
4636 TosaErrorValidator.evFFTOutputShapeMismatch,
4637 ),
4638 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004639 "rfft2d": {
4640 "op": Op.RFFT2D,
4641 "operands": (1, 0),
4642 "rank": (3, 3),
4643 "build_fcn": (
4644 build_rfft2d,
4645 TosaTensorGen.tgRFFT2d,
4646 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004647 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004648 ),
4649 "types": [DType.FP32],
4650 "error_if_validators": (
4651 TosaErrorValidator.evWrongInputType,
4652 TosaErrorValidator.evWrongOutputType,
4653 TosaErrorValidator.evWrongInputList,
4654 TosaErrorValidator.evWrongOutputList,
4655 TosaErrorValidator.evWrongRank,
4656 TosaErrorValidator.evBatchMismatch,
4657 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004658 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004659 ),
4660 },
Won Jeon74342e52024-01-09 00:34:40 +00004661 # Shape
4662 "add_shape": {
4663 "op": Op.ADD_SHAPE,
4664 "operands": (2, 0),
4665 "build_fcn": (
4666 build_shape_op,
4667 TosaTensorGen.tgShape,
4668 TosaTensorValuesGen.tvgAddSub,
4669 TosaArgGen.agNone,
4670 ),
4671 "types": [DType.SHAPE],
4672 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4673 },
4674 "sub_shape": {
4675 "op": Op.SUB_SHAPE,
4676 "operands": (2, 0),
4677 "build_fcn": (
4678 build_shape_op,
4679 TosaTensorGen.tgShape,
4680 TosaTensorValuesGen.tvgAddSub,
4681 TosaArgGen.agNone,
4682 ),
4683 "types": [DType.SHAPE],
4684 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4685 },
4686 "mul_shape": {
4687 "op": Op.MUL_SHAPE,
4688 "operands": (2, 0),
4689 "build_fcn": (
4690 build_shape_op,
4691 TosaTensorGen.tgShape,
4692 TosaTensorValuesGen.tvgMul,
4693 TosaArgGen.agNone,
4694 ),
4695 "types": [DType.SHAPE],
4696 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4697 },
4698 "div_shape": {
4699 "op": Op.DIV_SHAPE,
4700 "operands": (2, 0),
4701 "build_fcn": (
4702 build_shape_op,
4703 TosaTensorGen.tgShape,
4704 TosaTensorValuesGen.tvgIntDiv,
4705 TosaArgGen.agNone,
4706 ),
4707 "types": [DType.SHAPE],
4708 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4709 },
4710 "concat_shape": {
4711 "op": Op.CONCAT_SHAPE,
4712 "operands": (2, 0),
4713 "build_fcn": (
4714 build_concat,
4715 TosaTensorGen.tgConcat,
4716 TosaTensorValuesGen.tvgConcat,
4717 TosaArgGen.agNone,
4718 ),
4719 "types": [DType.SHAPE],
4720 "error_if_validators": (),
4721 },
4722 "const_shape": {
4723 "op": Op.CONST_SHAPE,
4724 "operands": (0, 1),
4725 "build_fcn": (
4726 build_const,
4727 TosaTensorGen.tgBasic,
4728 TosaTensorValuesGen.tvgDefault,
4729 None,
4730 ),
4731 "types": [DType.SHAPE],
4732 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004733 }
4734
Kevin Cheng550ccc52021-03-03 11:21:43 -08004735
Eric Kunzee5e26762020-10-13 16:11:07 -07004736class OutputShaper:
4737 # Methods in this class compute the expected output shape and datatype
4738 # for common classes of operations
4739 def __init__(self):
4740 pass
4741
4742 # These methods return arguments that can be used for
4743 # creating a new output tensor
4744 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004745 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4746 if error_name != ErrorIf.RankMismatch:
4747 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004748 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004749
4750 shape = []
4751 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004752 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004753 shape.append(b.shape[i])
4754 else:
4755 shape.append(a.shape[i])
4756
Jerry Ge135c9552023-05-23 20:59:32 +00004757 fuzz_idx = rng.integers(0, len(a.shape))
4758 if error_name == ErrorIf.DimensionMismatch:
4759 shape[fuzz_idx] += 1
4760
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004761 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004762 all_dtypes = [
4763 DType.INT8,
4764 DType.INT16,
4765 DType.INT32,
4766 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004767 DType.FP16,
4768 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004769 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004770 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004771 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4772 outputDType = rng.choice(wrong_dtypes)
4773 else:
4774 outputDType = a.dtype
4775
4776 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004777
4778 @staticmethod
4779 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004780 assert len(a.shape) == len(b.shape)
4781 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004782
4783 shape = []
4784 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004785 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004786 shape.append(a.shape[i])
4787
Kevin Cheng550ccc52021-03-03 11:21:43 -08004788 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004789
4790 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004791 def unaryOp(ser, rng, a, error_name=None):
4792 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004793 all_dtypes = [
4794 DType.INT8,
4795 DType.INT16,
4796 DType.INT32,
4797 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004798 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004799 DType.FP16,
4800 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004801 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004802 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4803 outputDType = rng.choice(wrong_dtypes)
4804 else:
4805 outputDType = a.dtype
4806
4807 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004808
4809 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004810 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004811 if error_name != ErrorIf.RankMismatch:
4812 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004813 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004814
4815 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004816 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004817 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004818 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4819 else:
4820 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004821
Jerry Ge135c9552023-05-23 20:59:32 +00004822 fuzz_idx = rng.integers(0, len(a.shape))
4823 if error_name == ErrorIf.DimensionMismatch:
4824 shape[fuzz_idx] += 1
4825
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004826 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004827 all_dtypes = [
4828 DType.INT8,
4829 DType.INT16,
4830 DType.INT32,
4831 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004832 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004833 DType.FP16,
4834 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004835 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004836 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4837 outputDType = rng.choice(wrong_dtypes)
4838 else:
4839 outputDType = a.dtype
4840
4841 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004842
4843 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004844 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004845 if error_name != ErrorIf.RankMismatch:
4846 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004847 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004848
4849 # Do broadcast
4850 shape = []
4851 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004852 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004853 shape.append(b.shape[i])
4854 else:
4855 shape.append(a.shape[i])
4856
Jerry Ge135c9552023-05-23 20:59:32 +00004857 fuzz_idx = rng.integers(0, len(a.shape))
4858 if error_name == ErrorIf.DimensionMismatch:
4859 shape[fuzz_idx] += 1
4860
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004861 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004862 wrong_dtypes = [
4863 DType.INT8,
4864 DType.INT16,
4865 DType.INT32,
4866 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004867 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004868 DType.FP16,
4869 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004870 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004871 outputDType = rng.choice(wrong_dtypes)
4872 else:
4873 outputDType = DType.BOOL
4874
4875 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004876
4877 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004878 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004879 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004880 if error_name not in [
4881 ErrorIf.AxisSmallerZero,
4882 ErrorIf.AxisLargerRank,
4883 ErrorIf.ShapeOfAxisNotOne,
4884 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004885 shape[axis] = 1
4886 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4887 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004888
Matthew Haddond6ce7252021-09-29 15:35:44 +01004889 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004890 all_dtypes = [
4891 DType.INT8,
4892 DType.INT16,
4893 DType.INT32,
4894 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004895 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004896 DType.FP16,
4897 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004898 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004899 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4900 outputDType = rng.choice(wrong_dtypes)
4901 else:
4902 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004903
Matthew Haddond6ce7252021-09-29 15:35:44 +01004904 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004905
4906 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004907 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004908 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004909
4910 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4911 del shape[axis]
4912
4913 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4914 remove = rng.choice([True, False])
4915 if remove and len(shape) > 1:
4916 del shape[0]
4917 else:
4918 shape.append(1)
4919 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4920 for i in range(len(shape)):
4921 shape[i] = shape[i] + rng.integers(1, 10)
4922
4923 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004924 all_dtypes = [
4925 DType.INT8,
4926 DType.INT16,
4927 DType.INT32,
4928 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004929 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004930 DType.FP16,
4931 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004932 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004933 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4934 outputDType = rng.choice(wrong_dtypes)
4935 else:
4936 outputDType = DType.INT32
4937
4938 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004939
4940 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004941 def conv2dOp(
4942 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4943 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004944
4945 # IFM: NHWC
4946 # Filter: OHWI
4947 # OFM: NHWC
4948
Kevin Cheng550ccc52021-03-03 11:21:43 -08004949 h = (
4950 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004951 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004952 + padding[0]
4953 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004954 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004955 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004956
Kevin Cheng550ccc52021-03-03 11:21:43 -08004957 w = (
4958 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004959 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004960 + padding[2]
4961 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004962 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004963 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004964
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004965 if error_name == ErrorIf.ConvOutputShapeMismatch:
4966 choices = [1, 2, 3]
4967 change = rng.choice(choices)
4968 # increment in multiples of stride to not hit non-integer error case
4969 if change in [1, 3]:
4970 h = h + (rng.choice(choices) * strides[0])
4971 if change in [2, 3]:
4972 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004973
Eric Kunzee5e26762020-10-13 16:11:07 -07004974 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4975
James Ward8b390432022-08-12 20:48:56 +01004976 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004977 # Pick some potentially correct output dtype if input type is incorrect
4978 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004979 else:
James Ward8b390432022-08-12 20:48:56 +01004980 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004981
4982 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004983 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004984 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004985 else:
4986 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004987 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004988 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004989
Kevin Cheng550ccc52021-03-03 11:21:43 -08004990 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004991
4992 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004993 def conv3dOp(
4994 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4995 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004996
4997 # IFM: NDHWC
4998 # Filter: ODHWI
4999 # OFM: NDHWC
5000
5001 d = (
5002 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005003 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005004 + padding[0]
5005 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005006 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005007 ) // strides[0] + 1
5008
5009 h = (
5010 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005011 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005012 + padding[2]
5013 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005014 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005015 ) // strides[1] + 1
5016
5017 w = (
5018 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005019 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005020 + padding[4]
5021 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005022 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005023 ) // strides[2] + 1
5024
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005025 if error_name == ErrorIf.ConvOutputShapeMismatch:
5026 choices = [1, 2, 3, 4]
5027 change = rng.choice(choices)
5028 # increment in multiples of stride to not hit non-integer error case
5029 if change in [1, 4]:
5030 d = d + (rng.choice(choices) * strides[0])
5031 if change in [2, 4]:
5032 h = h + (rng.choice(choices) * strides[1])
5033 if change in [3, 4]:
5034 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005035
Kevin Cheng1533b852021-09-01 12:51:58 -07005036 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5037
James Ward8b390432022-08-12 20:48:56 +01005038 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005039 # Pick some potentially correct output dtype if input type is incorrect
5040 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005041 else:
James Ward8b390432022-08-12 20:48:56 +01005042 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005043
5044 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005045 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005046 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005047 else:
5048 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005049 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005050 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005051
5052 return ser.addOutput(ofm_shape, out_dtype)
5053
5054 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005055 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005056 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005057 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005058 # IFM: NHWC
5059 # Filter: HWCM
5060 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005061
Kevin Cheng550ccc52021-03-03 11:21:43 -08005062 h = (
5063 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005064 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005065 + padding[0]
5066 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005067 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005068 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005069
Kevin Cheng550ccc52021-03-03 11:21:43 -08005070 w = (
5071 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005072 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005073 + padding[2]
5074 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005075 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005076 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005077
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005078 if error_name == ErrorIf.ConvOutputShapeMismatch:
5079 choices = [1, 2, 3]
5080 change = rng.choice(choices)
5081 # increment in multiples of stride to not hit non-integer error case
5082 if change in [1, 3]:
5083 h = h + (rng.choice(choices) * strides[0])
5084 if change in [2, 3]:
5085 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005086
Eric Kunzee5e26762020-10-13 16:11:07 -07005087 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5088
James Ward8b390432022-08-12 20:48:56 +01005089 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005090 # Pick some potentially correct output dtype if input type is incorrect
5091 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005092 else:
James Ward8b390432022-08-12 20:48:56 +01005093 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005094
5095 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005096 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005097 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005098 else:
5099 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005100 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005101 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005102
Kevin Cheng550ccc52021-03-03 11:21:43 -08005103 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005104
5105 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005106 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005107 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005108 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005109 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005110 h = 1
5111 w = 1
5112 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005113 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5114 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005115
5116 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005117 choices = [1, 2, 3]
5118 change = rng.choice(choices)
5119 # increment in multiples of stride to not hit non-integer error case
5120 if change in [1, 3]:
5121 h = h + (rng.choice(choices) * stride[0])
5122 if change in [2, 3]:
5123 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005124 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005125
5126 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005127 all_dtypes = [
5128 DType.INT8,
5129 DType.INT16,
5130 DType.INT32,
5131 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005132 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005133 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005134 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005135 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005136 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5137 outputDType = rng.choice(wrong_dtypes)
5138 else:
5139 outputDType = ifm.dtype
5140
5141 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005142
5143 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005144 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005145 # input: N, IC
5146 # filter: OC, IC
5147 # output: N, OC
5148
5149 output_shape = [input.shape[0], filter.shape[0]]
5150
James Ward8b390432022-08-12 20:48:56 +01005151 # Validated in arg_gen (also invalidated for ErrorIf)
5152 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005153
Kevin Cheng550ccc52021-03-03 11:21:43 -08005154 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005155
5156 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005157 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005158 # a: N, H, C
5159 # b: N, C, W
5160 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005161
Kevin Cheng2d60f002021-06-09 14:18:32 -07005162 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005163
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005164 if error_name == ErrorIf.WrongOutputType:
5165 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005166 incorrect_types = (
5167 DType.INT4,
5168 DType.INT8,
5169 DType.INT16,
5170 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005171 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005172 DType.FP16,
5173 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005174 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005175 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005176 incorrect_types = (
5177 DType.INT4,
5178 DType.INT8,
5179 DType.INT16,
5180 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005181 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005182 DType.FP16,
5183 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005184 )
James Ward24dbc422022-10-19 12:20:31 +01005185 elif (
5186 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5187 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005188 incorrect_types = (
5189 DType.INT4,
5190 DType.INT8,
5191 DType.INT16,
5192 DType.INT32,
5193 DType.INT48,
5194 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005195 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005196 elif error_name == ErrorIf.WrongInputType:
5197 # Pick some potentially correct output dtype if input type is incorrect
5198 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005199 else:
James Ward8b390432022-08-12 20:48:56 +01005200 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005201
Kevin Cheng550ccc52021-03-03 11:21:43 -08005202 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005203
5204 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005205 def concatOp(ser, rng, axis, inputs, error_name=None):
5206 input1 = inputs[0]
5207 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005208
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005209 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005210 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005211 if not (
5212 # unable to concat tensors of different ranks
5213 error_name == ErrorIf.ConcatInputRankMismatch
5214 # unable to concat tensors along an invalid axis
5215 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005216 ):
5217 for tensor in remaining_inputs:
5218 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005219
Matthew Haddon01c359d2021-10-15 16:30:48 +01005220 if error_name == ErrorIf.ConcatShapeSumMismatch:
5221 output_shape[axis] += rng.integers(5, 10)
5222
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005223 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005224 all_dtypes = {
5225 DType.INT8,
5226 DType.INT16,
5227 DType.INT32,
5228 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005229 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005230 DType.FP16,
5231 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005232 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005233 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5234 outputDType = rng.choice(wrong_dtypes)
5235 else:
5236 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005237
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005238 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005239
5240 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005241 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005242
5243 output_shape = a.shape.copy()
5244
5245 for i in range(len(output_shape)):
5246 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5247
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005248 if error_name == ErrorIf.PadOutputShapeMismatch:
5249 bad_dim = rng.choice(range(len(output_shape)))
5250 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005251 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005252 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005253
Matthew Haddone807aae2021-10-11 18:12:58 +01005254 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005255 all_dtypes = [
5256 DType.INT8,
5257 DType.INT16,
5258 DType.INT32,
5259 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005260 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005261 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005262 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005263 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005264 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5265 outputDType = rng.choice(wrong_dtypes)
5266 else:
5267 outputDType = a.dtype
5268
5269 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005270
5271 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005272 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005273 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005274
5275 if error_name == ErrorIf.WrongOutputType:
5276 all_dtypes = [
5277 DType.INT8,
5278 DType.INT16,
5279 DType.INT32,
5280 DType.INT48,
5281 DType.FP32,
5282 DType.FP16,
5283 DType.BF16,
5284 ]
5285 wrong_dtypes = list(set(all_dtypes))
5286 outputDType = rng.choice(wrong_dtypes)
5287 else:
5288 outputDType = DType.SHAPE
5289
5290 return ser.addOutput(output_shape, outputDType)
5291
5292 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005293 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005294 output_shape = shape.copy()
5295
Matthew Haddone807aae2021-10-11 18:12:58 +01005296 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5297 for i in range(len(output_shape)):
5298 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5299
5300 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005301 all_dtypes = [
5302 DType.INT8,
5303 DType.INT16,
5304 DType.INT32,
5305 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005306 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005307 DType.FP16,
5308 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005309 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005310 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5311 outputDType = rng.choice(wrong_dtypes)
5312 else:
5313 outputDType = a.dtype
5314
5315 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005316
5317 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005318 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005319
Matthew Haddone807aae2021-10-11 18:12:58 +01005320 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005321 all_dtypes = [
5322 DType.INT8,
5323 DType.INT16,
5324 DType.INT32,
5325 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005326 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005327 DType.FP16,
5328 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005329 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005330 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005331 outputDType = rng.choice(wrong_dtypes)
5332 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005333 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005334
Luke Huttona4e48ca2023-02-22 11:53:48 +00005335 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005336 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005337 for index in range(len(output_shape)):
5338 if output_shape[index] <= 2:
5339 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5340 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005341 output_shape[index] = output_shape[index] + rng.choice(
5342 [-2, -1, 1, 2]
5343 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005344 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5345 output_shape = input.shape.copy()
5346 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005347 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005348
5349 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005350
5351 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005352 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005353
5354 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005355 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005356
5357 for i in range(len(output_shape)):
5358 output_shape[i] = a.shape[i] * multiples[i]
5359
Luke Huttona4e48ca2023-02-22 11:53:48 +00005360 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005361 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005362
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005363 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005364 all_dtypes = [
5365 DType.INT8,
5366 DType.INT16,
5367 DType.INT32,
5368 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005369 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005370 DType.FP16,
5371 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005372 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005373 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5374 outputDType = rng.choice(wrong_dtypes)
5375 else:
5376 outputDType = a.dtype
5377
5378 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005379
5380 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005381 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005382 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005383
Kevin Cheng550ccc52021-03-03 11:21:43 -08005384 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005385
Luke Huttona4e48ca2023-02-22 11:53:48 +00005386 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005387 for i in range(len(output_shape)):
5388 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005389
Luke Huttona4e48ca2023-02-22 11:53:48 +00005390 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5391 for i in range(len(output_shape)):
5392 output_shape[i] += rng.integers(1, 10)
5393 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005394 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005395
Matthew Haddone807aae2021-10-11 18:12:58 +01005396 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005397 all_dtypes = [
5398 DType.INT8,
5399 DType.INT16,
5400 DType.INT32,
5401 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005402 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005403 DType.FP16,
5404 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005405 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005406 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5407 outputDType = rng.choice(wrong_dtypes)
5408 else:
5409 outputDType = a.dtype
5410
5411 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005412
5413 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005414 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005415 if error_name != ErrorIf.WrongRank:
5416 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005417 assert len(indices.shape) == 2
5418 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005419
Kevin Cheng77d0f762020-11-24 10:26:32 -08005420 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5421
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005422 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005423 all_dtypes = [
5424 DType.INT8,
5425 DType.INT16,
5426 DType.INT32,
5427 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005428 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005429 DType.FP16,
5430 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005431 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005432 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5433 outputDType = rng.choice(wrong_dtypes)
5434 else:
5435 outputDType = values.dtype
5436
5437 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005438
5439 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005440 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005441 if error_name != ErrorIf.WrongRank:
5442 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005443 assert len(indices.shape) == 2
5444 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005445 assert values_in.shape[0] == indices.shape[0] # N
5446 assert input.shape[1] == indices.shape[1] # W
5447 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005448
5449 output_shape = values_in.shape
5450
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005451 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005452 all_dtypes = [
5453 DType.INT8,
5454 DType.INT16,
5455 DType.INT32,
5456 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005457 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005458 DType.FP16,
5459 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005460 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005461 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5462 outputDType = rng.choice(wrong_dtypes)
5463 else:
5464 outputDType = values_in.dtype
5465
5466 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005467
5468 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005469 def tableOp(ser, rng, input, error_name=None):
5470 # Same shape as the input, dtype dependent on input dtype
5471 if error_name != ErrorIf.WrongInputType:
5472 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005473 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005474 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005475 wrong_dtypes = [
5476 DType.INT8,
5477 DType.INT16,
5478 DType.INT32,
5479 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005480 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005481 DType.FP16,
5482 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005483 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005484 wrong_dtypes.remove(output_dtype)
5485 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005486 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005487
5488 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005489 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005490 serializer,
5491 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005492 input,
5493 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005494 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005495 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005496 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005497 input_dtype,
5498 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005499 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005500 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005501 # Calculate OH, OW
5502 scale_y_n = scale[0]
5503 scale_y_d = scale[1]
5504 scale_x_n = scale[2]
5505 scale_x_d = scale[3]
5506 if error_name == ErrorIf.ScaleSmallerEqualZero:
5507 scale_y_n = max(scale_y_n, 1)
5508 scale_y_d = max(scale_y_d, 1)
5509 scale_x_n = max(scale_x_n, 1)
5510 scale_x_d = max(scale_x_d, 1)
5511
5512 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5513 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5514
5515 if error_name is not None:
5516 # Make sure the output tensor is valid, which can occur when
5517 # scale, offset or border have been changed for ERROR_IFs
5518 oh = max(oh, 1)
5519 ow = max(ow, 1)
5520 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005521 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5522 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005523
5524 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5525 choices = [1, 2, 3]
5526 change = rng.choice(choices)
5527 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5528 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005529 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005530 oh -= scale_y_d
5531 assert oh > 0 # Should have been caught in agResize
5532 else:
5533 oh += scale_y_d
5534 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005535 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005536 ow -= scale_x_d
5537 assert ow > 0 # Should have been caught in agResize
5538 else:
5539 ow += scale_x_d
5540
Matthew Haddon848efb42021-09-09 12:30:53 +01005541 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005542 output_dims = [
5543 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005544 oh,
5545 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005546 input.shape[0],
5547 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005548 elif error_name == ErrorIf.BatchMismatch:
5549 output_dims = [
5550 input.shape[0] + rng.integers(1, 10),
5551 oh,
5552 ow,
5553 input.shape[3],
5554 ]
5555 elif error_name == ErrorIf.ChannelMismatch:
5556 output_dims = [
5557 input.shape[0],
5558 oh,
5559 ow,
5560 input.shape[3] + rng.integers(1, 10),
5561 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005562 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005563 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005564
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005565 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005566
5567 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005568 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005569 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005570
5571 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005572 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005573 if error_name == ErrorIf.ConvOutputShapeMismatch:
5574 choices = [1, 2, 3]
5575 change = rng.choice(choices)
5576 if change in [1, 3]:
5577 output_shape[1] = output_shape[1] + rng.choice(choices)
5578 if change in [2, 3]:
5579 output_shape[2] = output_shape[2] + rng.choice(choices)
5580
James Ward8b390432022-08-12 20:48:56 +01005581 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005582 # Pick some potentially correct output dtype if input type is incorrect
5583 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005584 else:
James Ward8b390432022-08-12 20:48:56 +01005585 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005586
5587 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005588 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005589 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005590 else:
5591 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005592 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005593 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005594
Kevin Cheng550ccc52021-03-03 11:21:43 -08005595 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005596
5597 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005598 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5599 outputs = []
5600
5601 assert ifm1.dtype == ifm2.dtype
5602 input_dtype = ifm1.dtype
5603
5604 if error_name != ErrorIf.FFTInputShapeMismatch:
5605 assert ifm1.shape == ifm2.shape
5606
5607 input_shape = ifm1.shape
5608 if error_name != ErrorIf.WrongRank:
5609 assert len(input_shape) == 3
5610
5611 output_shape = input_shape.copy()
5612 output_dtype = input_dtype
5613
5614 if error_name == ErrorIf.WrongOutputType:
5615 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005616 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005617 output_dtype = rng.choice(wrong_dtypes)
5618 elif error_name == ErrorIf.BatchMismatch:
5619 output_shape[0] += rng.integers(1, 10)
5620 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5621 modify_dim = rng.choice([1, 2])
5622 output_shape[modify_dim] += rng.integers(1, 10)
5623
5624 outputs.append(serializer.addOutput(output_shape, output_dtype))
5625 outputs.append(serializer.addOutput(output_shape, output_dtype))
5626 return outputs
5627
5628 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005629 def rfft2dOp(serializer, rng, value, error_name=None):
5630 outputs = []
5631
5632 input_shape = value.shape
5633 if error_name != ErrorIf.WrongRank:
5634 assert len(input_shape) == 3
5635
5636 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5637
5638 output_dtype = value.dtype
5639 if error_name == ErrorIf.WrongOutputType:
5640 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005641 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005642 output_dtype = rng.choice(wrong_dtypes)
5643 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005644 output_shape[0] += rng.integers(1, 10)
5645 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5646 modify_dim = rng.choice([1, 2])
5647 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005648
5649 outputs.append(serializer.addOutput(output_shape, output_dtype))
5650 outputs.append(serializer.addOutput(output_shape, output_dtype))
5651 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005652
5653 @staticmethod
5654 def addShapeOp(ser, rng, a, b, error_name=None):
5655 if error_name != ErrorIf.RankMismatch:
5656 assert len(a.shape) == len(b.shape)
5657 assert a.dtype == b.dtype
5658
5659 shape = []
5660 for i in range(len(a.shape)):
5661 shape.append(a.shape[i])
5662
5663 fuzz_idx = rng.integers(0, len(a.shape))
5664 if error_name == ErrorIf.DimensionMismatch:
5665 shape[fuzz_idx] += 1
5666
5667 if error_name == ErrorIf.WrongOutputType:
5668 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5669 outputDType = rng.choice(wrong_dtypes)
5670 else:
5671 outputDType = DType.SHAPE
5672 return ser.addOutput(shape, outputDType)