blob: 159ee8322dadb27ce97f25a547ce3652839f6d52 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
79 for dtype in (DType.FP32, DType.FP16, DType.BF16):
80 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
170 elif dtype in (DType.INT32, DType.SHAPE):
171 # restricting too large value for SHAPE
172 rng = (-(1 << 31), (1 << 31))
173 elif dtype == DType.INT48:
174 rng = (-(1 << 47), (1 << 47))
175 else:
176 raise Exception("Unknown dtype: {}".format(dtype))
177
178 if not high_inclusive:
179 # Exclusive high: low <= range < high
180 return rng
181 else:
182 # Inclusive range: low <= range <= high
183 return (rng[0], rng[1] - 1)
184
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000185 def getRandTensor(self, shape, dtype, data_range=None):
186 if data_range is None:
187 low, high = self.getDTypeRange(dtype)
188 else:
189 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100190
Eric Kunzee5e26762020-10-13 16:11:07 -0700191 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100194 return np.int64(self.rng.integers(low=low, high=high, size=shape))
195 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
196 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
197
198 if dtype == DType.FP16:
199 return np.float16(f_tensor)
200 else:
201 f32_tensor = np.float32(f_tensor)
202 if dtype == DType.BF16:
203 # Floor the last 16 bits of each f32 value
204 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
205 else:
206 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700207 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100208 # All other integer types
209 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 placeholders = []
213
Kevin Cheng989cb052021-04-28 16:29:44 -0700214 assert len(shape_list) == len(dtype_list)
215
Jeremy Johnson1271c442023-09-05 11:39:26 +0100216 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700217 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100218 if not self.args.lazy_data_gen:
219 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700220 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700221
222 return placeholders
223
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700225 consts = []
226
Kevin Cheng989cb052021-04-28 16:29:44 -0700227 assert len(shape_list) == len(dtype_list)
228
Jeremy Johnson1271c442023-09-05 11:39:26 +0100229 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700230 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100231 if not self.args.lazy_data_gen:
232 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700233 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700234
235 return consts
236
237 def makeShape(self, rank):
238 if self.targetted_shape:
239 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800240 return np.int32(
241 self.rng.integers(
242 low=self.args.tensor_shape_range[0],
243 high=self.args.tensor_shape_range[1],
244 size=rank,
245 )
246 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700247
248 def setTargetShape(self, shape):
249 self.targetted_shape = shape
250
251 def randInt(self, low=0, high=256):
252 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
253
254 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100255 low, high = self.getDTypeRange(dtype)
256
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100257 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100258 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100259 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100261 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100262 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
263 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700264 elif dtype == DType.BOOL:
265 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000266 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700267 # Special size
268 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700269
270 return np.int32(self.rng.integers(low, high, size=1))[0]
271
272 def shapeStr(self, shape):
273
274 sStr = []
275 # Convert to strings
276 for i in shape:
277 sStr.append(str(i))
278
Kevin Cheng550ccc52021-03-03 11:21:43 -0800279 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700280
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100281 def typeStr(self, dtype):
282 if isinstance(dtype, list) or isinstance(dtype, tuple):
283 assert len(dtype) >= 2
284 strs = [self.typeStr(t) for t in dtype]
285 # Limit types to the first 2 as the 3rd is the accumulator
286 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700287 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100288 if dtype in gtu.DTYPE_ATTRIBUTES:
289 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700290 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100291 raise Exception(
292 "Unknown dtype, cannot convert to string: {}".format(dtype)
293 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700294
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100295 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100296 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100297 if dtype in gtu.DTYPE_ATTRIBUTES:
298 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700299 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100300 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700301
Luke Hutton57287132023-02-06 14:54:18 +0000302 def constrictBatchSize(self, shape):
303 # Limit the batch size unless an explicit target shape set
304 if self.args.max_batch_size and not self.args.target_shapes:
305 shape[0] = min(shape[0], self.args.max_batch_size)
306 return shape
307
James Ward30124a82023-02-02 14:56:33 +0000308 def makeDimension(self):
309 return self.randInt(
310 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
311 )
312
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100313 def tensorComplianceMetaData(
314 self, op, inputType, argsDict, outputTensor, errorName
315 ):
Jeremy Johnson708da822023-11-15 16:25:45 +0000316 # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet
317 UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100318 if (
319 errorName
320 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000321 or (
322 not gtu.dtypeIsSupportedByCompliance(inputType)
323 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
324 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100325 ):
326 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100327 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100328
Jeremy Johnson1271c442023-09-05 11:39:26 +0100329 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100330 compliance_tens = {
331 "mode": None,
332 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
333 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
334 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100335 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
336 mode = gtu.ComplianceMode.DOT_PRODUCT
337 compliance_tens["dot_product_info"] = {
338 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100339 "ks": int(argsDict["ksb"])
340 if "ksb" in argsDict
341 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100342 }
343 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
344 mode = gtu.ComplianceMode.FP_SPECIAL
345 elif "compliance" in op and "ulp" in op["compliance"]:
346 mode = gtu.ComplianceMode.ULP
347 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
348 elif op["op"] == Op.REDUCE_PRODUCT:
349 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000350 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000351 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000352 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000353 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
354 compliance_tens["abs_error_info"] = {
355 "lower_bound": op["compliance"]["abs_error_lower_bound"]
356 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100357 else:
358 mode = gtu.ComplianceMode.EXACT
359 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
360
361 return compliance_tens
362
363 # Build Op functions
364 # Create the output tensor (calling OutputShaper as needed)
365 # Do final tweaks to attributes (if necessary for errorIf)
366 # Add Op into graph
367 # Return resulting tensor information or BuildInfo
368
369 class BuildInfo:
370 """Enhanced build information containing result tensor and associated compliance dict."""
371
372 def __init__(self, resultTensor, complianceDict):
373 self.resultTensor = resultTensor
374 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700375
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000376 def build_unary(
377 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
378 ):
379 assert len(inputs) == 1
380 a = inputs[0]
381 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100382
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000383 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100384
385 # Ensure new output type has correct qinfo
386 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000387 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000388 qinfo = [
389 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000390 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000391 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100392
393 # Invalidate Input/Output list for error if checks.
394 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000395 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100396 pCount, cCount = op["operands"]
397 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000398 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
399 self, error_name, input_list, output_list
400 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100401
Les Bell729b0352021-11-24 10:28:21 +0000402 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100403 self.ser,
404 validator_fcns,
405 error_name,
406 op=op,
407 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000408 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000409 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000410 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100411 input_list=input_list,
412 output_list=output_list,
413 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000414 ):
415 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100416
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000417 attr = None
418 if op["op"] == Op.NEGATE:
419 attr = ts.TosaSerializerAttribute()
420 attr.NegateAttribute(qinfo[0], qinfo[1])
421
422 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000423
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000424 compliance = self.tensorComplianceMetaData(
425 op, a.dtype, args_dict, result_tensor, error_name
426 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000427 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700428
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000429 def build_binary_broadcast(
430 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
431 ):
432 assert len(inputs) == 2
433 a, b = inputs
434 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000435 self.ser, self.rng, a, b, error_name
436 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100437
438 # Invalidate Input/Output list for error if checks.
439 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000440 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100441 pCount, cCount = op["operands"]
442 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000443 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
444 self, error_name, input_list, output_list
445 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100446
Les Bell729b0352021-11-24 10:28:21 +0000447 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100448 self.ser,
449 validator_fcns,
450 error_name,
451 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000452 input1=a,
453 input2=b,
454 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000455 output_dtype=result_tensor.dtype,
456 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100457 input_list=input_list,
458 output_list=output_list,
459 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000460 ):
461 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100462
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000463 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000464
Jeremy Johnson9a758382023-11-07 16:27:35 +0000465 compliance = self.tensorComplianceMetaData(
466 op, a.dtype, args_dict, result_tensor, error_name
467 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000468
469 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700470
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100471 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700472 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000473 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700474 return result_tens
475
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000476 def build_arithmetic_right_shift(
477 self, op, a, b, round, validator_fcns=None, error_name=None
478 ):
479 result_tens = OutputShaper.binaryBroadcastOp(
480 self.ser, self.rng, a, b, error_name
481 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100482
483 # Invalidate Input/Output list for error if checks.
484 input_list = [a.name, b.name]
485 output_list = [result_tens.name]
486 pCount, cCount = op["operands"]
487 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000488 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
489 self, error_name, input_list, output_list
490 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100491
Les Bell729b0352021-11-24 10:28:21 +0000492 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100493 self.ser,
494 validator_fcns,
495 error_name,
496 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000497 input1=a,
498 input2=b,
499 input_dtype=a.dtype,
500 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000501 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100502 input_list=input_list,
503 output_list=output_list,
504 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000505 ):
506 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800507
508 attr = ts.TosaSerializerAttribute()
509 attr.ArithmeticRightShiftAttribute(round)
510
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000511 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800512 return result_tens
513
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100514 def build_mul(
515 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
516 ):
517 assert len(inputs) == 2
518 a, b = inputs
519 shift = args_dict["shift"]
520
521 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000522 self.ser, self.rng, a, b, error_name
523 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700524
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100525 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100526 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100527 result_tensor.setDtype(DType.INT32)
528
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100529 if error_name == ErrorIf.WrongOutputType:
530 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
531 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100532 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100533
534 # Invalidate Input/Output list for error if checks.
535 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100536 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100537 pCount, cCount = op["operands"]
538 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000539 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
540 self, error_name, input_list, output_list
541 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100542
Les Bell729b0352021-11-24 10:28:21 +0000543 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100544 self.ser,
545 validator_fcns,
546 error_name,
547 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000548 input1=a,
549 input2=b,
550 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100551 output_dtype=result_tensor.dtype,
552 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100553 input_list=input_list,
554 output_list=output_list,
555 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000556 ):
557 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700558
Kevin Chengaee1fac2020-11-11 13:54:06 -0800559 attr = ts.TosaSerializerAttribute()
560 attr.MulAttribute(shift)
561
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000562 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100563
564 compliance = self.tensorComplianceMetaData(
565 op, a.dtype, args_dict, result_tensor, error_name
566 )
567
568 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700569
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100570 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
571 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700572
Kevin Chengfe392ce2021-10-18 21:51:55 +0000573 attr = ts.TosaSerializerAttribute()
574 attr.TableAttribute(table)
575
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100576 # Invalidate Input/Output list for error if checks.
577 input_list = [a.name]
578 output_list = [result_tens.name]
579 pCount, cCount = op["operands"]
580 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000581 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
582 self, error_name, input_list, output_list
583 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100584
Les Bell729b0352021-11-24 10:28:21 +0000585 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100586 self.ser,
587 validator_fcns,
588 error_name,
589 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000590 input_shape=a.shape,
591 input_dtype=a.dtype,
592 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000593 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100594 input_list=input_list,
595 output_list=output_list,
596 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000597 ):
598 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100599
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000600 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700601
602 return result_tens
603
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000604 def build_select(
605 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
606 ):
607 assert len(inputs) == 3
608 cond, a, b = inputs
609
610 result_tensor = OutputShaper.selectOp(
611 self.ser, self.rng, cond, a, b, error_name
612 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100613
614 # Invalidate Input/Output list for error if checks.
615 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000616 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100617 pCount, cCount = op["operands"]
618 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000619 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
620 self, error_name, input_list, output_list
621 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622
Les Bell729b0352021-11-24 10:28:21 +0000623 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100624 self.ser,
625 validator_fcns,
626 error_name,
627 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000628 input1=cond,
629 input2=a,
630 input3=b,
631 input_shape=a.shape,
632 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000633 output_dtype=result_tensor.dtype,
634 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100635 input_list=input_list,
636 output_list=output_list,
637 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000638 ):
639 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100640
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000641 self.ser.addOperator(
642 op["op"],
643 input_list,
644 output_list,
645 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000646 compliance = self.tensorComplianceMetaData(
647 op, a.dtype, args_dict, result_tensor, error_name
648 )
649
650 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700651
Jeremy Johnsona0150012023-11-15 15:52:06 +0000652 def build_comparison(
653 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
654 ):
655 assert len(inputs) == 2
656 a, b = inputs
657
658 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000659 self.ser, self.rng, a, b, error_name
660 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100661
662 # Invalidate Input/Output list for error if checks.
663 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000664 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100665 pCount, cCount = op["operands"]
666 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
668 self, error_name, input_list, output_list
669 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100670
Les Bell729b0352021-11-24 10:28:21 +0000671 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100672 self.ser,
673 validator_fcns,
674 error_name,
675 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000676 input1=a,
677 input2=b,
678 input_shape=a.shape,
679 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000680 output_shape=result_tensor.shape,
681 output_dtype=result_tensor.dtype,
682 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100683 input_list=input_list,
684 output_list=output_list,
685 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000686 ):
687 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100688
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000689 self.ser.addOperator(
690 op["op"],
691 input_list,
692 output_list,
693 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000694
695 compliance = self.tensorComplianceMetaData(
696 op, a.dtype, args_dict, result_tensor, error_name
697 )
698 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700699
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000700 def build_argmax(
701 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
702 ):
703 assert len(inputs) == 1
704 a = inputs[0]
705 axis = args_dict["axis"]
706 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100707
708 # Invalidate Input/Output list for error if checks.
709 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000710 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100711 pCount, cCount = op["operands"]
712 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000713 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
714 self, error_name, input_list, output_list
715 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100716
Les Bell729b0352021-11-24 10:28:21 +0000717 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100718 self.ser,
719 validator_fcns,
720 error_name,
721 op=op,
722 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000723 input_shape=a.shape,
724 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000725 output_shape=result_tensor.shape,
726 output_dtype=result_tensor.dtype,
727 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100728 input_list=input_list,
729 output_list=output_list,
730 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000731 ):
732 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700733
734 attr = ts.TosaSerializerAttribute()
735 attr.AxisAttribute(axis)
736
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000737 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000738
739 compliance = self.tensorComplianceMetaData(
740 op, inputs[0].dtype, args_dict, result_tensor, error_name
741 )
742 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700743
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000744 def build_pool2d(
745 self,
746 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100747 inputs,
748 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000749 validator_fcns=None,
750 error_name=None,
751 qinfo=None,
752 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100753 assert len(inputs) == 1
754 input = inputs[0]
755 # max_pool has no accum_dtype
756 accum_dtype = (
757 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
758 )
759 stride = args_dict["stride"]
760 pad = args_dict["pad"]
761 kernel = args_dict["kernel"]
762
Jeremy Johnson0601f802023-11-08 16:28:09 +0000763 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000764 self.ser, self.rng, input, kernel, stride, pad, error_name
765 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100766
767 # Ensure new output type has correct qinfo
768 if error_name == ErrorIf.WrongInputType:
769 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000770 qinfo = [
771 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000772 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000773 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100774
775 # Invalidate Input/Output list for error if checks.
776 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000777 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100778 pCount, cCount = op["operands"]
779 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000780 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
781 self, error_name, input_list, output_list
782 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100783
Les Bell729b0352021-11-24 10:28:21 +0000784 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100785 self.ser,
786 validator_fcns,
787 error_name,
788 op=op,
789 input_shape=input.shape,
790 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000791 output_shape=result_tensor.shape,
792 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100793 kernel=kernel,
794 stride=stride,
795 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000796 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000797 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100798 input_list=input_list,
799 output_list=output_list,
800 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000801 ):
802 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700803
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000804 if qinfo is None:
805 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700806
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000807 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100808 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000809
810 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700811
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100812 compliance = self.tensorComplianceMetaData(
813 op, inputs[0].dtype, args_dict, result_tensor, error_name
814 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100815
816 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100817
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000818 def build_conv2d(
819 self,
820 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100821 inputs,
822 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000823 validator_fcns=None,
824 error_name=None,
825 qinfo=None,
826 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100827 assert len(inputs) == 3
828 ifm, filter, bias = inputs
829 accum_dtype = args_dict["acc_type"]
830 strides = args_dict["stride"]
831 padding = args_dict["pad"]
832 dilations = args_dict["dilation"]
833
Kevin Cheng550ccc52021-03-03 11:21:43 -0800834 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100835 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100836 self.ser,
837 self.rng,
838 ifm,
839 filter,
840 accum_dtype,
841 strides,
842 padding,
843 dilations,
844 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000845 )
846
847 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000848 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
849 DType.INT8,
850 DType.UINT8,
851 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000852 qinfo = [
853 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100854 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000855 ]
Les Bell0e027d42021-11-09 14:42:14 +0000856
857 # Invalidate Input/Output list for error_if checks.
858 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100859 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000860 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000861 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
862 self, error_name, input_list, output_list
863 )
Les Bell0e027d42021-11-09 14:42:14 +0000864
Les Bell729b0352021-11-24 10:28:21 +0000865 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000866 self.ser,
867 validator_fcns,
868 error_name,
869 op=op,
870 input_dtype=ifm.dtype,
871 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100872 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000873 qinfo=qinfo,
874 input_list=input_list,
875 num_operands=num_operands,
876 output_list=output_list,
877 pad=padding,
878 stride=strides,
879 dilation=dilations,
880 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100881 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100882 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000883 ):
884 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700885
Tai Lyd3797f02023-11-15 23:06:19 +0000886 # TODO - Test local_bound, for now set local bound attribute to False
887 local_bound = False
888
Eric Kunzee5e26762020-10-13 16:11:07 -0700889 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000890 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700891
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000892 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100893
894 compliance = self.tensorComplianceMetaData(
895 op, ifm.dtype, args_dict, result_tensor, error_name
896 )
897
898 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700899
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000900 def build_conv3d(
901 self,
902 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100903 inputs,
904 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000905 validator_fcns=None,
906 error_name=None,
907 qinfo=None,
908 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100909 assert len(inputs) == 3
910 ifm, filter, bias = inputs
911 accum_dtype = args_dict["acc_type"]
912 strides = args_dict["stride"]
913 padding = args_dict["pad"]
914 dilations = args_dict["dilation"]
915
Kevin Cheng1533b852021-09-01 12:51:58 -0700916 assert len(padding) == 6
917 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100918 self.ser,
919 self.rng,
920 ifm,
921 filter,
922 accum_dtype,
923 strides,
924 padding,
925 dilations,
926 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000927 )
928
929 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000930 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
931 DType.INT8,
932 DType.UINT8,
933 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000934 qinfo = [
935 TosaQuantGen.getZeroPoint(self, ifm.dtype),
936 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
937 ]
Les Bell0e027d42021-11-09 14:42:14 +0000938
939 # Invalidate Input/Output list for error_if checks.
940 input_list = [ifm.name, filter.name, bias.name]
941 output_list = [result_tens.name]
942 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000943 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
944 self, error_name, input_list, output_list
945 )
Les Bell0e027d42021-11-09 14:42:14 +0000946
Les Bell729b0352021-11-24 10:28:21 +0000947 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000948 self.ser,
949 validator_fcns,
950 error_name,
951 op=op,
952 input_dtype=ifm.dtype,
953 weight_dtype=filter.dtype,
954 output_dtype=result_tens.dtype,
955 qinfo=qinfo,
956 input_list=input_list,
957 num_operands=num_operands,
958 output_list=output_list,
959 pad=padding,
960 stride=strides,
961 dilation=dilations,
962 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100963 weight_shape=filter.shape,
964 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000965 ):
966 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700967
Tai Lyd3797f02023-11-15 23:06:19 +0000968 # TODO - Test local_bound, for now set local bound attribute to False
969 local_bound = False
970
Kevin Cheng1533b852021-09-01 12:51:58 -0700971 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000972 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700973
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000974 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700975 return result_tens
976
Kevin Cheng550ccc52021-03-03 11:21:43 -0800977 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000978 self,
979 op,
980 ifm,
981 filter,
982 bias,
James Ward8b390432022-08-12 20:48:56 +0100983 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000984 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700985 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000986 output_shape,
987 validator_fcns=None,
988 error_name=None,
989 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800990 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700991 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000992 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100993 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000994 )
Les Bell0e027d42021-11-09 14:42:14 +0000995
996 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000997 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
998 DType.INT8,
999 DType.UINT8,
1000 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001001 qinfo = [
1002 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1003 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1004 ]
Les Bell0e027d42021-11-09 14:42:14 +00001005
1006 # Invalidate Input/Output list for error_if checks.
1007 input_list = [ifm.name, filter.name, bias.name]
1008 output_list = [result_tens.name]
1009 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001010 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1011 self, error_name, input_list, output_list
1012 )
Les Bell0e027d42021-11-09 14:42:14 +00001013
Les Bell729b0352021-11-24 10:28:21 +00001014 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001015 self.ser,
1016 validator_fcns,
1017 error_name,
1018 op=op,
1019 input_dtype=ifm.dtype,
1020 weight_dtype=filter.dtype,
1021 output_dtype=result_tens.dtype,
1022 qinfo=qinfo,
1023 input_list=input_list,
1024 num_operands=num_operands,
1025 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001026 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001027 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001028 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001029 weight_shape=filter.shape,
1030 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001031 ):
1032 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001033
Tai Lyd3797f02023-11-15 23:06:19 +00001034 # TODO - Test local_bound, for now set local bound attribute to False
1035 local_bound = False
1036
Eric Kunzee5e26762020-10-13 16:11:07 -07001037 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001038 attr.TransposeConvAttribute(
1039 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1040 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001041
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001042 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001043 return result_tens
1044
Kevin Cheng550ccc52021-03-03 11:21:43 -08001045 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001046 self,
1047 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001048 inputs,
1049 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001050 validator_fcns=None,
1051 error_name=None,
1052 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001053 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001054 assert len(inputs) == 3
1055 ifm, filter, bias = inputs
1056 accum_dtype = args_dict["acc_type"]
1057 strides = args_dict["stride"]
1058 padding = args_dict["pad"]
1059 dilations = args_dict["dilation"]
1060
Kevin Cheng550ccc52021-03-03 11:21:43 -08001061 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001062 self.ser,
1063 self.rng,
1064 ifm,
1065 filter,
1066 accum_dtype,
1067 strides,
1068 padding,
1069 dilations,
1070 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001071 )
1072
1073 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001074 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1075 DType.INT8,
1076 DType.UINT8,
1077 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001078 qinfo = [
1079 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1080 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1081 ]
Les Bell0e027d42021-11-09 14:42:14 +00001082
1083 # Invalidate Input/Output list for error_if checks.
1084 input_list = [ifm.name, filter.name, bias.name]
1085 output_list = [result_tens.name]
1086 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001087 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1088 self, error_name, input_list, output_list
1089 )
Les Bell0e027d42021-11-09 14:42:14 +00001090
Les Bell729b0352021-11-24 10:28:21 +00001091 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001092 self.ser,
1093 validator_fcns,
1094 error_name,
1095 op=op,
1096 input_dtype=ifm.dtype,
1097 weight_dtype=filter.dtype,
1098 output_dtype=result_tens.dtype,
1099 qinfo=qinfo,
1100 input_list=input_list,
1101 num_operands=num_operands,
1102 output_list=output_list,
1103 pad=padding,
1104 stride=strides,
1105 dilation=dilations,
1106 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001107 weight_shape=filter.shape,
1108 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001109 ):
1110 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001111
Tai Lyd3797f02023-11-15 23:06:19 +00001112 # TODO - Test local_bound, for now set local bound attribute to False
1113 local_bound = False
1114
Eric Kunzee5e26762020-10-13 16:11:07 -07001115 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001116 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001117
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001118 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001119 return result_tens
1120
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001121 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001122 self,
1123 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001124 inputs,
1125 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001126 validator_fcns=None,
1127 error_name=None,
1128 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001129 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001130 assert len(inputs) == 3
1131 ifm, filter, bias = inputs
1132 accum_dtype = args_dict["acc_type"]
1133
1134 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001135 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001136 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001137
1138 # Invalidate Input/Output list for error if checks.
1139 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001140 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001141 pCount, cCount = op["operands"]
1142 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001143 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1144 self, error_name, input_list, output_list
1145 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001146
Les Bell729b0352021-11-24 10:28:21 +00001147 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001148 self.ser,
1149 validator_fcns,
1150 error_name,
1151 op=op,
1152 input_shape=ifm.shape,
1153 input_dtype=ifm.dtype,
1154 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001155 output_shape=result_tensor.shape,
1156 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001157 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001158 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001159 input_list=input_list,
1160 output_list=output_list,
1161 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001162 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001163 ):
1164 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001165
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001166 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001167 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001168
1169 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001170
1171 compliance = self.tensorComplianceMetaData(
1172 op, ifm.dtype, args_dict, result_tensor, error_name
1173 )
1174
1175 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001176
James Ward8b390432022-08-12 20:48:56 +01001177 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001178 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001179 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001180 assert len(inputs) == 2
1181 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001182 accum_dtype = args_dict["acc_type"]
1183 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001184 self.ser, self.rng, a, b, accum_dtype, error_name
1185 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001186
1187 # Invalidate Input/Output list for error if checks.
1188 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001189 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001190 pCount, cCount = op["operands"]
1191 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001192 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1193 self, error_name, input_list, output_list
1194 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001195
Les Bell729b0352021-11-24 10:28:21 +00001196 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001197 self.ser,
1198 validator_fcns,
1199 error_name,
1200 op=op,
1201 input_shape=a.shape,
1202 input_dtype=a.dtype,
1203 input2_shape=b.shape,
1204 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001205 output_shape=result_tensor.shape,
1206 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001207 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001208 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001209 input_list=input_list,
1210 output_list=output_list,
1211 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001212 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001213 ):
1214 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001215
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001216 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001217 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001218
1219 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001220
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001221 compliance = self.tensorComplianceMetaData(
1222 op, a.dtype, args_dict, result_tensor, error_name
1223 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001224
1225 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001226
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001227 def build_reduce(
1228 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1229 ):
1230 assert len(inputs) == 1
1231 a = inputs[0]
1232 axis = args_dict["axis"]
1233 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001234
1235 # Invalidate Input/Output list for error if checks.
1236 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001237 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001238 pCount, cCount = op["operands"]
1239 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001240 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1241 self, error_name, input_list, output_list
1242 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001243
Les Bell729b0352021-11-24 10:28:21 +00001244 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001245 self.ser,
1246 validator_fcns,
1247 error_name,
1248 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001249 axis=axis,
1250 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001251 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001252 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001253 output_dtype=result_tensor.dtype,
1254 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001255 input_list=input_list,
1256 output_list=output_list,
1257 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001258 ):
1259 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001260
1261 attr = ts.TosaSerializerAttribute()
1262 attr.AxisAttribute(axis)
1263
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001264 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001265
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001266 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1267 # Number of products - needed for compliance
1268 args_dict["n"] = a.shape[axis]
1269
1270 compliance = self.tensorComplianceMetaData(
1271 op, a.dtype, args_dict, result_tensor, error_name
1272 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001273
1274 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001275
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001276 def build_clamp(
1277 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1278 ):
1279 assert len(inputs) == 1
1280 a = inputs[0]
1281
1282 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001283
Jeremy Johnson18e26662021-07-22 16:15:29 +01001284 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001285
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001286 if error_name == ErrorIf.MaxSmallerMin:
1287 # Make sure the numbers are different to invoke this error
1288 while v[0] == v[1]:
1289 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1290 max_val = min(v)
1291 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001292 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001293 max_val = max(v)
1294 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001295
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001296 # Invalidate Input/Output list for error if checks.
1297 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001298 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001299 pCount, cCount = op["operands"]
1300 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001301 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1302 self, error_name, input_list, output_list
1303 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001304
Les Bell729b0352021-11-24 10:28:21 +00001305 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001306 self.ser,
1307 validator_fcns,
1308 error_name,
1309 op=op,
1310 max_val=max_val,
1311 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001312 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001313 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001314 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001315 output_dtype=result_tensor.dtype,
1316 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001317 input_list=input_list,
1318 output_list=output_list,
1319 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001320 ):
1321 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001322
1323 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001324 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1325 if a.dtype == DType.FP16:
1326 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1327 min_val = min_val.astype(np.float32)
1328 max_val = max_val.astype(np.float32)
1329
1330 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001331 else:
James Ward34071252022-12-07 15:48:47 +00001332 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001333
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001334 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001335
1336 compliance = self.tensorComplianceMetaData(
1337 op, a.dtype, args_dict, result_tensor, error_name
1338 )
1339
1340 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001341
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001342 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1343 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001344 attr = ts.TosaSerializerAttribute()
1345
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001346 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001347
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001348 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001349 return result_tens
1350
1351 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001352 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1353 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001354
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001355 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001356 return result_tens
1357
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001358 def build_activation(
1359 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1360 ):
1361 assert len(inputs) == 1
1362 a = inputs[0]
1363
1364 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001365
1366 # Invalidate Input/Output list for error if checks.
1367 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001368 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001369 pCount, cCount = op["operands"]
1370 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001371 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1372 self, error_name, input_list, output_list
1373 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374
Les Bell729b0352021-11-24 10:28:21 +00001375 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001376 self.ser,
1377 validator_fcns,
1378 error_name,
1379 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001380 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001381 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001382 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001383 output_dtype=result_tensor.dtype,
1384 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001385 input_list=input_list,
1386 output_list=output_list,
1387 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001388 ):
1389 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001390
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001391 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001392
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001393 compliance = self.tensorComplianceMetaData(
1394 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001395 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001396
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001397 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001398
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001399 def build_concat(
1400 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1401 ):
1402 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001403 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001404 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001405
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001406 result_tensor = OutputShaper.concatOp(
1407 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001408 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001409
Matthew Haddon818ab902021-07-27 09:12:49 +01001410 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001411 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001412 input_tensor_names.append(tensor.name)
1413
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001414 # Invalidate Input/Output list for error if checks.
1415 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001416 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001417 pCount, cCount = op["operands"]
1418 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001419 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1420 self, error_name, input_list, output_list
1421 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001422
Les Bell729b0352021-11-24 10:28:21 +00001423 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424 self.ser,
1425 validator_fcns,
1426 error_name,
1427 op=op,
1428 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001429 input_shape=inputs[0].shape,
1430 output_shape=result_tensor.shape,
1431 input_dtype=inputs[0].dtype,
1432 output_dtype=result_tensor.dtype,
1433 inputs=inputs,
1434 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001435 input_list=input_list,
1436 output_list=output_list,
1437 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001438 ):
1439 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001440
1441 attr = ts.TosaSerializerAttribute()
1442 attr.AxisAttribute(axis)
1443
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001444 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001445
1446 compliance = self.tensorComplianceMetaData(
1447 op, inputs[0].dtype, args_dict, result_tensor, error_name
1448 )
1449
1450 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001451
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001452 def build_pad(
1453 self,
1454 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001455 inputs,
1456 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001457 validator_fcns=None,
1458 error_name=None,
1459 qinfo=None,
1460 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001461 assert len(inputs) == 1
1462 a = inputs[0]
1463 padding = args_dict["pad"]
1464 pad_const_int = args_dict["pad_const_int"]
1465 pad_const_float = args_dict["pad_const_fp"]
1466
1467 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001468
Kevin Chengfe392ce2021-10-18 21:51:55 +00001469 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001470 attr.PadAttribute(
1471 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1472 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001473
Matthew Haddone807aae2021-10-11 18:12:58 +01001474 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001475 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001476 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001477 pCount, cCount = op["operands"]
1478 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001479 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1480 self, error_name, input_list, output_list
1481 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001482
Les Bell729b0352021-11-24 10:28:21 +00001483 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001484 self.ser,
1485 validator_fcns,
1486 error_name,
1487 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001488 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001489 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001490 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001491 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001492 pad=padding,
1493 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001494 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001495 input_list=input_list,
1496 output_list=output_list,
1497 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001498 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001499 ):
1500 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001501
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001502 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001503
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001504 compliance = self.tensorComplianceMetaData(
1505 op, a.dtype, args_dict, result_tensor, error_name
1506 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001507
1508 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001509
Won Jeona21b2e82023-08-10 10:33:01 +00001510 def build_dim(
1511 self,
1512 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001513 inputs,
1514 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001515 validator_fcns=None,
1516 error_name=None,
1517 qinfo=None,
1518 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001519 assert len(inputs) == 1
1520 a = inputs[0]
1521 axis = args_dict["axis"]
1522 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001523
1524 # Invalidate Input/Output list for error if checks.
1525 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001526 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001527 pCount, cCount = op["operands"]
1528 num_operands = pCount + cCount
1529 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1530 self, error_name, input_list, output_list
1531 )
1532
1533 if not TosaErrorValidator.evValidateErrorIfs(
1534 self.ser,
1535 validator_fcns,
1536 error_name,
1537 op=op,
1538 axis=axis,
1539 input_shape=a.shape,
1540 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001541 output_shape=result_tensor.shape,
1542 output_dtype=result_tensor.dtype,
1543 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001544 input_list=input_list,
1545 output_list=output_list,
1546 num_operands=num_operands,
1547 ):
1548 return None
1549
1550 attr = ts.TosaSerializerAttribute()
1551 attr.AxisAttribute(axis)
1552
1553 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001554 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001555
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001556 def build_reshape(
1557 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1558 ):
Tai Ly8690a082023-12-18 20:40:24 +00001559 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001560 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001561 shape = inputs[1]
1562 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001563 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001564 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001565 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001566
1567 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001568 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001569 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001570 pCount, cCount = op["operands"]
1571 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001572 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1573 self, error_name, input_list, output_list
1574 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001575
Les Bell729b0352021-11-24 10:28:21 +00001576 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001577 self.ser,
1578 validator_fcns,
1579 error_name,
1580 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001581 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001582 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001583 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001584 output_dtype=result_tensor.dtype,
1585 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001586 input_list=input_list,
1587 output_list=output_list,
1588 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001589 ):
1590 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001591
Tai Ly8690a082023-12-18 20:40:24 +00001592 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001593
1594 compliance = self.tensorComplianceMetaData(
1595 op, a.dtype, args_dict, result_tensor, error_name
1596 )
1597
1598 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001599
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001600 def build_reverse(
1601 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1602 ):
1603 assert len(inputs) == 1
1604 a = inputs[0]
1605 axis = args_dict["axis"]
1606 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001607
1608 # Invalidate Input/Output list for error if checks.
1609 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001610 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001611 pCount, cCount = op["operands"]
1612 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001613 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1614 self, error_name, input_list, output_list
1615 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001616
Les Bell729b0352021-11-24 10:28:21 +00001617 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001618 self.ser,
1619 validator_fcns,
1620 error_name,
1621 op=op,
1622 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001623 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001624 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001625 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001626 output_dtype=result_tensor.dtype,
1627 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001628 input_list=input_list,
1629 output_list=output_list,
1630 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001631 ):
1632 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001633
1634 attr = ts.TosaSerializerAttribute()
1635 attr.AxisAttribute(axis)
1636
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001637 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001638 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001639
Matthew Haddone807aae2021-10-11 18:12:58 +01001640 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1641 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001642
Kevin Chengfe392ce2021-10-18 21:51:55 +00001643 attr = ts.TosaSerializerAttribute()
1644 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001645
Matthew Haddone807aae2021-10-11 18:12:58 +01001646 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001647 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001648 output_list = [result_tens.name]
1649 pCount, cCount = op["operands"]
1650 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001651 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1652 self, error_name, input_list, output_list
1653 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001654
Les Bell729b0352021-11-24 10:28:21 +00001655 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001656 self.ser,
1657 validator_fcns,
1658 error_name,
1659 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001660 input_shape=a.shape,
1661 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001662 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001663 input_dtype=a.dtype,
1664 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001665 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001666 input_list=input_list,
1667 output_list=output_list,
1668 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001669 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001670 ):
1671 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001672
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001673 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001674 return result_tens
1675
Matthew Haddone807aae2021-10-11 18:12:58 +01001676 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001677 result_tens = OutputShaper.sliceOp(
1678 self.ser, self.rng, a, start, size, error_name
1679 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001680
1681 # Invalidate Input/Output list for error if checks.
1682 input_list = [a.name]
1683 output_list = [result_tens.name]
1684 pCount, cCount = op["operands"]
1685 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001686 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1687 self, error_name, input_list, output_list
1688 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001689
Les Bell729b0352021-11-24 10:28:21 +00001690 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001691 self.ser,
1692 validator_fcns,
1693 error_name,
1694 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001695 input_shape=a.shape,
1696 output_shape=result_tens.shape,
1697 input_dtype=a.dtype,
1698 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001699 start=start,
1700 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001701 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001702 input_list=input_list,
1703 output_list=output_list,
1704 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001705 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001706 ):
1707 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001708
1709 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001710 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001711
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001712 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001713 return result_tens
1714
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001715 def build_tile(
1716 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1717 ):
Tai Ly8690a082023-12-18 20:40:24 +00001718 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001719 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001720 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001721 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001722 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001723 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001724 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001725
1726 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001727 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001728 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001729 pCount, cCount = op["operands"]
1730 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001731 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1732 self, error_name, input_list, output_list
1733 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001734
Les Bell729b0352021-11-24 10:28:21 +00001735 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001736 self.ser,
1737 validator_fcns,
1738 error_name,
1739 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001740 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001741 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001742 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001743 output_dtype=result_tensor.dtype,
1744 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001745 input_list=input_list,
1746 output_list=output_list,
1747 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001748 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001749 ):
1750 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001751
Tai Ly8690a082023-12-18 20:40:24 +00001752 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001753
1754 compliance = self.tensorComplianceMetaData(
1755 op, a.dtype, args_dict, result_tensor, error_name
1756 )
1757
1758 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001759
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001760 def build_gather(
1761 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1762 ):
1763 assert len(inputs) == 2
1764 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001765
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001766 result_tensor = OutputShaper.gatherOp(
1767 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001768 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001769
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001770 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001771 input_list = [values.name, indices.name]
1772 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001773 pCount, cCount = op["operands"]
1774 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001775 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1776 self, error_name, input_list, output_list
1777 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001778
Les Bell729b0352021-11-24 10:28:21 +00001779 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001780 self.ser,
1781 validator_fcns,
1782 error_name,
1783 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001784 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001785 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001786 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001787 output_dtype=result_tensor.dtype,
1788 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001789 input_list=input_list,
1790 output_list=output_list,
1791 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001792 ):
1793 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001794
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001795 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001796
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001797 compliance = self.tensorComplianceMetaData(
1798 op, values.dtype, args_dict, result_tensor, error_name
1799 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001800
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001801 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001802
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001803 def build_scatter(
1804 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1805 ):
1806 assert len(inputs) == 3
1807 values_in, indices, input = inputs
1808 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001809 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001810 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001811
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001812 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001813 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001814 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001815 pCount, cCount = op["operands"]
1816 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001817 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1818 self, error_name, input_list, output_list
1819 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001820
Les Bell729b0352021-11-24 10:28:21 +00001821 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001822 self.ser,
1823 validator_fcns,
1824 error_name,
1825 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001826 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001827 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001828 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001829 output_dtype=result_tensor.dtype,
1830 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001831 input_list=input_list,
1832 output_list=output_list,
1833 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001834 ):
1835 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001836
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001837 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001838
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001839 compliance = self.tensorComplianceMetaData(
1840 op, values_in.dtype, args_dict, result_tensor, error_name
1841 )
1842
1843 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001844
Kevin Cheng550ccc52021-03-03 11:21:43 -08001845 def build_resize(
1846 self,
1847 op,
1848 input,
1849 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001850 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001851 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001852 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001853 input_dtype,
1854 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001855 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001856 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001857 ):
1858 result_tens = OutputShaper.resizeOp(
1859 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001860 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001861 input,
1862 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001863 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001864 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001865 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001866 input_dtype,
1867 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001868 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001869 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001870
Matthew Haddon848efb42021-09-09 12:30:53 +01001871 # Invalidate Input/Output list for error if checks.
1872 input_list = [input.name]
1873 output_list = [result_tens.name]
1874 pCount, cCount = op["operands"]
1875 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001876 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1877 self, error_name, input_list, output_list
1878 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001879
Les Bell729b0352021-11-24 10:28:21 +00001880 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001881 self.ser,
1882 validator_fcns,
1883 error_name,
1884 op=op,
1885 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001886 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001887 input_dtype=input_dtype,
1888 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001889 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001890 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001891 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001892 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001893 input_list=input_list,
1894 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001895 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001896 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001897 ):
1898 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001899
Eric Kunzee5e26762020-10-13 16:11:07 -07001900 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001901
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001902 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001903
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001904 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001905 return result_tens
1906
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001907 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1908 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1909 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001910 self.ser.addOperator(
1911 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1912 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001913 return result_tens
1914
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001915 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001916 self.ser.addOutputTensor(val)
1917 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001918
1919 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00001920 def build_cast(
1921 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1922 ):
1923 assert len(inputs) == 1
1924 val = inputs[0]
1925 out_dtype = args_dict["out_type"]
1926
1927 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001928 self.ser, self.rng, val, out_dtype, error_name
1929 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001930
1931 # Invalidate Input/Output list for error if checks.
1932 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00001933 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001934 pCount, cCount = op["operands"]
1935 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001936 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1937 self, error_name, input_list, output_list
1938 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001939
Les Bell729b0352021-11-24 10:28:21 +00001940 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001941 self.ser,
1942 validator_fcns,
1943 error_name,
1944 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001945 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00001946 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001947 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00001948 output_dtype=result_tensor.dtype,
1949 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001950 input_list=input_list,
1951 output_list=output_list,
1952 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001953 ):
1954 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001955
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001956 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00001957
1958 compliance = self.tensorComplianceMetaData(
1959 op, val.dtype, args_dict, result_tensor, error_name
1960 )
1961
1962 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001963
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001964 def build_rescale(
1965 self,
1966 op,
1967 val,
1968 out_dtype,
1969 scale32,
1970 double_round,
1971 per_channel,
1972 validator_fcns,
1973 error_name,
1974 ):
1975 result_tens = OutputShaper.typeConversionOp(
1976 self.ser, self.rng, val, out_dtype, error_name
1977 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001978
1979 if per_channel:
1980 nc = val.shape[-1]
1981 else:
1982 nc = 1
1983
1984 in_type_width = self.typeWidth(val.dtype)
1985 out_type_width = self.typeWidth(out_dtype)
1986
Tai Ly8690a082023-12-18 20:40:24 +00001987 input_unsigned = False
1988 output_unsigned = False
1989
Kevin Cheng3a478572021-01-22 17:21:02 -08001990 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001991 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001992 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001993 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001994 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001995 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00001996 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001997 elif error_name in [
1998 ErrorIf.InputZeroPointNotZero,
1999 ErrorIf.U16InputZeroPointNotValid,
2000 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002001 input_zp = self.randInt(-128, 128)
2002 if input_zp == 0:
2003 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002004 in_type_width += 1
2005 elif val.dtype == DType.UINT16:
2006 # Must come after ErrorIf.U16InputZeroPointNotValid check
2007 input_zp = self.rng.choice([0, 32768])
2008 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002009 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002010 else:
2011 input_zp = 0
2012
Kevin Cheng3a478572021-01-22 17:21:02 -08002013 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002014 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002015 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002016 elif out_dtype == DType.UINT8:
2017 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002018 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002019 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002020 elif error_name in [
2021 ErrorIf.OutputZeroPointNotZero,
2022 ErrorIf.U16OutputZeroPointNotValid,
2023 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002024 output_zp = self.randInt(-128, 128)
2025 if output_zp == 0:
2026 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002027 out_type_width += 1
2028 elif out_dtype == DType.UINT16:
2029 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2030 output_zp = self.rng.choice([0, 32768])
2031 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002032 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002033 else:
2034 output_zp = 0
2035
2036 # Calculate scale based on:
2037 # scale = a *(2^output_width)/(2^input_width))
2038
2039 a = np.float32(self.rng.random(size=[nc]))
2040 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2041
2042 if scale32:
2043 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002044 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002045 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2046 else:
2047 # Cap the scaling at 2^15 - 1 for scale16
2048 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2049
Kevin Cheng550ccc52021-03-03 11:21:43 -08002050 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002051
2052 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2053 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002054 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2055 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002056
2057 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002058 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2059 scale_arr[i], scale32
2060 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002061 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2062 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002063
Kevin Cheng550ccc52021-03-03 11:21:43 -08002064 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002065 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002066 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002067 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002068 assert val.placeholderFilename
2069 values = np.load(
2070 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2071 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002072 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2073 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2074 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2075 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002076 if not np.all(np.array_equal(values, val_adj)):
2077 # Values changed so overwrite file with new values
2078 np.save(
2079 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2080 val_adj,
2081 False,
2082 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002083
Matthew Haddonc2025212021-10-08 21:21:05 +01002084 # Invalidate Input/Output list for error if checks.
2085 input_list = [val.name]
2086 output_list = [result_tens.name]
2087 pCount, cCount = op["operands"]
2088 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002089 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2090 self, error_name, input_list, output_list
2091 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002092
2093 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002094 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002095 self.ser,
2096 validator_fcns,
2097 error_name,
2098 op=op,
2099 input_dtype=val.dtype,
2100 output_dtype=out_dtype,
2101 input_shape=val.shape,
2102 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002103 scale32=scale32,
2104 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002105 input_list=input_list,
2106 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002107 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002108 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002109 ):
2110 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002111
Eric Kunzee5e26762020-10-13 16:11:07 -07002112 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002113 attr.RescaleAttribute(
2114 input_zp,
2115 output_zp,
2116 multiplier_arr,
2117 shift_arr,
2118 scale32,
2119 double_round,
2120 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002121 input_unsigned,
2122 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002123 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002124
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002125 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002126 return result_tens
2127
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002128 def _get_condition_tensor(self, op, cond, error_name):
2129 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002130 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002131 else:
2132 cond_type = DType.BOOL
2133 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2134 choice = self.rng.choice([1, 2])
2135 if choice == 1:
2136 cond_shape = [2]
2137 else:
2138 cond_shape = [1, 2]
2139 else:
2140 # Must be of size 1 (rank 0)
2141 cond_shape = []
2142 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2143 return cond_tens
2144
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002145 def build_cond_if_const(
2146 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2147 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002148 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002149 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002150 # and fill them with const nodes for the body.
2151
2152 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002153 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002154
2155 # Make then/else tensors
2156 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002157
2158 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002159 if error_name in [
2160 ErrorIf.CondIfOutputListThenGraphMismatch,
2161 ErrorIf.CondIfOutputListElseGraphMismatch,
2162 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002163 incorrect_shape = deepcopy(then_tens.shape)
2164 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002165 incorrect_shape[i] += (
2166 self.rng.choice([-3, -2, 2, 3])
2167 if incorrect_shape[i] > 3
2168 else self.rng.choice([1, 2, 4])
2169 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002170 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2171
Jeremy Johnson18e26662021-07-22 16:15:29 +01002172 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2173 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002174
2175 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002176 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002177
2178 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002179 then_block = "THEN_BLOCK"
2180 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002181 attr = ts.TosaSerializerAttribute()
2182 attr.CondIfAttribute(then_block, else_block)
2183
2184 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002185 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002186
Jerry Ge9e94af82022-10-27 09:57:00 -07002187 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002188 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002189 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2190 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2191 else:
2192 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002193 self.ser.addOutputTensor(then_tens)
2194
Jerry Ge9e94af82022-10-27 09:57:00 -07002195 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002196 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2197 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2198 else:
2199 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002200 self.ser.addOutputTensor(else_tens)
2201
Les Bell729b0352021-11-24 10:28:21 +00002202 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002203 self.ser,
2204 validator_fcns,
2205 error_name,
2206 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002207 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002208 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002209 ):
2210 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002211
Eric Kunzee5e26762020-10-13 16:11:07 -07002212 return result_tens
2213
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002214 def build_cond_if_binary(
2215 self, op, a, b, cond, validator_fcns=None, error_name=None
2216 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002217 # For cond_if with a binary op in the then/else blocks, take a and b and
2218 # alternately add or subtract them based on the condition
2219
2220 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002221 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002222
Kevin Cheng550ccc52021-03-03 11:21:43 -08002223 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002224
2225 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002226 then_block = "THEN_BLOCK"
2227 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002228 attr = ts.TosaSerializerAttribute()
2229 attr.CondIfAttribute(then_block, else_block)
2230
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002231 if error_name in [
2232 ErrorIf.CondIfInputListThenGraphMismatch,
2233 ErrorIf.CondIfInputListElseGraphMismatch,
2234 ErrorIf.CondIfOutputListElseGraphMismatch,
2235 ErrorIf.CondIfOutputListThenGraphMismatch,
2236 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002237 incorrect_shape = a.shape.copy()
2238 for i in range(len(incorrect_shape)):
2239 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2240 incorrect_block_input = deepcopy(a)
2241 incorrect_block_input.shape = incorrect_shape
2242
Eric Kunzee5e26762020-10-13 16:11:07 -07002243 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002244 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002245 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002246 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002247
James Ward24dbc422022-10-19 12:20:31 +01002248 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002249 then_op, else_op = Op.ADD, Op.SUB
2250 elif a.dtype in (DType.INT8, DType.INT16):
2251 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2252 else:
2253 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002254
Les Bell6040b4d2021-10-11 12:50:31 +01002255 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002256 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002257 if (
2258 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2259 and block == then_block
2260 ) or (
2261 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2262 and block == else_block
2263 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002264 self.ser.addInputTensor(incorrect_block_input)
2265 self.ser.addInputTensor(b)
2266 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002267 elif (
2268 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2269 and block == then_block
2270 ) or (
2271 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2272 and block == else_block
2273 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002274 self.ser.addInputTensor(a)
2275 self.ser.addInputTensor(b)
2276 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2277 else:
2278 self.ser.addInputTensor(a)
2279 self.ser.addInputTensor(b)
2280 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002281 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002282
Les Bell729b0352021-11-24 10:28:21 +00002283 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002284 self.ser,
2285 validator_fcns,
2286 error_name,
2287 op=op,
2288 a=a,
2289 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002290 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002291 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002292 ):
2293 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002294
Eric Kunzee5e26762020-10-13 16:11:07 -07002295 return result_tens
2296
Matthew Haddon630c17c2021-10-14 15:05:41 +01002297 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002298 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002299
Kevin Cheng550ccc52021-03-03 11:21:43 -08002300 cond_block = "COND_BLOCK"
2301 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002302
2303 attr = ts.TosaSerializerAttribute()
2304 attr.WhileLoopAttribute(cond_block, body_block)
2305
2306 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002307 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002308 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002309 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002310
2311 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002312 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2313 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002314 if error_name == ErrorIf.InputListOutputListMismatch:
2315 incorrect_acc = deepcopy(acc)
2316 for i in range(len(incorrect_acc.shape)):
2317 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2318 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2319 else:
2320 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002321
2322 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002323 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002324 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002325 [iter.name, a.name, acc.name],
2326 [iter_out.name, a_out.name, acc_out.name],
2327 attr,
2328 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002329 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002330
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002331 if error_name in [
2332 ErrorIf.InputListCondGraphMismatch,
2333 ErrorIf.InputListBodyGraphInputMismatch,
2334 ErrorIf.InputListBodyGraphOutputMismatch,
2335 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002336 incorrect_iter = deepcopy(iter)
2337 for i in range(len(incorrect_iter.shape)):
2338 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2339 if len(incorrect_iter.shape) == 0:
2340 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2341
2342 incorrect_acc = deepcopy(acc)
2343 for i in range(len(incorrect_acc.shape)):
2344 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2345
Eric Kunzee5e26762020-10-13 16:11:07 -07002346 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002347 self.ser.addBasicBlock(cond_block)
2348
Matthew Haddon630c17c2021-10-14 15:05:41 +01002349 if error_name == ErrorIf.InputListCondGraphMismatch:
2350 self.ser.addInputTensor(incorrect_iter)
2351 self.ser.addInputTensor(a)
2352 self.ser.addInputTensor(incorrect_acc)
2353 else:
2354 self.ser.addInputTensor(iter)
2355 self.ser.addInputTensor(a)
2356 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002357 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002358
2359 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002360 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002361 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002362 cond_type = DType.BOOL
2363 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2364 choice = self.rng.choice([1, 2])
2365 if choice == 1:
2366 cond_shape = [3]
2367 else:
2368 cond_shape = [1, 2]
2369 else:
2370 cond_shape = []
2371 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002372
Kevin Cheng550ccc52021-03-03 11:21:43 -08002373 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002374
2375 # BODY block (input: a, acc, iter, output: a, acc, iter)
2376 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002377 self.ser.addBasicBlock(body_block)
2378
Matthew Haddon630c17c2021-10-14 15:05:41 +01002379 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2380 self.ser.addInputTensor(incorrect_iter)
2381 self.ser.addInputTensor(a)
2382 self.ser.addInputTensor(incorrect_acc)
2383 else:
2384 self.ser.addInputTensor(iter)
2385 self.ser.addInputTensor(a)
2386 self.ser.addInputTensor(acc)
2387
Kevin Cheng550ccc52021-03-03 11:21:43 -08002388 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002389
2390 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002391 iter_body_out = self.ser.addIntermediate(
2392 incorrect_iter.shape, incorrect_iter.dtype
2393 )
2394 acc_body_out = self.ser.addIntermediate(
2395 incorrect_acc.shape, incorrect_acc.dtype
2396 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002397 else:
2398 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2399 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2400
Eric Kunzee5e26762020-10-13 16:11:07 -07002401 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2402 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2403 self.ser.addOutputTensor(iter_body_out)
2404 self.ser.addOutputTensor(a)
2405 self.ser.addOutputTensor(acc_body_out)
2406
Les Bell729b0352021-11-24 10:28:21 +00002407 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002408 self.ser,
2409 validator_fcns,
2410 error_name,
2411 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002412 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002413 ):
2414 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002415
Eric Kunzee5e26762020-10-13 16:11:07 -07002416 return acc_out
2417
Luke Hutton57287132023-02-06 14:54:18 +00002418 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002419 self,
2420 op,
2421 val1,
2422 val2,
2423 inverse,
2424 validator_fcns=None,
2425 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002426 ):
2427 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2428
2429 input_names = [val1.name, val2.name]
2430 pCount, cCount = op["operands"]
2431 num_operands = pCount + cCount
2432
2433 output_names = [res.name for res in results]
2434 output_shapes = [res.shape for res in results]
2435 output_dtypes = [res.dtype for res in results]
2436
2437 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2438 self, error_name, input_names, output_names
2439 )
2440
2441 if not TosaErrorValidator.evValidateErrorIfs(
2442 self.ser,
2443 validator_fcns,
2444 error_name,
2445 op=op,
2446 inverse=inverse,
2447 input1=val1,
2448 input2=val2,
2449 input_shape=val1.shape,
2450 input_dtype=val1.dtype,
2451 output_shape=output_shapes,
2452 output_dtype=output_dtypes,
2453 result_tensors=results,
2454 input_list=input_names,
2455 output_list=output_names,
2456 num_operands=num_operands,
2457 ):
2458 return None
2459
Tai Lyd3797f02023-11-15 23:06:19 +00002460 # TODO - Test local_bound, for now set local bound attribute to False
2461 local_bound = False
2462
Luke Hutton57287132023-02-06 14:54:18 +00002463 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002464 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002465
2466 self.ser.addOperator(op["op"], input_names, output_names, attr)
2467 return results
2468
Tai Lyd3797f02023-11-15 23:06:19 +00002469 def build_rfft2d(
2470 self,
2471 op,
2472 val,
2473 validator_fcns=None,
2474 error_name=None,
2475 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002476 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2477
2478 input_names = [val.name]
2479 pCount, cCount = op["operands"]
2480 num_operands = pCount + cCount
2481
2482 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002483 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002484 output_dtypes = [res.dtype for res in results]
2485
2486 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2487 self, error_name, input_names, output_names
2488 )
2489
2490 if not TosaErrorValidator.evValidateErrorIfs(
2491 self.ser,
2492 validator_fcns,
2493 error_name,
2494 op=op,
2495 input_shape=val.shape,
2496 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002497 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002498 output_dtype=output_dtypes,
2499 result_tensors=results,
2500 input_list=input_names,
2501 output_list=output_names,
2502 num_operands=num_operands,
2503 ):
2504 return None
2505
Tai Lyd3797f02023-11-15 23:06:19 +00002506 # TODO - Test local_bound, for now set local bound attribute to False
2507 local_bound = False
2508
2509 attr = ts.TosaSerializerAttribute()
2510 attr.RFFTAttribute(local_bound)
2511
2512 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002513 return results
2514
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002515 def create_filter_lists(
2516 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2517 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002518 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2519 default_test_rank_range = range(1, 5)
2520 if not shapeFilter:
2521 shapeFilter = [None]
2522
2523 # Calculate the filters based on what is requested and what the operator allows
2524 rmin, rmax = op["rank"]
2525 if rankFilter is not None:
2526 cleanRankFilter = []
2527 # Ensure rankFilter values are allowed by operator
2528 for rank in rankFilter:
2529 if rank >= rmin and rank <= rmax:
2530 cleanRankFilter.append(rank)
2531 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002532 # Ensure default behaviour is bounded by default range or by operator,
2533 # whichever is the smaller range of ranks.
2534 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002535 cleanRankFilter = (
2536 opRankRange
2537 if len(opRankRange) <= len(default_test_rank_range)
2538 else default_test_rank_range
2539 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002540 else:
2541 cleanRankFilter = range(rmin, rmax + 1)
2542
2543 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002544
Matthew Haddon1c00b712021-10-01 15:51:03 +01002545 if dtypeFilter is not None:
2546 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002547 # Create list of operator dtypes filtered by requested dtypes
2548 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002549 if dtype in dtypeFilter or (
2550 isinstance(dtype, list) and dtype[0] in dtypeFilter
2551 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002552 cleanDtypeFilter.append(dtype)
2553 else:
2554 cleanDtypeFilter = dtypes
2555
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002556 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002557 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002558 "shapeFilter": shapeFilter,
2559 "rankFilter": cleanRankFilter,
2560 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002561 }
2562 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002563 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002564 if validator is not None:
2565 validator_info = validator(check=False, op=op)
2566 else:
2567 return None
2568
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002569 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002570
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002571 # Set parameters as required
2572 if error_arguments["rank"] is not None:
2573 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002574 else:
2575 rankFilter = cleanRankFilter
2576
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002577 if error_arguments["dtype"] is not None:
2578 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002579 else:
2580 dtypeFilter = cleanDtypeFilter
2581
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002582 if error_arguments["shape"] is not None:
2583 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002584 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002585 shapeFilter = shapeFilter[
2586 :2
2587 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002588
2589 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002590 "shapeFilter": shapeFilter,
2591 "rankFilter": rankFilter,
2592 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002593 }
2594 return filterDict
2595
Kevin Cheng550ccc52021-03-03 11:21:43 -08002596 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002597 self,
2598 opName,
2599 shapeFilter=[None],
2600 rankFilter=None,
2601 dtypeFilter=None,
2602 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002603 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002604
2605 try:
2606 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002607 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002608 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002609
2610 # Initialize a new random number generator
2611 self.rng = np.random.default_rng(self.random_seed)
2612
Jeremy Johnson1271c442023-09-05 11:39:26 +01002613 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002614
Eric Kunzee5e26762020-10-13 16:11:07 -07002615 # Test list consists of a tuple of:
2616 # (opName, testNameStr, dtype, shapeList, argumentsList)
2617 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002618 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002619 error_if_validators = op["error_if_validators"]
2620 else:
2621 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002622
Matthew Haddon1c00b712021-10-01 15:51:03 +01002623 for validator in error_if_validators:
2624 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002625 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002626 else:
2627 error_name = None
2628
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002629 filterDict = self.create_filter_lists(
2630 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2631 )
2632 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002633 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002634 cleanRankFilter = filterDict["rankFilter"]
2635 cleanDtypeFilter = filterDict["dtypeFilter"]
2636 cleanShapeFilter = filterDict["shapeFilter"]
2637 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002638
2639 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002640 for t in cleanDtypeFilter:
2641 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002642 # Filter out by rank
2643 if shape is not None and len(shape) != r:
2644 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002645 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002646 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002647
Matthew Haddon74567092021-07-16 15:38:20 +01002648 shapeStr = self.shapeStr(shapeList[0])
2649 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002650
Matthew Haddon74567092021-07-16 15:38:20 +01002651 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2652 argList = []
2653 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002654 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002655 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002656 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002657
Matthew Haddon74567092021-07-16 15:38:20 +01002658 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002659 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002660 if argStr:
2661 testStr = "{}_{}_{}_{}".format(
2662 opName, shapeStr, typeStr, argStr
2663 )
2664 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002665 testStr = "{}_{}_{}".format(
2666 opName, shapeStr, typeStr
2667 )
2668 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002669 if argStr:
2670 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2671 opName, error_name, shapeStr, typeStr, argStr
2672 )
2673 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002674 testStr = "{}_ERRORIF_{}_{}_{}".format(
2675 opName, error_name, shapeStr, typeStr
2676 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002677
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002678 testList.append(
2679 (opName, testStr, t, error_name, shapeList, args)
2680 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002681
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002682 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002683 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2684 if "invalid_test_validators" in op:
2685 invalid_test_validators = op["invalid_test_validators"]
2686 clean_testList = []
2687 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002688 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002689 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002690 if validator_fcn(
2691 opName=test[0],
2692 input_dtype=test[2],
2693 shapeList=test[4],
2694 args=test[5],
2695 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002696 remove_test = True
2697 if not remove_test:
2698 clean_testList.append(test)
2699 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002700
2701 return testList
2702
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002703 def serializeTest(
2704 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2705 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002706 try:
2707 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002708 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002709 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002710
Jeremy Johnson0c716862023-04-13 17:18:19 +01002711 if self.args.verbose:
2712 print(f"Creating {testStr}")
2713
Eric Kunzee5e26762020-10-13 16:11:07 -07002714 # Create a serializer
2715 self.createSerializer(opName, testStr)
2716
Jeremy Johnson1271c442023-09-05 11:39:26 +01002717 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002718 if "error_if_validators" in op:
2719 error_if_validators = op["error_if_validators"]
2720 else:
2721 error_if_validators = None
2722
Kevin Cheng550ccc52021-03-03 11:21:43 -08002723 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002724 num_operands = pCount + cCount
2725
2726 if isinstance(dtype_or_dtypeList, list):
2727 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002728 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002729 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002730 else:
2731 dtypeList = [dtype_or_dtypeList] * (num_operands)
2732
Kevin Cheng93a16282021-08-31 16:14:03 -07002733 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002734 assert (
2735 len(shapeList) == num_operands
2736 ), "shapeList length {} must match number of operands {}".format(
2737 len(shapeList), num_operands
2738 )
2739 assert (
2740 len(dtypeList) == num_operands
2741 ), "dtypeList length {} must match number of operands {}".format(
2742 len(dtypeList), num_operands
2743 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002744
2745 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002746 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002747 except KeyError:
2748 qgen = None
2749
2750 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002751
Matthew Haddon1c00b712021-10-01 15:51:03 +01002752 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002753 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002754 else:
2755 qinfo = None
2756
Jeremy Johnson1271c442023-09-05 11:39:26 +01002757 # Extra meta data for the desc.json
2758 tensMeta = {}
2759
2760 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002761 if isinstance(testArgs, dict):
2762 # New interface with args info in dictionary
2763 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002764 assert "dg_type" in argsDict
2765 tvgInfo = tvgen_fcn(
2766 self, opName, dtypeList, shapeList, argsDict, error_name
2767 )
2768 if tvgInfo.dataGenDict:
2769 tensMeta["data_gen"] = tvgInfo.dataGenDict
2770 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002771
2772 result = build_fcn(
2773 self,
2774 op,
2775 tens,
2776 argsDict,
2777 validator_fcns=error_if_validators,
2778 error_name=error_name,
2779 qinfo=qinfo,
2780 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002781 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002782 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002783 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002784
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002785 try:
2786 if error_if_validators is None:
2787 if qinfo is not None:
2788 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2789 else:
2790 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002791 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002792 if qinfo is not None:
2793 result = build_fcn(
2794 self,
2795 op,
2796 *tens,
2797 *testArgs,
2798 validator_fcns=error_if_validators,
2799 error_name=error_name,
2800 qinfo=qinfo,
2801 )
2802 else:
2803 result = build_fcn(
2804 self,
2805 op,
2806 *tens,
2807 *testArgs,
2808 validator_fcns=error_if_validators,
2809 error_name=error_name,
2810 )
2811 except TypeError as e:
2812 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2813 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002814
Jeremy Johnson1271c442023-09-05 11:39:26 +01002815 if result:
Les Bell729b0352021-11-24 10:28:21 +00002816 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002817 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2818 # Add the compliance meta data
2819 # NOTE: This currently expects only one result output
2820 tensMeta["compliance"] = {
2821 "version": "0.1",
2822 "tensors": {result.resultTensor.name: result.complianceDict},
2823 }
2824 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002825 else:
2826 # The test is not valid
2827 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002828
Eric Kunzee5e26762020-10-13 16:11:07 -07002829 def createDynamicOpLists(self):
2830
Jeremy Johnson00423432022-09-12 17:27:37 +01002831 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2832 # Already created these lists (can occur when class is initialized more than once)
2833 return
2834
Eric Kunzee5e26762020-10-13 16:11:07 -07002835 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002836 if not self.args.level8k:
2837 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2838 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2839 else:
2840 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2841 KERNELS_2D = [[1, bigK], [bigK, 2]]
2842 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002843
Kevin Cheng1533b852021-09-01 12:51:58 -07002844 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002845 testName = "conv2d_{}x{}".format(k[0], k[1])
2846 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2847 self.TOSA_OP_LIST[testName]["filter"] = k
2848 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002849
Kevin Cheng550ccc52021-03-03 11:21:43 -08002850 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2851 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2852 "depthwise_conv2d_TEMPLATE"
2853 ].copy()
2854 self.TOSA_OP_LIST[testName]["filter"] = k
2855 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002856
Kevin Cheng550ccc52021-03-03 11:21:43 -08002857 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2858 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2859 "transpose_conv2d_TEMPLATE"
2860 ].copy()
2861 self.TOSA_OP_LIST[testName]["filter"] = k
2862 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002863
Kevin Cheng1533b852021-09-01 12:51:58 -07002864 for k in KERNELS_3D:
2865 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2866 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2867 self.TOSA_OP_LIST[testName]["filter"] = k
2868 self.TOSA_OP_LIST[testName]["template"] = False
2869
Eric Kunzee5e26762020-10-13 16:11:07 -07002870 # Delete any templates after having created any dynamic ops
2871 # This is a two-pass operation because it's bad practice to delete
2872 # keys from dictionaries while iterating
2873 keyList = []
2874 for k in self.TOSA_OP_LIST:
2875 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002876 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002877 keyList.append(k)
2878 continue
2879 except KeyError:
2880 pass
2881
2882 for k in keyList:
2883 del self.TOSA_OP_LIST[k]
2884
2885 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002886 """Fill in default fields for ops if they aren't already specified.
2887 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002888 for op in self.TOSA_OP_LIST:
2889
2890 # Required fields
2891 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002892 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002893 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002894 raise Exception(
2895 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2896 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002897
2898 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002899 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002900 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002901 raise Exception(
2902 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2903 op
2904 )
2905 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002906
2907 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002908 _ = self.TOSA_OP_LIST[op]["types"]
2909 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002910 raise Exception(
2911 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2912 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002913
2914 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002915 _ = self.TOSA_OP_LIST[op]["op"]
2916 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002917 raise Exception(
2918 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2919 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002920
2921 # Put in default rank range, if missing
2922 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002923 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002924 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002925 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002926
2927 # Tensor operator list
2928 # 'op': op name
2929 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002930 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2931 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002932 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2933 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002934 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002935
Kevin Cheng550ccc52021-03-03 11:21:43 -08002936 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002937 TYPE_INT_FP = [
2938 DType.INT8,
2939 DType.INT16,
2940 DType.INT32,
2941 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002942 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002943 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002944 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002945
Kevin Cheng550ccc52021-03-03 11:21:43 -08002946 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002947 TYPE_FI32 = [
2948 DType.FP32,
2949 DType.FP16,
2950 DType.BF16,
2951 DType.INT32,
2952 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002953 TYPE_FIB = [
2954 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002955 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002956 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002957 DType.INT8,
2958 DType.INT16,
2959 DType.INT32,
2960 DType.BOOL,
2961 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002962 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002963
James Ward24dbc422022-10-19 12:20:31 +01002964 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002965
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002966 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002967 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002968 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002969 [DType.INT8, DType.INT8, DType.INT32],
2970 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002971 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002972 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002973 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002974 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002975 ]
2976
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002977 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002978
2979 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002980 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002981 "argmax": {
2982 "op": Op.ARGMAX,
2983 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002984 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002985 "build_fcn": (
2986 build_argmax,
2987 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002988 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002989 TosaArgGen.agAxis,
2990 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002991 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002992 "error_if_validators": (
2993 TosaErrorValidator.evAxisSmallerZero,
2994 TosaErrorValidator.evAxisLargerRank,
2995 TosaErrorValidator.evArgmaxOutputRankMismatch,
2996 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2997 TosaErrorValidator.evWrongRank,
2998 TosaErrorValidator.evWrongInputType,
2999 TosaErrorValidator.evWrongOutputType,
3000 TosaErrorValidator.evWrongInputList,
3001 TosaErrorValidator.evWrongOutputList,
3002 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003003 "data_gen": {
3004 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3005 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003006 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003007 "avg_pool2d": {
3008 "op": Op.AVG_POOL2D,
3009 "operands": (1, 0),
3010 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003011 "build_fcn": (
3012 build_pool2d,
3013 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003014 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003015 TosaArgGen.agPooling,
3016 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003017 "qgen": TosaQuantGen.qgUnary,
3018 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003019 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003020 "error_if_validators": (
3021 TosaErrorValidator.evKernelSmallerOne,
3022 TosaErrorValidator.evStrideSmallerOne,
3023 TosaErrorValidator.evPadSmallerZero,
3024 TosaErrorValidator.evWrongRank,
3025 TosaErrorValidator.evWrongInputType,
3026 TosaErrorValidator.evWrongOutputType,
3027 TosaErrorValidator.evWrongInputList,
3028 TosaErrorValidator.evWrongOutputList,
3029 TosaErrorValidator.evInputZeroPointNotZero,
3030 TosaErrorValidator.evOutputZeroPointNotZero,
3031 TosaErrorValidator.evPadLargerEqualKernel,
3032 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003033 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003034 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003035 "data_gen": {
3036 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3037 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003038 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003039 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003040 "conv2d_TEMPLATE": {
3041 "op": Op.CONV2D,
3042 "operands": (1, 2),
3043 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003044 "build_fcn": (
3045 build_conv2d,
3046 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003047 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003048 TosaArgGen.agConv,
3049 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003050 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003051 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003052 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3053 "error_if_validators": (
3054 TosaErrorValidator.evWrongInputType,
3055 TosaErrorValidator.evWrongOutputType,
3056 TosaErrorValidator.evWrongInputList,
3057 TosaErrorValidator.evWrongOutputList,
3058 TosaErrorValidator.evInputZeroPointNotZero,
3059 TosaErrorValidator.evWeightZeroPointNotZero,
3060 TosaErrorValidator.evPadSmallerZero,
3061 TosaErrorValidator.evStrideSmallerOne,
3062 TosaErrorValidator.evDilationSmallerOne,
3063 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003064 TosaErrorValidator.evConvOutputShapeMismatch,
3065 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003066 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003067 "data_gen": {
3068 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3069 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003070 "template": True,
3071 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003072 # Templated operator. Filled in by createDynamicOpLists
3073 "conv3d_TEMPLATE": {
3074 "op": Op.CONV3D,
3075 "operands": (1, 2),
3076 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003077 "build_fcn": (
3078 build_conv3d,
3079 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003080 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003081 TosaArgGen.agConv,
3082 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003083 "qgen": TosaQuantGen.qgConv,
3084 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003085 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3086 "error_if_validators": (
3087 TosaErrorValidator.evWrongInputType,
3088 TosaErrorValidator.evWrongOutputType,
3089 TosaErrorValidator.evWrongInputList,
3090 TosaErrorValidator.evWrongOutputList,
3091 TosaErrorValidator.evInputZeroPointNotZero,
3092 TosaErrorValidator.evWeightZeroPointNotZero,
3093 TosaErrorValidator.evPadSmallerZero,
3094 TosaErrorValidator.evStrideSmallerOne,
3095 TosaErrorValidator.evDilationSmallerOne,
3096 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003097 TosaErrorValidator.evConvOutputShapeMismatch,
3098 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003099 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003100 "template": True,
3101 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003102 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003103 "depthwise_conv2d_TEMPLATE": {
3104 "op": Op.DEPTHWISE_CONV2D,
3105 "operands": (1, 2),
3106 "filter": [1, 1],
3107 "rank": (4, 4),
3108 "build_fcn": (
3109 build_depthwise_conv2d,
3110 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003111 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003112 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003113 ),
3114 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003115 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003116 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3117 "error_if_validators": (
3118 TosaErrorValidator.evWrongInputType,
3119 TosaErrorValidator.evWrongOutputType,
3120 TosaErrorValidator.evWrongInputList,
3121 TosaErrorValidator.evWrongOutputList,
3122 TosaErrorValidator.evInputZeroPointNotZero,
3123 TosaErrorValidator.evWeightZeroPointNotZero,
3124 TosaErrorValidator.evPadSmallerZero,
3125 TosaErrorValidator.evStrideSmallerOne,
3126 TosaErrorValidator.evDilationSmallerOne,
3127 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003128 TosaErrorValidator.evConvOutputShapeMismatch,
3129 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003130 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003131 "template": True,
3132 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003133 "fully_connected": {
3134 "op": Op.FULLY_CONNECTED,
3135 "operands": (1, 2),
3136 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003137 "build_fcn": (
3138 build_fully_connected,
3139 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003140 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003141 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003142 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003143 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003144 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003145 "error_if_validators": (
3146 TosaErrorValidator.evInputZeroPointNotZero,
3147 TosaErrorValidator.evWeightZeroPointNotZero,
3148 TosaErrorValidator.evWrongRank,
3149 TosaErrorValidator.evWrongInputType,
3150 TosaErrorValidator.evWrongOutputType,
3151 TosaErrorValidator.evWrongInputList,
3152 TosaErrorValidator.evWrongOutputList,
3153 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003154 "data_gen": {
3155 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3156 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003157 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003158 "matmul": {
3159 "op": Op.MATMUL,
3160 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003161 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003162 "build_fcn": (
3163 build_matmul,
3164 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003165 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003166 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003167 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003168 "qgen": TosaQuantGen.qgMatmul,
3169 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003170 "error_if_validators": (
3171 TosaErrorValidator.evInputZeroPointNotZero,
3172 TosaErrorValidator.evWrongRank,
3173 TosaErrorValidator.evWrongInputType,
3174 TosaErrorValidator.evWrongOutputType,
3175 TosaErrorValidator.evWrongInputList,
3176 TosaErrorValidator.evWrongOutputList,
3177 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003178 "data_gen": {
3179 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003180 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003181 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003182 "max_pool2d": {
3183 "op": Op.MAX_POOL2D,
3184 "operands": (1, 0),
3185 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003186 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003187 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003188 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003189 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003190 TosaArgGen.agPooling,
3191 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003192 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003193 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003194 "error_if_validators": (
3195 TosaErrorValidator.evKernelSmallerOne,
3196 TosaErrorValidator.evStrideSmallerOne,
3197 TosaErrorValidator.evPadSmallerZero,
3198 TosaErrorValidator.evWrongRank,
3199 TosaErrorValidator.evWrongInputType,
3200 TosaErrorValidator.evWrongOutputType,
3201 TosaErrorValidator.evWrongInputList,
3202 TosaErrorValidator.evWrongOutputList,
3203 TosaErrorValidator.evPadLargerEqualKernel,
3204 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003205 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003206 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003207 "data_gen": {
3208 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3209 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003210 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003211 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003212 "transpose_conv2d_TEMPLATE": {
3213 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003214 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003215 "rank": (4, 4),
3216 "build_fcn": (
3217 build_transpose_conv2d,
3218 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003219 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003220 TosaArgGen.agTransposeConv2D,
3221 ),
3222 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003223 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003224 "invalid_test_validators": (
3225 TosaInvalidValidator.ivHeightWidthInvalid,
3226 TosaInvalidValidator.ivNonPositiveOutputShape,
3227 ),
3228 "error_if_validators": (
3229 TosaErrorValidator.evWrongInputType,
3230 TosaErrorValidator.evWrongOutputType,
3231 TosaErrorValidator.evWrongInputList,
3232 TosaErrorValidator.evWrongOutputList,
3233 TosaErrorValidator.evInputZeroPointNotZero,
3234 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003235 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003236 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003237 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003238 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003239 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003240 "template": True,
3241 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003242 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003243 "clamp": {
3244 "op": Op.CLAMP,
3245 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003246 "build_fcn": (
3247 build_clamp,
3248 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003249 TosaTensorValuesGen.tvgLazyGenDefault,
3250 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003251 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003252 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003253 "error_if_validators": (
3254 TosaErrorValidator.evMaxSmallerMin,
3255 TosaErrorValidator.evWrongInputType,
3256 TosaErrorValidator.evWrongOutputType,
3257 TosaErrorValidator.evWrongInputList,
3258 TosaErrorValidator.evWrongOutputList,
3259 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003260 "data_gen": {
3261 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3262 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003263 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003264 "sigmoid": {
3265 "op": Op.SIGMOID,
3266 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003267 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003268 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003269 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003270 TosaTensorValuesGen.tvgLazyGenDefault,
3271 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003272 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003273 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003274 "error_if_validators": (
3275 TosaErrorValidator.evWrongInputType,
3276 TosaErrorValidator.evWrongOutputType,
3277 TosaErrorValidator.evWrongInputList,
3278 TosaErrorValidator.evWrongOutputList,
3279 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003280 "data_gen": {
3281 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3282 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003283 },
3284 "tanh": {
3285 "op": Op.TANH,
3286 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003287 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003288 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003289 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003290 TosaTensorValuesGen.tvgLazyGenDefault,
3291 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003292 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003293 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003294 "error_if_validators": (
3295 TosaErrorValidator.evWrongInputType,
3296 TosaErrorValidator.evWrongOutputType,
3297 TosaErrorValidator.evWrongInputList,
3298 TosaErrorValidator.evWrongOutputList,
3299 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003300 "data_gen": {
3301 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3302 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003303 "compliance": {
3304 "abs_error_lower_bound": 0.5,
3305 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003306 },
Won Jeon78155c62023-06-10 00:20:04 +00003307 "erf": {
3308 "op": Op.ERF,
3309 "operands": (1, 0),
3310 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003311 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003312 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003313 TosaTensorValuesGen.tvgLazyGenDefault,
3314 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003315 ),
3316 "types": TYPE_FP,
3317 "error_if_validators": (
3318 TosaErrorValidator.evWrongInputType,
3319 TosaErrorValidator.evWrongOutputType,
3320 TosaErrorValidator.evWrongInputList,
3321 TosaErrorValidator.evWrongOutputList,
3322 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003323 "data_gen": {
3324 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3325 },
3326 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003327 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003328 # Elementwise Binary Operators
3329 "add": {
3330 "op": Op.ADD,
3331 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003332 "build_fcn": (
3333 build_binary_broadcast,
3334 TosaTensorGen.tgBroadcastFuzz,
3335 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003336 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003337 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003339 "error_if_validators": (
3340 TosaErrorValidator.evRankMismatch,
3341 TosaErrorValidator.evWrongInputType,
3342 TosaErrorValidator.evWrongOutputType,
3343 TosaErrorValidator.evWrongInputList,
3344 TosaErrorValidator.evWrongOutputList,
3345 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003346 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003347 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003348 "data_gen": {
3349 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3350 },
3351 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003352 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003353 "arithmetic_right_shift": {
3354 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3355 "operands": (2, 0),
3356 "build_fcn": (
3357 build_arithmetic_right_shift,
3358 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003359 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003360 TosaArgGen.agArithmeticRightShift,
3361 ),
3362 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003363 "error_if_validators": (
3364 TosaErrorValidator.evRankMismatch,
3365 TosaErrorValidator.evWrongInputType,
3366 TosaErrorValidator.evWrongOutputType,
3367 TosaErrorValidator.evWrongInputList,
3368 TosaErrorValidator.evWrongOutputList,
3369 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003370 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003371 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003372 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003373 "bitwise_and": {
3374 "op": Op.BITWISE_AND,
3375 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003376 "build_fcn": (
3377 build_binary_broadcast,
3378 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003379 TosaTensorValuesGen.tvgLazyGenDefault,
3380 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003381 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003382 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003383 "error_if_validators": (
3384 TosaErrorValidator.evRankMismatch,
3385 TosaErrorValidator.evWrongInputType,
3386 TosaErrorValidator.evWrongOutputType,
3387 TosaErrorValidator.evWrongInputList,
3388 TosaErrorValidator.evWrongOutputList,
3389 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003390 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003391 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003392 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003393 "bitwise_or": {
3394 "op": Op.BITWISE_OR,
3395 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003396 "build_fcn": (
3397 build_binary_broadcast,
3398 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003399 TosaTensorValuesGen.tvgLazyGenDefault,
3400 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003401 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003402 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003403 "error_if_validators": (
3404 TosaErrorValidator.evRankMismatch,
3405 TosaErrorValidator.evWrongInputType,
3406 TosaErrorValidator.evWrongOutputType,
3407 TosaErrorValidator.evWrongInputList,
3408 TosaErrorValidator.evWrongOutputList,
3409 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003410 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003411 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003412 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003413 "bitwise_xor": {
3414 "op": Op.BITWISE_XOR,
3415 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003416 "build_fcn": (
3417 build_binary_broadcast,
3418 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003419 TosaTensorValuesGen.tvgLazyGenDefault,
3420 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003421 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003422 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003423 "error_if_validators": (
3424 TosaErrorValidator.evRankMismatch,
3425 TosaErrorValidator.evWrongInputType,
3426 TosaErrorValidator.evWrongOutputType,
3427 TosaErrorValidator.evWrongInputList,
3428 TosaErrorValidator.evWrongOutputList,
3429 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003430 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003431 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003432 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003433 "intdiv": {
3434 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003435 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003436 "build_fcn": (
3437 build_binary_broadcast,
3438 TosaTensorGen.tgBroadcastFuzz,
3439 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003440 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003441 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003442 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003443 "error_if_validators": (
3444 TosaErrorValidator.evRankMismatch,
3445 TosaErrorValidator.evWrongInputType,
3446 TosaErrorValidator.evWrongOutputType,
3447 TosaErrorValidator.evWrongInputList,
3448 TosaErrorValidator.evWrongOutputList,
3449 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003450 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003451 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003452 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003453 "logical_and": {
3454 "op": Op.LOGICAL_AND,
3455 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003456 "build_fcn": (
3457 build_binary_broadcast,
3458 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003459 TosaTensorValuesGen.tvgLazyGenDefault,
3460 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003461 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003462 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003463 "error_if_validators": (
3464 TosaErrorValidator.evRankMismatch,
3465 TosaErrorValidator.evWrongInputType,
3466 TosaErrorValidator.evWrongOutputType,
3467 TosaErrorValidator.evWrongInputList,
3468 TosaErrorValidator.evWrongOutputList,
3469 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003470 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003471 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003472 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 "logical_left_shift": {
3474 "op": Op.LOGICAL_LEFT_SHIFT,
3475 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003476 "build_fcn": (
3477 build_binary_broadcast,
3478 TosaTensorGen.tgBroadcastFuzz,
3479 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003480 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003481 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003482 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003483 "error_if_validators": (
3484 TosaErrorValidator.evRankMismatch,
3485 TosaErrorValidator.evWrongInputType,
3486 TosaErrorValidator.evWrongOutputType,
3487 TosaErrorValidator.evWrongInputList,
3488 TosaErrorValidator.evWrongOutputList,
3489 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003490 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003491 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003492 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003493 "logical_right_shift": {
3494 "op": Op.LOGICAL_RIGHT_SHIFT,
3495 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003496 "build_fcn": (
3497 build_binary_broadcast,
3498 TosaTensorGen.tgBroadcastFuzz,
3499 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003500 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003501 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003502 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003503 "error_if_validators": (
3504 TosaErrorValidator.evRankMismatch,
3505 TosaErrorValidator.evWrongInputType,
3506 TosaErrorValidator.evWrongOutputType,
3507 TosaErrorValidator.evWrongInputList,
3508 TosaErrorValidator.evWrongOutputList,
3509 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003510 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003511 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003512 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003513 "logical_or": {
3514 "op": Op.LOGICAL_OR,
3515 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003516 "build_fcn": (
3517 build_binary_broadcast,
3518 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003519 TosaTensorValuesGen.tvgLazyGenDefault,
3520 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003521 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003522 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003523 "error_if_validators": (
3524 TosaErrorValidator.evRankMismatch,
3525 TosaErrorValidator.evWrongInputType,
3526 TosaErrorValidator.evWrongOutputType,
3527 TosaErrorValidator.evWrongInputList,
3528 TosaErrorValidator.evWrongOutputList,
3529 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003530 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003531 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003532 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003533 "logical_xor": {
3534 "op": Op.LOGICAL_XOR,
3535 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003536 "build_fcn": (
3537 build_binary_broadcast,
3538 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003539 TosaTensorValuesGen.tvgLazyGenDefault,
3540 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003541 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003542 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003543 "error_if_validators": (
3544 TosaErrorValidator.evRankMismatch,
3545 TosaErrorValidator.evWrongInputType,
3546 TosaErrorValidator.evWrongOutputType,
3547 TosaErrorValidator.evWrongInputList,
3548 TosaErrorValidator.evWrongOutputList,
3549 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003550 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003551 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003552 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003553 "maximum": {
3554 "op": Op.MAXIMUM,
3555 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003556 "build_fcn": (
3557 build_binary_broadcast,
3558 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003559 TosaTensorValuesGen.tvgLazyGenDefault,
3560 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003561 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003562 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003563 "error_if_validators": (
3564 TosaErrorValidator.evRankMismatch,
3565 TosaErrorValidator.evWrongInputType,
3566 TosaErrorValidator.evWrongOutputType,
3567 TosaErrorValidator.evWrongInputList,
3568 TosaErrorValidator.evWrongOutputList,
3569 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003570 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003571 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003572 "data_gen": {
3573 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3574 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003575 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003576 "minimum": {
3577 "op": Op.MINIMUM,
3578 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003579 "build_fcn": (
3580 build_binary_broadcast,
3581 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003582 TosaTensorValuesGen.tvgLazyGenDefault,
3583 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003584 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003585 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003586 "error_if_validators": (
3587 TosaErrorValidator.evRankMismatch,
3588 TosaErrorValidator.evWrongInputType,
3589 TosaErrorValidator.evWrongOutputType,
3590 TosaErrorValidator.evWrongInputList,
3591 TosaErrorValidator.evWrongOutputList,
3592 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003593 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003594 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003595 "data_gen": {
3596 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3597 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003598 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003599 "mul": {
3600 "op": Op.MUL,
3601 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003602 "build_fcn": (
3603 build_mul,
3604 TosaTensorGen.tgBroadcastFuzz,
3605 TosaTensorValuesGen.tvgMul,
3606 TosaArgGen.agMul,
3607 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003608 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003609 "error_if_validators": (
3610 TosaErrorValidator.evWrongInputType,
3611 TosaErrorValidator.evWrongOutputType,
3612 TosaErrorValidator.evWrongInputList,
3613 TosaErrorValidator.evWrongOutputList,
3614 TosaErrorValidator.evRankMismatch,
3615 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003616 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003617 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003618 "data_gen": {
3619 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3620 },
3621 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003622 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003623 "pow": {
3624 "op": Op.POW,
3625 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003626 "build_fcn": (
3627 build_binary_broadcast,
3628 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003629 TosaTensorValuesGen.tvgPow,
3630 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003631 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003632 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003633 "error_if_validators": (
3634 TosaErrorValidator.evRankMismatch,
3635 TosaErrorValidator.evWrongInputType,
3636 TosaErrorValidator.evWrongOutputType,
3637 TosaErrorValidator.evWrongInputList,
3638 TosaErrorValidator.evWrongOutputList,
3639 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003640 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003641 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003642 "data_gen": {
3643 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3644 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003645 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003646 "sub": {
3647 "op": Op.SUB,
3648 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003649 "build_fcn": (
3650 build_binary_broadcast,
3651 TosaTensorGen.tgBroadcastFuzz,
3652 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003653 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003654 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003655 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003656 "error_if_validators": (
3657 TosaErrorValidator.evRankMismatch,
3658 TosaErrorValidator.evWrongInputType,
3659 TosaErrorValidator.evWrongOutputType,
3660 TosaErrorValidator.evWrongInputList,
3661 TosaErrorValidator.evWrongOutputList,
3662 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003663 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003664 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003665 "data_gen": {
3666 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3667 },
3668 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003669 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003670 "table": {
3671 "op": Op.TABLE,
3672 # Use the automatic generation functions to create the input array
3673 # but create the table tensor in the build function, as it may be
3674 # a different type from the input
3675 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003676 "build_fcn": (
3677 build_table,
3678 TosaTensorGen.tgBasic,
3679 TosaTensorValuesGen.tvgDefault,
3680 TosaArgGen.agTable,
3681 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003682 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003683 "error_if_validators": (
3684 TosaErrorValidator.evWrongInputType,
3685 TosaErrorValidator.evWrongOutputType,
3686 TosaErrorValidator.evWrongInputList,
3687 TosaErrorValidator.evWrongOutputList,
3688 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003689 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003690 # Elementwise Unary operators
3691 "abs": {
3692 "op": Op.ABS,
3693 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003694 "build_fcn": (
3695 build_unary,
3696 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003697 TosaTensorValuesGen.tvgLazyGenDefault,
3698 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003699 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003700 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003701 "error_if_validators": (
3702 TosaErrorValidator.evWrongInputType,
3703 TosaErrorValidator.evWrongOutputType,
3704 TosaErrorValidator.evWrongInputList,
3705 TosaErrorValidator.evWrongOutputList,
3706 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003707 "data_gen": {
3708 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3709 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003710 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003711 "bitwise_not": {
3712 "op": Op.BITWISE_NOT,
3713 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003714 "build_fcn": (
3715 build_unary,
3716 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003717 TosaTensorValuesGen.tvgLazyGenDefault,
3718 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003719 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003720 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003721 "error_if_validators": (
3722 TosaErrorValidator.evWrongInputType,
3723 TosaErrorValidator.evWrongOutputType,
3724 TosaErrorValidator.evWrongInputList,
3725 TosaErrorValidator.evWrongOutputList,
3726 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003727 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003728 "ceil": {
3729 "op": Op.CEIL,
3730 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003731 "build_fcn": (
3732 build_unary,
3733 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003734 TosaTensorValuesGen.tvgLazyGenDefault,
3735 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003736 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003737 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003738 "error_if_validators": (
3739 TosaErrorValidator.evWrongInputType,
3740 TosaErrorValidator.evWrongOutputType,
3741 TosaErrorValidator.evWrongInputList,
3742 TosaErrorValidator.evWrongOutputList,
3743 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003744 "data_gen": {
3745 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3746 },
3747 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003748 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003749 "clz": {
3750 "op": Op.CLZ,
3751 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003752 "build_fcn": (
3753 build_unary,
3754 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003755 TosaTensorValuesGen.tvgLazyGenDefault,
3756 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003757 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003758 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003759 "error_if_validators": (
3760 TosaErrorValidator.evWrongInputType,
3761 TosaErrorValidator.evWrongOutputType,
3762 TosaErrorValidator.evWrongInputList,
3763 TosaErrorValidator.evWrongOutputList,
3764 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003765 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003766 "exp": {
3767 "op": Op.EXP,
3768 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003769 "build_fcn": (
3770 build_unary,
3771 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003772 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003773 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003774 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003775 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003776 "error_if_validators": (
3777 TosaErrorValidator.evWrongInputType,
3778 TosaErrorValidator.evWrongOutputType,
3779 TosaErrorValidator.evWrongInputList,
3780 TosaErrorValidator.evWrongOutputList,
3781 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003782 "data_gen": {
3783 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3784 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003785 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003786 "floor": {
3787 "op": Op.FLOOR,
3788 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003789 "build_fcn": (
3790 build_unary,
3791 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003792 TosaTensorValuesGen.tvgLazyGenDefault,
3793 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003794 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003795 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003796 "error_if_validators": (
3797 TosaErrorValidator.evWrongInputType,
3798 TosaErrorValidator.evWrongOutputType,
3799 TosaErrorValidator.evWrongInputList,
3800 TosaErrorValidator.evWrongOutputList,
3801 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003802 "data_gen": {
3803 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3804 },
3805 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003806 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003807 "log": {
3808 "op": Op.LOG,
3809 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003810 "build_fcn": (
3811 build_unary,
3812 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003813 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003814 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003815 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003816 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003817 "error_if_validators": (
3818 TosaErrorValidator.evWrongInputType,
3819 TosaErrorValidator.evWrongOutputType,
3820 TosaErrorValidator.evWrongInputList,
3821 TosaErrorValidator.evWrongOutputList,
3822 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003823 "data_gen": {
3824 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3825 },
3826 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003827 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003828 "logical_not": {
3829 "op": Op.LOGICAL_NOT,
3830 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003831 "build_fcn": (
3832 build_unary,
3833 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003834 TosaTensorValuesGen.tvgLazyGenDefault,
3835 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003836 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003837 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003838 "error_if_validators": (
3839 TosaErrorValidator.evWrongInputType,
3840 TosaErrorValidator.evWrongOutputType,
3841 TosaErrorValidator.evWrongInputList,
3842 TosaErrorValidator.evWrongOutputList,
3843 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003844 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003845 "negate": {
3846 "op": Op.NEGATE,
3847 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003848 "build_fcn": (
3849 build_unary,
3850 TosaTensorGen.tgBasic,
3851 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003852 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003853 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003854 "qgen": TosaQuantGen.qgUnary,
3855 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003856 "error_if_validators": (
3857 TosaErrorValidator.evInputZeroPointNotZero,
3858 TosaErrorValidator.evOutputZeroPointNotZero,
3859 TosaErrorValidator.evWrongInputType,
3860 TosaErrorValidator.evWrongOutputType,
3861 TosaErrorValidator.evWrongInputList,
3862 TosaErrorValidator.evWrongOutputList,
3863 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003864 "data_gen": {
3865 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3866 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003867 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003868 "reciprocal": {
3869 "op": Op.RECIPROCAL,
3870 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003871 "build_fcn": (
3872 build_unary,
3873 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003874 TosaTensorValuesGen.tvgLazyGenDefault,
3875 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003876 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003877 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003878 "error_if_validators": (
3879 TosaErrorValidator.evWrongInputType,
3880 TosaErrorValidator.evWrongOutputType,
3881 TosaErrorValidator.evWrongInputList,
3882 TosaErrorValidator.evWrongOutputList,
3883 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003884 "data_gen": {
3885 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3886 },
3887 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003888 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003889 "rsqrt": {
3890 "op": Op.RSQRT,
3891 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003892 "build_fcn": (
3893 build_unary,
3894 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003895 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003896 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003897 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003898 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003899 "error_if_validators": (
3900 TosaErrorValidator.evWrongInputType,
3901 TosaErrorValidator.evWrongOutputType,
3902 TosaErrorValidator.evWrongInputList,
3903 TosaErrorValidator.evWrongOutputList,
3904 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003905 "data_gen": {
3906 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3907 },
3908 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003909 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003910 # Elementwise Ternary operators
3911 "select": {
3912 "op": Op.SELECT,
3913 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003914 "build_fcn": (
3915 build_select,
3916 TosaTensorGen.tgBroadcastFuzz,
3917 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00003918 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003919 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003920 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003921 "error_if_validators": (
3922 TosaErrorValidator.evRankMismatch,
3923 TosaErrorValidator.evWrongInputType,
3924 TosaErrorValidator.evWrongOutputType,
3925 TosaErrorValidator.evWrongInputList,
3926 TosaErrorValidator.evWrongOutputList,
3927 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003928 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003929 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00003930 "data_gen": {
3931 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3932 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003933 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003934 # Comparison operators
3935 "equal": {
3936 "op": Op.EQUAL,
3937 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003938 "build_fcn": (
3939 build_comparison,
3940 TosaTensorGen.tgBroadcastFuzz,
3941 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003942 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003943 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003944 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003945 "error_if_validators": (
3946 TosaErrorValidator.evRankMismatch,
3947 TosaErrorValidator.evWrongInputType,
3948 TosaErrorValidator.evWrongOutputType,
3949 TosaErrorValidator.evWrongInputList,
3950 TosaErrorValidator.evWrongOutputList,
3951 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003952 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003953 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003954 "data_gen": {
3955 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3956 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003957 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003958 "greater_equal": {
3959 "op": Op.GREATER_EQUAL,
3960 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003961 "build_fcn": (
3962 build_comparison,
3963 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003964 TosaTensorValuesGen.tvgLazyGenDefault,
3965 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003966 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003967 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003968 "error_if_validators": (
3969 TosaErrorValidator.evRankMismatch,
3970 TosaErrorValidator.evWrongInputType,
3971 TosaErrorValidator.evWrongOutputType,
3972 TosaErrorValidator.evWrongInputList,
3973 TosaErrorValidator.evWrongOutputList,
3974 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003975 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003976 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003977 "data_gen": {
3978 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3979 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003980 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003981 "greater": {
3982 "op": Op.GREATER,
3983 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003984 "build_fcn": (
3985 build_comparison,
3986 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003987 TosaTensorValuesGen.tvgLazyGenDefault,
3988 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003989 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003990 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003991 "error_if_validators": (
3992 TosaErrorValidator.evRankMismatch,
3993 TosaErrorValidator.evWrongInputType,
3994 TosaErrorValidator.evWrongOutputType,
3995 TosaErrorValidator.evWrongInputList,
3996 TosaErrorValidator.evWrongOutputList,
3997 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003998 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003999 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004000 "data_gen": {
4001 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4002 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004003 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004004 # Reduction operators
4005 "reduce_all": {
4006 "op": Op.REDUCE_ALL,
4007 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004008 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004009 "build_fcn": (
4010 build_reduce,
4011 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004012 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004013 TosaArgGen.agAxis,
4014 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004015 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004016 "error_if_validators": (
4017 TosaErrorValidator.evAxisLargerRank,
4018 TosaErrorValidator.evAxisSmallerZero,
4019 TosaErrorValidator.evShapeOfAxisNotOne,
4020 TosaErrorValidator.evWrongInputType,
4021 TosaErrorValidator.evWrongOutputType,
4022 TosaErrorValidator.evWrongRank,
4023 TosaErrorValidator.evWrongInputList,
4024 TosaErrorValidator.evWrongOutputList,
4025 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004026 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004027 "reduce_any": {
4028 "op": Op.REDUCE_ANY,
4029 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004030 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004031 "build_fcn": (
4032 build_reduce,
4033 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004034 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004035 TosaArgGen.agAxis,
4036 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004037 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004038 "error_if_validators": (
4039 TosaErrorValidator.evAxisLargerRank,
4040 TosaErrorValidator.evAxisSmallerZero,
4041 TosaErrorValidator.evShapeOfAxisNotOne,
4042 TosaErrorValidator.evWrongInputType,
4043 TosaErrorValidator.evWrongOutputType,
4044 TosaErrorValidator.evWrongRank,
4045 TosaErrorValidator.evWrongInputList,
4046 TosaErrorValidator.evWrongOutputList,
4047 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004048 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004049 "reduce_max": {
4050 "op": Op.REDUCE_MAX,
4051 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004052 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004053 "build_fcn": (
4054 build_reduce,
4055 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004056 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004057 TosaArgGen.agAxis,
4058 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004059 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004060 "error_if_validators": (
4061 TosaErrorValidator.evAxisLargerRank,
4062 TosaErrorValidator.evAxisSmallerZero,
4063 TosaErrorValidator.evShapeOfAxisNotOne,
4064 TosaErrorValidator.evWrongInputType,
4065 TosaErrorValidator.evWrongOutputType,
4066 TosaErrorValidator.evWrongRank,
4067 TosaErrorValidator.evWrongInputList,
4068 TosaErrorValidator.evWrongOutputList,
4069 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004070 "data_gen": {
4071 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4072 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004073 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004074 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004075 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004076 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004077 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004078 "build_fcn": (
4079 build_reduce,
4080 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004081 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004082 TosaArgGen.agAxis,
4083 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004084 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004085 "error_if_validators": (
4086 TosaErrorValidator.evAxisLargerRank,
4087 TosaErrorValidator.evAxisSmallerZero,
4088 TosaErrorValidator.evShapeOfAxisNotOne,
4089 TosaErrorValidator.evWrongInputType,
4090 TosaErrorValidator.evWrongOutputType,
4091 TosaErrorValidator.evWrongRank,
4092 TosaErrorValidator.evWrongInputList,
4093 TosaErrorValidator.evWrongOutputList,
4094 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004095 "data_gen": {
4096 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4097 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004098 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004099 "reduce_product": {
4100 "op": Op.REDUCE_PRODUCT,
4101 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004102 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004103 "build_fcn": (
4104 build_reduce,
4105 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004106 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004107 TosaArgGen.agAxis,
4108 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004109 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004110 "error_if_validators": (
4111 TosaErrorValidator.evAxisLargerRank,
4112 TosaErrorValidator.evAxisSmallerZero,
4113 TosaErrorValidator.evShapeOfAxisNotOne,
4114 TosaErrorValidator.evWrongInputType,
4115 TosaErrorValidator.evWrongOutputType,
4116 TosaErrorValidator.evWrongRank,
4117 TosaErrorValidator.evWrongInputList,
4118 TosaErrorValidator.evWrongOutputList,
4119 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004120 "data_gen": {
4121 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4122 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004123 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004124 "reduce_sum": {
4125 "op": Op.REDUCE_SUM,
4126 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004127 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004128 "build_fcn": (
4129 build_reduce,
4130 TosaTensorGen.tgBasic,
4131 TosaTensorValuesGen.tvgReduceSum,
4132 TosaArgGen.agAxis,
4133 ),
James Ward24dbc422022-10-19 12:20:31 +01004134 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004135 "error_if_validators": (
4136 TosaErrorValidator.evAxisLargerRank,
4137 TosaErrorValidator.evAxisSmallerZero,
4138 TosaErrorValidator.evShapeOfAxisNotOne,
4139 TosaErrorValidator.evWrongInputType,
4140 TosaErrorValidator.evWrongOutputType,
4141 TosaErrorValidator.evWrongRank,
4142 TosaErrorValidator.evWrongInputList,
4143 TosaErrorValidator.evWrongOutputList,
4144 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004145 "data_gen": {
4146 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4147 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004148 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004149 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004150 "concat": {
4151 "op": Op.CONCAT,
4152 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004153 "build_fcn": (
4154 build_concat,
4155 TosaTensorGen.tgConcat,
4156 TosaTensorValuesGen.tvgConcat,
4157 TosaArgGen.agAxis,
4158 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004159 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004160 "error_if_validators": (
4161 TosaErrorValidator.evAxisLargerRank,
4162 TosaErrorValidator.evAxisSmallerZero,
4163 TosaErrorValidator.evConcatInputRankMismatch,
4164 TosaErrorValidator.evConcatShapeSumMismatch,
4165 TosaErrorValidator.evConcatInputDimMismatch,
4166 TosaErrorValidator.evWrongInputType,
4167 TosaErrorValidator.evWrongOutputType,
4168 TosaErrorValidator.evWrongOutputList,
4169 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004170 "data_gen": {
4171 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4172 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004173 },
4174 "pad": {
4175 "op": Op.PAD,
4176 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004177 "build_fcn": (
4178 build_pad,
4179 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004180 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004181 TosaArgGen.agPad,
4182 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004183 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004184 "error_if_validators": (
4185 TosaErrorValidator.evWrongInputType,
4186 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004187 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004188 TosaErrorValidator.evWrongOutputType,
4189 TosaErrorValidator.evWrongInputList,
4190 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004191 TosaErrorValidator.evRankMismatch,
4192 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004193 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004194 "data_gen": {
4195 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4196 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004197 },
Won Jeona21b2e82023-08-10 10:33:01 +00004198 "dim": {
4199 "op": Op.DIM,
4200 "operands": (1, 0),
4201 "build_fcn": (
4202 build_dim,
4203 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004204 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004205 TosaArgGen.agAxis,
4206 ),
4207 "types": TYPE_FIB,
4208 "error_if_validators": (
4209 TosaErrorValidator.evAxisLargerRank,
4210 TosaErrorValidator.evAxisSmallerZero,
4211 TosaErrorValidator.evWrongInputType,
4212 TosaErrorValidator.evWrongInputList,
4213 TosaErrorValidator.evWrongOutputList,
4214 TosaErrorValidator.evWrongRank,
4215 ),
4216 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004217 "reshape": {
4218 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004219 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004220 "build_fcn": (
4221 build_reshape,
4222 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004223 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004224 TosaArgGen.agReshape,
4225 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004226 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004227 "error_if_validators": (
4228 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4229 TosaErrorValidator.evWrongInputType,
4230 TosaErrorValidator.evWrongOutputType,
4231 TosaErrorValidator.evWrongInputList,
4232 TosaErrorValidator.evWrongOutputList,
4233 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004234 "data_gen": {
4235 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4236 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004237 },
4238 "reverse": {
4239 "op": Op.REVERSE,
4240 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004241 "build_fcn": (
4242 build_reverse,
4243 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004244 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004245 TosaArgGen.agAxis,
4246 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004247 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004248 "error_if_validators": (
4249 TosaErrorValidator.evAxisSmallerZero,
4250 TosaErrorValidator.evAxisLargerRank,
4251 TosaErrorValidator.evWrongInputType,
4252 TosaErrorValidator.evWrongOutputType,
4253 TosaErrorValidator.evWrongInputList,
4254 TosaErrorValidator.evWrongOutputList,
4255 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004256 },
4257 "slice": {
4258 "op": Op.SLICE,
4259 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004260 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004261 "build_fcn": (
4262 build_slice,
4263 TosaTensorGen.tgBasic,
4264 TosaTensorValuesGen.tvgDefault,
4265 TosaArgGen.agSlice,
4266 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004267 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004268 "error_if_validators": (
4269 TosaErrorValidator.evStartSmallerZero,
4270 TosaErrorValidator.evSizeSmallerEqualZero,
4271 TosaErrorValidator.evStartSizeOutsideBounds,
4272 TosaErrorValidator.evSizeOutputShapeMismatch,
4273 TosaErrorValidator.evInputSizeStartLengthMismatch,
4274 TosaErrorValidator.evWrongRank,
4275 TosaErrorValidator.evWrongInputType,
4276 TosaErrorValidator.evWrongOutputType,
4277 TosaErrorValidator.evWrongInputList,
4278 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004279 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004280 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004281 },
4282 "tile": {
4283 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004284 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004285 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004286 "build_fcn": (
4287 build_tile,
4288 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004289 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004290 TosaArgGen.agTile,
4291 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004292 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004293 "error_if_validators": (
4294 TosaErrorValidator.evWrongInputType,
4295 TosaErrorValidator.evWrongOutputType,
4296 TosaErrorValidator.evWrongInputList,
4297 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004298 TosaErrorValidator.evRankMismatch,
4299 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004300 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004301 "data_gen": {
4302 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4303 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004304 },
4305 "transpose": {
4306 "op": Op.TRANSPOSE,
4307 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004308 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004309 "build_fcn": (
4310 build_transpose,
4311 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004312 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004313 TosaArgGen.agTranspose,
4314 ),
4315 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004316 "error_if_validators": (
4317 TosaErrorValidator.evIndexOutsideBounds,
4318 TosaErrorValidator.evIndexUsedTwice,
4319 TosaErrorValidator.evWrongInputType,
4320 TosaErrorValidator.evWrongOutputType,
4321 TosaErrorValidator.evWrongInputList,
4322 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004323 TosaErrorValidator.evWrongRank,
4324 TosaErrorValidator.evRankMismatch,
4325 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004326 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004327 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004328 # Data nodes
4329 "const": {
4330 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004331 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004332 "build_fcn": (
4333 build_const,
4334 TosaTensorGen.tgBasic,
4335 TosaTensorValuesGen.tvgDefault,
4336 None,
4337 ),
Luke Hutton65872422023-02-20 10:33:04 +00004338 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004339 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004340 "identity": {
4341 "op": Op.IDENTITY,
4342 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004343 "build_fcn": (
4344 build_unary,
4345 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004346 TosaTensorValuesGen.tvgLazyGenDefault,
4347 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004348 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004349 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004350 "data_gen": {
4351 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4352 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004353 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004354 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004355 "gather": {
4356 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004357 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004358 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004359 "build_fcn": (
4360 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004361 TosaTensorGen.tgGather,
4362 TosaTensorValuesGen.tvgGather,
4363 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004364 ),
James Ward24dbc422022-10-19 12:20:31 +01004365 "types": (
4366 DType.INT8,
4367 DType.INT16,
4368 DType.INT32,
4369 DType.FP16,
4370 DType.BF16,
4371 DType.FP32,
4372 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004373 "error_if_validators": (
4374 TosaErrorValidator.evWrongInputType,
4375 TosaErrorValidator.evWrongOutputType,
4376 TosaErrorValidator.evWrongInputList,
4377 TosaErrorValidator.evWrongOutputList,
4378 TosaErrorValidator.evWrongRank,
4379 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004380 "data_gen": {
4381 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4382 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004383 },
4384 "scatter": {
4385 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004386 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004387 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004388 "build_fcn": (
4389 build_scatter,
4390 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004391 TosaTensorValuesGen.tvgScatter,
4392 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004393 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004394 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004395 "error_if_validators": (
4396 TosaErrorValidator.evWrongInputType,
4397 TosaErrorValidator.evWrongOutputType,
4398 TosaErrorValidator.evWrongInputList,
4399 TosaErrorValidator.evWrongOutputList,
4400 TosaErrorValidator.evWrongRank,
4401 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004402 "data_gen": {
4403 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4404 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004405 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004406 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004407 "resize": {
4408 "op": Op.RESIZE,
4409 "operands": (1, 0),
4410 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004411 "build_fcn": (
4412 build_resize,
4413 TosaTensorGen.tgNHWC,
4414 TosaTensorValuesGen.tvgDefault,
4415 TosaArgGen.agResize,
4416 ),
James Ward24dbc422022-10-19 12:20:31 +01004417 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004418 "invalid_test_validators": (
4419 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004420 ),
4421 "error_if_validators": (
4422 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004423 TosaErrorValidator.evScaleSmallerEqualZero,
4424 TosaErrorValidator.evScaleNLargerMax,
4425 TosaErrorValidator.evScaleDLargerMax,
4426 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004427 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004428 TosaErrorValidator.evBorderSmallerMin,
4429 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004430 TosaErrorValidator.evWrongInputType,
4431 TosaErrorValidator.evWrongOutputType,
4432 TosaErrorValidator.evWrongRank,
4433 TosaErrorValidator.evWrongInputList,
4434 TosaErrorValidator.evWrongOutputList,
4435 TosaErrorValidator.evBatchMismatch,
4436 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004437 TosaErrorValidator.evResizeOutputShapeMismatch,
4438 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004439 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004440 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004441 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004442 "cast": {
4443 "op": Op.CAST,
4444 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004445 "build_fcn": (
4446 build_cast,
4447 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004448 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004449 TosaArgGen.agCast,
4450 ),
James Ward8b390432022-08-12 20:48:56 +01004451 "types": (
4452 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004453 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004454 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004455 DType.INT8,
4456 DType.INT16,
4457 DType.INT32,
4458 DType.BOOL,
4459 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004460 "error_if_validators": (
4461 TosaErrorValidator.evWrongInputType,
4462 TosaErrorValidator.evWrongOutputType,
4463 TosaErrorValidator.evWrongInputList,
4464 TosaErrorValidator.evWrongOutputList,
4465 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004466 "data_gen": {
4467 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4468 },
4469 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004470 },
4471 "rescale": {
4472 "op": Op.RESCALE,
4473 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004474 "build_fcn": (
4475 build_rescale,
4476 TosaTensorGen.tgBasic,
4477 TosaTensorValuesGen.tvgDefault,
4478 TosaArgGen.agRescale,
4479 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004480 "types": [
4481 DType.UINT8,
4482 DType.INT8,
4483 DType.INT16,
4484 DType.INT32,
4485 DType.INT48,
4486 DType.UINT16,
4487 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004488 "error_if_validators": (
4489 TosaErrorValidator.evInputZeroPointNotZero,
4490 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004491 TosaErrorValidator.evU16InputZeroPointNotValid,
4492 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004493 TosaErrorValidator.evScaleTrue,
4494 TosaErrorValidator.evScaleNotTrue,
4495 TosaErrorValidator.evWrongInputType,
4496 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004497 TosaErrorValidator.evWrongInputList,
4498 TosaErrorValidator.evWrongOutputList,
4499 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004500 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004501 # Custom
4502 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004503 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004504 # Two varients of cond_if, one that generates one of two constant tensors (no
4505 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4506 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004507 "cond_if_const": {
4508 "op": Op.COND_IF,
4509 "operands": (0, 2),
4510 "build_fcn": (
4511 build_cond_if_const,
4512 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004513 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004514 TosaArgGen.agCondIf,
4515 ),
4516 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004517 "error_if_validators": (
4518 TosaErrorValidator.evOutputListThenGraphMismatch,
4519 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004520 TosaErrorValidator.evCondIfCondNotMatchingBool,
4521 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004522 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004523 },
4524 "cond_if_binary": {
4525 "op": Op.COND_IF,
4526 "operands": (2, 0),
4527 "build_fcn": (
4528 build_cond_if_binary,
4529 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004530 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004531 TosaArgGen.agCondIf,
4532 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004533 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004534 "error_if_validators": (
4535 TosaErrorValidator.evInputListThenGraphMismatch,
4536 TosaErrorValidator.evInputListElseGraphMismatch,
4537 TosaErrorValidator.evOutputListThenGraphMismatch,
4538 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004539 TosaErrorValidator.evCondIfCondNotMatchingBool,
4540 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004541 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004542 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004543 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004544 "while_loop": {
4545 "op": Op.WHILE_LOOP,
4546 "operands": (0, 1),
4547 "build_fcn": (
4548 build_while_loop,
4549 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004550 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004551 TosaArgGen.agWhileLoop,
4552 ),
4553 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004554 "error_if_validators": (
4555 TosaErrorValidator.evInputListOutputListMismatch,
4556 TosaErrorValidator.evInputListCondGraphMismatch,
4557 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4558 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4559 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004560 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004561 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004562 },
Luke Hutton57287132023-02-06 14:54:18 +00004563 "fft2d": {
4564 "op": Op.FFT2D,
4565 "operands": (2, 0),
4566 "rank": (3, 3),
4567 "build_fcn": (
4568 build_fft2d,
4569 TosaTensorGen.tgFFT2d,
4570 TosaTensorValuesGen.tvgDefault,
4571 TosaArgGen.agFFT2d,
4572 ),
4573 "types": [DType.FP32],
4574 "error_if_validators": (
4575 TosaErrorValidator.evWrongInputType,
4576 TosaErrorValidator.evWrongOutputType,
4577 TosaErrorValidator.evWrongInputList,
4578 TosaErrorValidator.evWrongOutputList,
4579 TosaErrorValidator.evWrongRank,
4580 TosaErrorValidator.evBatchMismatch,
4581 TosaErrorValidator.evKernelNotPowerOfTwo,
4582 TosaErrorValidator.evFFTInputShapeMismatch,
4583 TosaErrorValidator.evFFTOutputShapeMismatch,
4584 ),
4585 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004586 "rfft2d": {
4587 "op": Op.RFFT2D,
4588 "operands": (1, 0),
4589 "rank": (3, 3),
4590 "build_fcn": (
4591 build_rfft2d,
4592 TosaTensorGen.tgRFFT2d,
4593 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004594 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004595 ),
4596 "types": [DType.FP32],
4597 "error_if_validators": (
4598 TosaErrorValidator.evWrongInputType,
4599 TosaErrorValidator.evWrongOutputType,
4600 TosaErrorValidator.evWrongInputList,
4601 TosaErrorValidator.evWrongOutputList,
4602 TosaErrorValidator.evWrongRank,
4603 TosaErrorValidator.evBatchMismatch,
4604 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004605 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004606 ),
4607 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004608 }
4609
Kevin Cheng550ccc52021-03-03 11:21:43 -08004610
Eric Kunzee5e26762020-10-13 16:11:07 -07004611class OutputShaper:
4612 # Methods in this class compute the expected output shape and datatype
4613 # for common classes of operations
4614 def __init__(self):
4615 pass
4616
4617 # These methods return arguments that can be used for
4618 # creating a new output tensor
4619 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004620 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4621 if error_name != ErrorIf.RankMismatch:
4622 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004623 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004624
4625 shape = []
4626 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004627 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004628 shape.append(b.shape[i])
4629 else:
4630 shape.append(a.shape[i])
4631
Jerry Ge135c9552023-05-23 20:59:32 +00004632 fuzz_idx = rng.integers(0, len(a.shape))
4633 if error_name == ErrorIf.DimensionMismatch:
4634 shape[fuzz_idx] += 1
4635
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004636 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004637 all_dtypes = [
4638 DType.INT8,
4639 DType.INT16,
4640 DType.INT32,
4641 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004642 DType.FP16,
4643 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004644 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004645 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004646 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4647 outputDType = rng.choice(wrong_dtypes)
4648 else:
4649 outputDType = a.dtype
4650
4651 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004652
4653 @staticmethod
4654 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004655 assert len(a.shape) == len(b.shape)
4656 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004657
4658 shape = []
4659 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004660 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004661 shape.append(a.shape[i])
4662
Kevin Cheng550ccc52021-03-03 11:21:43 -08004663 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004664
4665 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004666 def unaryOp(ser, rng, a, error_name=None):
4667 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004668 all_dtypes = [
4669 DType.INT8,
4670 DType.INT16,
4671 DType.INT32,
4672 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004673 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004674 DType.FP16,
4675 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004676 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004677 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4678 outputDType = rng.choice(wrong_dtypes)
4679 else:
4680 outputDType = a.dtype
4681
4682 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004683
4684 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004685 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004686 if error_name != ErrorIf.RankMismatch:
4687 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004688 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004689
4690 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004691 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004692 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004693 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4694 else:
4695 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004696
Jerry Ge135c9552023-05-23 20:59:32 +00004697 fuzz_idx = rng.integers(0, len(a.shape))
4698 if error_name == ErrorIf.DimensionMismatch:
4699 shape[fuzz_idx] += 1
4700
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004701 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004702 all_dtypes = [
4703 DType.INT8,
4704 DType.INT16,
4705 DType.INT32,
4706 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004707 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004708 DType.FP16,
4709 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004710 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004711 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4712 outputDType = rng.choice(wrong_dtypes)
4713 else:
4714 outputDType = a.dtype
4715
4716 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004717
4718 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004719 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004720 if error_name != ErrorIf.RankMismatch:
4721 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004722 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004723
4724 # Do broadcast
4725 shape = []
4726 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004727 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004728 shape.append(b.shape[i])
4729 else:
4730 shape.append(a.shape[i])
4731
Jerry Ge135c9552023-05-23 20:59:32 +00004732 fuzz_idx = rng.integers(0, len(a.shape))
4733 if error_name == ErrorIf.DimensionMismatch:
4734 shape[fuzz_idx] += 1
4735
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004736 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004737 wrong_dtypes = [
4738 DType.INT8,
4739 DType.INT16,
4740 DType.INT32,
4741 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004742 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004743 DType.FP16,
4744 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004745 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004746 outputDType = rng.choice(wrong_dtypes)
4747 else:
4748 outputDType = DType.BOOL
4749
4750 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004751
4752 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004753 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004754 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004755 if error_name not in [
4756 ErrorIf.AxisSmallerZero,
4757 ErrorIf.AxisLargerRank,
4758 ErrorIf.ShapeOfAxisNotOne,
4759 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004760 shape[axis] = 1
4761 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4762 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004763
Matthew Haddond6ce7252021-09-29 15:35:44 +01004764 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004765 all_dtypes = [
4766 DType.INT8,
4767 DType.INT16,
4768 DType.INT32,
4769 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004770 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004771 DType.FP16,
4772 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004773 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004774 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4775 outputDType = rng.choice(wrong_dtypes)
4776 else:
4777 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004778
Matthew Haddond6ce7252021-09-29 15:35:44 +01004779 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004780
4781 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004782 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004783 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004784
4785 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4786 del shape[axis]
4787
4788 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4789 remove = rng.choice([True, False])
4790 if remove and len(shape) > 1:
4791 del shape[0]
4792 else:
4793 shape.append(1)
4794 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4795 for i in range(len(shape)):
4796 shape[i] = shape[i] + rng.integers(1, 10)
4797
4798 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004799 all_dtypes = [
4800 DType.INT8,
4801 DType.INT16,
4802 DType.INT32,
4803 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004804 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004805 DType.FP16,
4806 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004807 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004808 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4809 outputDType = rng.choice(wrong_dtypes)
4810 else:
4811 outputDType = DType.INT32
4812
4813 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004814
4815 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004816 def conv2dOp(
4817 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4818 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004819
4820 # IFM: NHWC
4821 # Filter: OHWI
4822 # OFM: NHWC
4823
Kevin Cheng550ccc52021-03-03 11:21:43 -08004824 h = (
4825 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004826 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004827 + padding[0]
4828 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004829 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004830 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004831
Kevin Cheng550ccc52021-03-03 11:21:43 -08004832 w = (
4833 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004834 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004835 + padding[2]
4836 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004837 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004838 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004839
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004840 if error_name == ErrorIf.ConvOutputShapeMismatch:
4841 choices = [1, 2, 3]
4842 change = rng.choice(choices)
4843 # increment in multiples of stride to not hit non-integer error case
4844 if change in [1, 3]:
4845 h = h + (rng.choice(choices) * strides[0])
4846 if change in [2, 3]:
4847 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004848
Eric Kunzee5e26762020-10-13 16:11:07 -07004849 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4850
James Ward8b390432022-08-12 20:48:56 +01004851 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004852 # Pick some potentially correct output dtype if input type is incorrect
4853 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004854 else:
James Ward8b390432022-08-12 20:48:56 +01004855 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004856
4857 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004858 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004859 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004860 else:
4861 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004862 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004863 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004864
Kevin Cheng550ccc52021-03-03 11:21:43 -08004865 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004866
4867 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004868 def conv3dOp(
4869 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4870 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004871
4872 # IFM: NDHWC
4873 # Filter: ODHWI
4874 # OFM: NDHWC
4875
4876 d = (
4877 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004878 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004879 + padding[0]
4880 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004881 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004882 ) // strides[0] + 1
4883
4884 h = (
4885 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004886 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004887 + padding[2]
4888 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004889 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004890 ) // strides[1] + 1
4891
4892 w = (
4893 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004894 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004895 + padding[4]
4896 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004897 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004898 ) // strides[2] + 1
4899
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004900 if error_name == ErrorIf.ConvOutputShapeMismatch:
4901 choices = [1, 2, 3, 4]
4902 change = rng.choice(choices)
4903 # increment in multiples of stride to not hit non-integer error case
4904 if change in [1, 4]:
4905 d = d + (rng.choice(choices) * strides[0])
4906 if change in [2, 4]:
4907 h = h + (rng.choice(choices) * strides[1])
4908 if change in [3, 4]:
4909 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004910
Kevin Cheng1533b852021-09-01 12:51:58 -07004911 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4912
James Ward8b390432022-08-12 20:48:56 +01004913 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004914 # Pick some potentially correct output dtype if input type is incorrect
4915 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004916 else:
James Ward8b390432022-08-12 20:48:56 +01004917 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004918
4919 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004920 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004921 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004922 else:
4923 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004924 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004925 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004926
4927 return ser.addOutput(ofm_shape, out_dtype)
4928
4929 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004930 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004931 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004932 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004933 # IFM: NHWC
4934 # Filter: HWCM
4935 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004936
Kevin Cheng550ccc52021-03-03 11:21:43 -08004937 h = (
4938 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004939 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004940 + padding[0]
4941 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004942 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004943 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004944
Kevin Cheng550ccc52021-03-03 11:21:43 -08004945 w = (
4946 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004947 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004948 + padding[2]
4949 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004950 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004951 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004952
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004953 if error_name == ErrorIf.ConvOutputShapeMismatch:
4954 choices = [1, 2, 3]
4955 change = rng.choice(choices)
4956 # increment in multiples of stride to not hit non-integer error case
4957 if change in [1, 3]:
4958 h = h + (rng.choice(choices) * strides[0])
4959 if change in [2, 3]:
4960 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004961
Eric Kunzee5e26762020-10-13 16:11:07 -07004962 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4963
James Ward8b390432022-08-12 20:48:56 +01004964 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004965 # Pick some potentially correct output dtype if input type is incorrect
4966 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004967 else:
James Ward8b390432022-08-12 20:48:56 +01004968 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004969
4970 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004971 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004972 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004973 else:
4974 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004975 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004976 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004977
Kevin Cheng550ccc52021-03-03 11:21:43 -08004978 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004979
4980 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004981 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004982 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004983 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004984 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004985 h = 1
4986 w = 1
4987 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004988 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4989 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004990
4991 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004992 choices = [1, 2, 3]
4993 change = rng.choice(choices)
4994 # increment in multiples of stride to not hit non-integer error case
4995 if change in [1, 3]:
4996 h = h + (rng.choice(choices) * stride[0])
4997 if change in [2, 3]:
4998 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004999 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005000
5001 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005002 all_dtypes = [
5003 DType.INT8,
5004 DType.INT16,
5005 DType.INT32,
5006 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005007 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005008 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005009 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005010 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005011 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5012 outputDType = rng.choice(wrong_dtypes)
5013 else:
5014 outputDType = ifm.dtype
5015
5016 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005017
5018 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005019 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005020 # input: N, IC
5021 # filter: OC, IC
5022 # output: N, OC
5023
5024 output_shape = [input.shape[0], filter.shape[0]]
5025
James Ward8b390432022-08-12 20:48:56 +01005026 # Validated in arg_gen (also invalidated for ErrorIf)
5027 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005028
Kevin Cheng550ccc52021-03-03 11:21:43 -08005029 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005030
5031 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005032 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005033 # a: N, H, C
5034 # b: N, C, W
5035 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005036
Kevin Cheng2d60f002021-06-09 14:18:32 -07005037 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005038
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005039 if error_name == ErrorIf.WrongOutputType:
5040 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005041 incorrect_types = (
5042 DType.INT4,
5043 DType.INT8,
5044 DType.INT16,
5045 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005046 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005047 DType.FP16,
5048 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005049 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005050 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005051 incorrect_types = (
5052 DType.INT4,
5053 DType.INT8,
5054 DType.INT16,
5055 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005056 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005057 DType.FP16,
5058 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005059 )
James Ward24dbc422022-10-19 12:20:31 +01005060 elif (
5061 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5062 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005063 incorrect_types = (
5064 DType.INT4,
5065 DType.INT8,
5066 DType.INT16,
5067 DType.INT32,
5068 DType.INT48,
5069 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005070 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005071 elif error_name == ErrorIf.WrongInputType:
5072 # Pick some potentially correct output dtype if input type is incorrect
5073 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005074 else:
James Ward8b390432022-08-12 20:48:56 +01005075 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005076
Kevin Cheng550ccc52021-03-03 11:21:43 -08005077 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005078
5079 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005080 def concatOp(ser, rng, axis, inputs, error_name=None):
5081 input1 = inputs[0]
5082 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005083
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005084 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005085 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005086 if not (
5087 # unable to concat tensors of different ranks
5088 error_name == ErrorIf.ConcatInputRankMismatch
5089 # unable to concat tensors along an invalid axis
5090 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005091 ):
5092 for tensor in remaining_inputs:
5093 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005094
Matthew Haddon01c359d2021-10-15 16:30:48 +01005095 if error_name == ErrorIf.ConcatShapeSumMismatch:
5096 output_shape[axis] += rng.integers(5, 10)
5097
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005098 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005099 all_dtypes = {
5100 DType.INT8,
5101 DType.INT16,
5102 DType.INT32,
5103 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005104 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005105 DType.FP16,
5106 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005107 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005108 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5109 outputDType = rng.choice(wrong_dtypes)
5110 else:
5111 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005112
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005113 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005114
5115 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005116 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005117
5118 output_shape = a.shape.copy()
5119
5120 for i in range(len(output_shape)):
5121 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5122
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005123 if error_name == ErrorIf.PadOutputShapeMismatch:
5124 bad_dim = rng.choice(range(len(output_shape)))
5125 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005126 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005127 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005128
Matthew Haddone807aae2021-10-11 18:12:58 +01005129 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005130 all_dtypes = [
5131 DType.INT8,
5132 DType.INT16,
5133 DType.INT32,
5134 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005135 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005136 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005137 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005138 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005139 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5140 outputDType = rng.choice(wrong_dtypes)
5141 else:
5142 outputDType = a.dtype
5143
5144 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005145
5146 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005147 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005148 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005149
5150 if error_name == ErrorIf.WrongOutputType:
5151 all_dtypes = [
5152 DType.INT8,
5153 DType.INT16,
5154 DType.INT32,
5155 DType.INT48,
5156 DType.FP32,
5157 DType.FP16,
5158 DType.BF16,
5159 ]
5160 wrong_dtypes = list(set(all_dtypes))
5161 outputDType = rng.choice(wrong_dtypes)
5162 else:
5163 outputDType = DType.SHAPE
5164
5165 return ser.addOutput(output_shape, outputDType)
5166
5167 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005168 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005169 output_shape = shape.copy()
5170
Matthew Haddone807aae2021-10-11 18:12:58 +01005171 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5172 for i in range(len(output_shape)):
5173 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5174
5175 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005176 all_dtypes = [
5177 DType.INT8,
5178 DType.INT16,
5179 DType.INT32,
5180 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005181 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005182 DType.FP16,
5183 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005184 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005185 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5186 outputDType = rng.choice(wrong_dtypes)
5187 else:
5188 outputDType = a.dtype
5189
5190 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005191
5192 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005193 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005194
Matthew Haddone807aae2021-10-11 18:12:58 +01005195 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005196 all_dtypes = [
5197 DType.INT8,
5198 DType.INT16,
5199 DType.INT32,
5200 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005201 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005202 DType.FP16,
5203 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005204 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005205 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005206 outputDType = rng.choice(wrong_dtypes)
5207 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005208 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005209
Luke Huttona4e48ca2023-02-22 11:53:48 +00005210 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005211 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005212 for index in range(len(output_shape)):
5213 if output_shape[index] <= 2:
5214 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5215 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005216 output_shape[index] = output_shape[index] + rng.choice(
5217 [-2, -1, 1, 2]
5218 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005219 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5220 output_shape = input.shape.copy()
5221 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005222 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005223
5224 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005225
5226 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005227 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005228
5229 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005230 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005231
5232 for i in range(len(output_shape)):
5233 output_shape[i] = a.shape[i] * multiples[i]
5234
Luke Huttona4e48ca2023-02-22 11:53:48 +00005235 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005236 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005237
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005238 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005239 all_dtypes = [
5240 DType.INT8,
5241 DType.INT16,
5242 DType.INT32,
5243 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005244 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005245 DType.FP16,
5246 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005247 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005248 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5249 outputDType = rng.choice(wrong_dtypes)
5250 else:
5251 outputDType = a.dtype
5252
5253 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005254
5255 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005256 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005257 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005258
Kevin Cheng550ccc52021-03-03 11:21:43 -08005259 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005260
Luke Huttona4e48ca2023-02-22 11:53:48 +00005261 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005262 for i in range(len(output_shape)):
5263 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005264
Luke Huttona4e48ca2023-02-22 11:53:48 +00005265 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5266 for i in range(len(output_shape)):
5267 output_shape[i] += rng.integers(1, 10)
5268 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005269 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005270
Matthew Haddone807aae2021-10-11 18:12:58 +01005271 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005272 all_dtypes = [
5273 DType.INT8,
5274 DType.INT16,
5275 DType.INT32,
5276 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005277 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005278 DType.FP16,
5279 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005280 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005281 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5282 outputDType = rng.choice(wrong_dtypes)
5283 else:
5284 outputDType = a.dtype
5285
5286 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005287
5288 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005289 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005290 if error_name != ErrorIf.WrongRank:
5291 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005292 assert len(indices.shape) == 2
5293 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005294
Kevin Cheng77d0f762020-11-24 10:26:32 -08005295 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5296
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005297 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005298 all_dtypes = [
5299 DType.INT8,
5300 DType.INT16,
5301 DType.INT32,
5302 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005303 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005304 DType.FP16,
5305 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005306 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005307 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5308 outputDType = rng.choice(wrong_dtypes)
5309 else:
5310 outputDType = values.dtype
5311
5312 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005313
5314 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005315 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005316 if error_name != ErrorIf.WrongRank:
5317 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005318 assert len(indices.shape) == 2
5319 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005320 assert values_in.shape[0] == indices.shape[0] # N
5321 assert input.shape[1] == indices.shape[1] # W
5322 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005323
5324 output_shape = values_in.shape
5325
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005326 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005327 all_dtypes = [
5328 DType.INT8,
5329 DType.INT16,
5330 DType.INT32,
5331 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005332 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005333 DType.FP16,
5334 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005335 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005336 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5337 outputDType = rng.choice(wrong_dtypes)
5338 else:
5339 outputDType = values_in.dtype
5340
5341 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005342
5343 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005344 def tableOp(ser, rng, input, error_name=None):
5345 # Same shape as the input, dtype dependent on input dtype
5346 if error_name != ErrorIf.WrongInputType:
5347 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005348 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005349 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005350 wrong_dtypes = [
5351 DType.INT8,
5352 DType.INT16,
5353 DType.INT32,
5354 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005355 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005356 DType.FP16,
5357 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005358 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005359 wrong_dtypes.remove(output_dtype)
5360 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005361 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005362
5363 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005364 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005365 serializer,
5366 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005367 input,
5368 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005369 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005370 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005371 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005372 input_dtype,
5373 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005374 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005375 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005376 # Calculate OH, OW
5377 scale_y_n = scale[0]
5378 scale_y_d = scale[1]
5379 scale_x_n = scale[2]
5380 scale_x_d = scale[3]
5381 if error_name == ErrorIf.ScaleSmallerEqualZero:
5382 scale_y_n = max(scale_y_n, 1)
5383 scale_y_d = max(scale_y_d, 1)
5384 scale_x_n = max(scale_x_n, 1)
5385 scale_x_d = max(scale_x_d, 1)
5386
5387 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5388 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5389
5390 if error_name is not None:
5391 # Make sure the output tensor is valid, which can occur when
5392 # scale, offset or border have been changed for ERROR_IFs
5393 oh = max(oh, 1)
5394 ow = max(ow, 1)
5395 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005396 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5397 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005398
5399 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5400 choices = [1, 2, 3]
5401 change = rng.choice(choices)
5402 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5403 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005404 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005405 oh -= scale_y_d
5406 assert oh > 0 # Should have been caught in agResize
5407 else:
5408 oh += scale_y_d
5409 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005410 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005411 ow -= scale_x_d
5412 assert ow > 0 # Should have been caught in agResize
5413 else:
5414 ow += scale_x_d
5415
Matthew Haddon848efb42021-09-09 12:30:53 +01005416 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005417 output_dims = [
5418 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005419 oh,
5420 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005421 input.shape[0],
5422 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005423 elif error_name == ErrorIf.BatchMismatch:
5424 output_dims = [
5425 input.shape[0] + rng.integers(1, 10),
5426 oh,
5427 ow,
5428 input.shape[3],
5429 ]
5430 elif error_name == ErrorIf.ChannelMismatch:
5431 output_dims = [
5432 input.shape[0],
5433 oh,
5434 ow,
5435 input.shape[3] + rng.integers(1, 10),
5436 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005437 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005438 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005439
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005440 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005441
5442 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005443 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005444 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005445
5446 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005447 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005448 if error_name == ErrorIf.ConvOutputShapeMismatch:
5449 choices = [1, 2, 3]
5450 change = rng.choice(choices)
5451 if change in [1, 3]:
5452 output_shape[1] = output_shape[1] + rng.choice(choices)
5453 if change in [2, 3]:
5454 output_shape[2] = output_shape[2] + rng.choice(choices)
5455
James Ward8b390432022-08-12 20:48:56 +01005456 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005457 # Pick some potentially correct output dtype if input type is incorrect
5458 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005459 else:
James Ward8b390432022-08-12 20:48:56 +01005460 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005461
5462 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005463 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005464 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005465 else:
5466 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005467 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005468 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005469
Kevin Cheng550ccc52021-03-03 11:21:43 -08005470 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005471
5472 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005473 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5474 outputs = []
5475
5476 assert ifm1.dtype == ifm2.dtype
5477 input_dtype = ifm1.dtype
5478
5479 if error_name != ErrorIf.FFTInputShapeMismatch:
5480 assert ifm1.shape == ifm2.shape
5481
5482 input_shape = ifm1.shape
5483 if error_name != ErrorIf.WrongRank:
5484 assert len(input_shape) == 3
5485
5486 output_shape = input_shape.copy()
5487 output_dtype = input_dtype
5488
5489 if error_name == ErrorIf.WrongOutputType:
5490 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005491 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005492 output_dtype = rng.choice(wrong_dtypes)
5493 elif error_name == ErrorIf.BatchMismatch:
5494 output_shape[0] += rng.integers(1, 10)
5495 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5496 modify_dim = rng.choice([1, 2])
5497 output_shape[modify_dim] += rng.integers(1, 10)
5498
5499 outputs.append(serializer.addOutput(output_shape, output_dtype))
5500 outputs.append(serializer.addOutput(output_shape, output_dtype))
5501 return outputs
5502
5503 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005504 def rfft2dOp(serializer, rng, value, error_name=None):
5505 outputs = []
5506
5507 input_shape = value.shape
5508 if error_name != ErrorIf.WrongRank:
5509 assert len(input_shape) == 3
5510
5511 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5512
5513 output_dtype = value.dtype
5514 if error_name == ErrorIf.WrongOutputType:
5515 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005516 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005517 output_dtype = rng.choice(wrong_dtypes)
5518 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005519 output_shape[0] += rng.integers(1, 10)
5520 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5521 modify_dim = rng.choice([1, 2])
5522 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005523
5524 outputs.append(serializer.addOutput(output_shape, output_dtype))
5525 outputs.append(serializer.addOutput(output_shape, output_dtype))
5526 return outputs