blob: d4dbef4e0b8675e73b0c4e92ce4c82c6bcce92b7 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
79 for dtype in (DType.FP32, DType.FP16, DType.BF16):
80 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
170 elif dtype in (DType.INT32, DType.SHAPE):
171 # restricting too large value for SHAPE
172 rng = (-(1 << 31), (1 << 31))
173 elif dtype == DType.INT48:
174 rng = (-(1 << 47), (1 << 47))
175 else:
176 raise Exception("Unknown dtype: {}".format(dtype))
177
178 if not high_inclusive:
179 # Exclusive high: low <= range < high
180 return rng
181 else:
182 # Inclusive range: low <= range <= high
183 return (rng[0], rng[1] - 1)
184
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000185 def getRandTensor(self, shape, dtype, data_range=None):
186 if data_range is None:
187 low, high = self.getDTypeRange(dtype)
188 else:
189 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100190
Eric Kunzee5e26762020-10-13 16:11:07 -0700191 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100194 return np.int64(self.rng.integers(low=low, high=high, size=shape))
195 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
196 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
197
198 if dtype == DType.FP16:
199 return np.float16(f_tensor)
200 else:
201 f32_tensor = np.float32(f_tensor)
202 if dtype == DType.BF16:
203 # Floor the last 16 bits of each f32 value
204 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
205 else:
206 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700207 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100208 # All other integer types
209 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 placeholders = []
213
Kevin Cheng989cb052021-04-28 16:29:44 -0700214 assert len(shape_list) == len(dtype_list)
215
Jeremy Johnson1271c442023-09-05 11:39:26 +0100216 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700217 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100218 if not self.args.lazy_data_gen:
219 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700220 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700221
222 return placeholders
223
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700225 consts = []
226
Kevin Cheng989cb052021-04-28 16:29:44 -0700227 assert len(shape_list) == len(dtype_list)
228
Jeremy Johnson1271c442023-09-05 11:39:26 +0100229 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700230 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100231 if not self.args.lazy_data_gen:
232 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700233 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700234
235 return consts
236
237 def makeShape(self, rank):
238 if self.targetted_shape:
239 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800240 return np.int32(
241 self.rng.integers(
242 low=self.args.tensor_shape_range[0],
243 high=self.args.tensor_shape_range[1],
244 size=rank,
245 )
246 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700247
248 def setTargetShape(self, shape):
249 self.targetted_shape = shape
250
251 def randInt(self, low=0, high=256):
252 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
253
254 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100255 low, high = self.getDTypeRange(dtype)
256
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100257 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100258 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100259 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100261 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100262 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
263 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700264 elif dtype == DType.BOOL:
265 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700266 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700267 # Special size
268 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700269
270 return np.int32(self.rng.integers(low, high, size=1))[0]
271
272 def shapeStr(self, shape):
273
274 sStr = []
275 # Convert to strings
276 for i in shape:
277 sStr.append(str(i))
278
Kevin Cheng550ccc52021-03-03 11:21:43 -0800279 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700280
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100281 def typeStr(self, dtype):
282 if isinstance(dtype, list) or isinstance(dtype, tuple):
283 assert len(dtype) >= 2
284 strs = [self.typeStr(t) for t in dtype]
285 # Limit types to the first 2 as the 3rd is the accumulator
286 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700287 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100288 if dtype in gtu.DTYPE_ATTRIBUTES:
289 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700290 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100291 raise Exception(
292 "Unknown dtype, cannot convert to string: {}".format(dtype)
293 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700294
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100295 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100296 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100297 if dtype in gtu.DTYPE_ATTRIBUTES:
298 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700299 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100300 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700301
Luke Hutton57287132023-02-06 14:54:18 +0000302 def constrictBatchSize(self, shape):
303 # Limit the batch size unless an explicit target shape set
304 if self.args.max_batch_size and not self.args.target_shapes:
305 shape[0] = min(shape[0], self.args.max_batch_size)
306 return shape
307
James Ward30124a82023-02-02 14:56:33 +0000308 def makeDimension(self):
309 return self.randInt(
310 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
311 )
312
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100313 def tensorComplianceMetaData(
314 self, op, inputType, argsDict, outputTensor, errorName
315 ):
Jeremy Johnson708da822023-11-15 16:25:45 +0000316 # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet
317 UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100318 if (
319 errorName
320 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000321 or (
322 not gtu.dtypeIsSupportedByCompliance(inputType)
323 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
324 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100325 ):
326 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100327 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100328
Jeremy Johnson1271c442023-09-05 11:39:26 +0100329 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100330 compliance_tens = {
331 "mode": None,
332 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
333 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
334 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100335 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
336 mode = gtu.ComplianceMode.DOT_PRODUCT
337 compliance_tens["dot_product_info"] = {
338 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100339 "ks": int(argsDict["ksb"])
340 if "ksb" in argsDict
341 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100342 }
343 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
344 mode = gtu.ComplianceMode.FP_SPECIAL
345 elif "compliance" in op and "ulp" in op["compliance"]:
346 mode = gtu.ComplianceMode.ULP
347 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
348 elif op["op"] == Op.REDUCE_PRODUCT:
349 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000350 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000351 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000352 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000353 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
354 compliance_tens["abs_error_info"] = {
355 "lower_bound": op["compliance"]["abs_error_lower_bound"]
356 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100357 else:
358 mode = gtu.ComplianceMode.EXACT
359 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
360
361 return compliance_tens
362
363 # Build Op functions
364 # Create the output tensor (calling OutputShaper as needed)
365 # Do final tweaks to attributes (if necessary for errorIf)
366 # Add Op into graph
367 # Return resulting tensor information or BuildInfo
368
369 class BuildInfo:
370 """Enhanced build information containing result tensor and associated compliance dict."""
371
372 def __init__(self, resultTensor, complianceDict):
373 self.resultTensor = resultTensor
374 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700375
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000376 def build_unary(
377 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
378 ):
379 assert len(inputs) == 1
380 a = inputs[0]
381 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100382
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000383 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100384
385 # Ensure new output type has correct qinfo
386 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000387 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000388 qinfo = [
389 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000390 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000391 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100392
393 # Invalidate Input/Output list for error if checks.
394 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000395 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100396 pCount, cCount = op["operands"]
397 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000398 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
399 self, error_name, input_list, output_list
400 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100401
Les Bell729b0352021-11-24 10:28:21 +0000402 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100403 self.ser,
404 validator_fcns,
405 error_name,
406 op=op,
407 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000408 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000409 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000410 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100411 input_list=input_list,
412 output_list=output_list,
413 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000414 ):
415 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100416
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000417 attr = None
418 if op["op"] == Op.NEGATE:
419 attr = ts.TosaSerializerAttribute()
420 attr.NegateAttribute(qinfo[0], qinfo[1])
421
422 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000423
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000424 compliance = self.tensorComplianceMetaData(
425 op, a.dtype, args_dict, result_tensor, error_name
426 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000427 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700428
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000429 def build_binary_broadcast(
430 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
431 ):
432 assert len(inputs) == 2
433 a, b = inputs
434 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000435 self.ser, self.rng, a, b, error_name
436 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100437
438 # Invalidate Input/Output list for error if checks.
439 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000440 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100441 pCount, cCount = op["operands"]
442 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000443 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
444 self, error_name, input_list, output_list
445 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100446
Les Bell729b0352021-11-24 10:28:21 +0000447 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100448 self.ser,
449 validator_fcns,
450 error_name,
451 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000452 input1=a,
453 input2=b,
454 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000455 output_dtype=result_tensor.dtype,
456 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100457 input_list=input_list,
458 output_list=output_list,
459 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000460 ):
461 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100462
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000463 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000464
Jeremy Johnson9a758382023-11-07 16:27:35 +0000465 compliance = self.tensorComplianceMetaData(
466 op, a.dtype, args_dict, result_tensor, error_name
467 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000468
469 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700470
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100471 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700472 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000473 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700474 return result_tens
475
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000476 def build_arithmetic_right_shift(
477 self, op, a, b, round, validator_fcns=None, error_name=None
478 ):
479 result_tens = OutputShaper.binaryBroadcastOp(
480 self.ser, self.rng, a, b, error_name
481 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100482
483 # Invalidate Input/Output list for error if checks.
484 input_list = [a.name, b.name]
485 output_list = [result_tens.name]
486 pCount, cCount = op["operands"]
487 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000488 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
489 self, error_name, input_list, output_list
490 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100491
Les Bell729b0352021-11-24 10:28:21 +0000492 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100493 self.ser,
494 validator_fcns,
495 error_name,
496 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000497 input1=a,
498 input2=b,
499 input_dtype=a.dtype,
500 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000501 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100502 input_list=input_list,
503 output_list=output_list,
504 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000505 ):
506 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800507
508 attr = ts.TosaSerializerAttribute()
509 attr.ArithmeticRightShiftAttribute(round)
510
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000511 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800512 return result_tens
513
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100514 def build_mul(
515 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
516 ):
517 assert len(inputs) == 2
518 a, b = inputs
519 shift = args_dict["shift"]
520
521 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000522 self.ser, self.rng, a, b, error_name
523 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700524
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100525 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100526 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100527 result_tensor.setDtype(DType.INT32)
528
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100529 if error_name == ErrorIf.WrongOutputType:
530 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
531 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100532 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100533
534 # Invalidate Input/Output list for error if checks.
535 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100536 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100537 pCount, cCount = op["operands"]
538 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000539 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
540 self, error_name, input_list, output_list
541 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100542
Les Bell729b0352021-11-24 10:28:21 +0000543 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100544 self.ser,
545 validator_fcns,
546 error_name,
547 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000548 input1=a,
549 input2=b,
550 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100551 output_dtype=result_tensor.dtype,
552 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100553 input_list=input_list,
554 output_list=output_list,
555 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000556 ):
557 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700558
Kevin Chengaee1fac2020-11-11 13:54:06 -0800559 attr = ts.TosaSerializerAttribute()
560 attr.MulAttribute(shift)
561
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000562 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100563
564 compliance = self.tensorComplianceMetaData(
565 op, a.dtype, args_dict, result_tensor, error_name
566 )
567
568 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700569
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100570 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
571 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700572
Kevin Chengfe392ce2021-10-18 21:51:55 +0000573 attr = ts.TosaSerializerAttribute()
574 attr.TableAttribute(table)
575
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100576 # Invalidate Input/Output list for error if checks.
577 input_list = [a.name]
578 output_list = [result_tens.name]
579 pCount, cCount = op["operands"]
580 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000581 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
582 self, error_name, input_list, output_list
583 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100584
Les Bell729b0352021-11-24 10:28:21 +0000585 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100586 self.ser,
587 validator_fcns,
588 error_name,
589 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000590 input_shape=a.shape,
591 input_dtype=a.dtype,
592 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000593 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100594 input_list=input_list,
595 output_list=output_list,
596 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000597 ):
598 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100599
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000600 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700601
602 return result_tens
603
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000604 def build_select(
605 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
606 ):
607 assert len(inputs) == 3
608 cond, a, b = inputs
609
610 result_tensor = OutputShaper.selectOp(
611 self.ser, self.rng, cond, a, b, error_name
612 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100613
614 # Invalidate Input/Output list for error if checks.
615 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000616 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100617 pCount, cCount = op["operands"]
618 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000619 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
620 self, error_name, input_list, output_list
621 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622
Les Bell729b0352021-11-24 10:28:21 +0000623 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100624 self.ser,
625 validator_fcns,
626 error_name,
627 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000628 input1=cond,
629 input2=a,
630 input3=b,
631 input_shape=a.shape,
632 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000633 output_dtype=result_tensor.dtype,
634 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100635 input_list=input_list,
636 output_list=output_list,
637 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000638 ):
639 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100640
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000641 self.ser.addOperator(
642 op["op"],
643 input_list,
644 output_list,
645 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000646 compliance = self.tensorComplianceMetaData(
647 op, a.dtype, args_dict, result_tensor, error_name
648 )
649
650 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700651
Jeremy Johnsona0150012023-11-15 15:52:06 +0000652 def build_comparison(
653 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
654 ):
655 assert len(inputs) == 2
656 a, b = inputs
657
658 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000659 self.ser, self.rng, a, b, error_name
660 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100661
662 # Invalidate Input/Output list for error if checks.
663 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000664 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100665 pCount, cCount = op["operands"]
666 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
668 self, error_name, input_list, output_list
669 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100670
Les Bell729b0352021-11-24 10:28:21 +0000671 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100672 self.ser,
673 validator_fcns,
674 error_name,
675 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000676 input1=a,
677 input2=b,
678 input_shape=a.shape,
679 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000680 output_shape=result_tensor.shape,
681 output_dtype=result_tensor.dtype,
682 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100683 input_list=input_list,
684 output_list=output_list,
685 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000686 ):
687 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100688
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000689 self.ser.addOperator(
690 op["op"],
691 input_list,
692 output_list,
693 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000694
695 compliance = self.tensorComplianceMetaData(
696 op, a.dtype, args_dict, result_tensor, error_name
697 )
698 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700699
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000700 def build_argmax(
701 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
702 ):
703 assert len(inputs) == 1
704 a = inputs[0]
705 axis = args_dict["axis"]
706 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100707
708 # Invalidate Input/Output list for error if checks.
709 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000710 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100711 pCount, cCount = op["operands"]
712 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000713 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
714 self, error_name, input_list, output_list
715 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100716
Les Bell729b0352021-11-24 10:28:21 +0000717 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100718 self.ser,
719 validator_fcns,
720 error_name,
721 op=op,
722 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000723 input_shape=a.shape,
724 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000725 output_shape=result_tensor.shape,
726 output_dtype=result_tensor.dtype,
727 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100728 input_list=input_list,
729 output_list=output_list,
730 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000731 ):
732 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700733
734 attr = ts.TosaSerializerAttribute()
735 attr.AxisAttribute(axis)
736
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000737 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000738
739 compliance = self.tensorComplianceMetaData(
740 op, inputs[0].dtype, args_dict, result_tensor, error_name
741 )
742 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700743
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000744 def build_pool2d(
745 self,
746 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100747 inputs,
748 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000749 validator_fcns=None,
750 error_name=None,
751 qinfo=None,
752 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100753 assert len(inputs) == 1
754 input = inputs[0]
755 # max_pool has no accum_dtype
756 accum_dtype = (
757 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
758 )
759 stride = args_dict["stride"]
760 pad = args_dict["pad"]
761 kernel = args_dict["kernel"]
762
Jeremy Johnson0601f802023-11-08 16:28:09 +0000763 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000764 self.ser, self.rng, input, kernel, stride, pad, error_name
765 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100766
767 # Ensure new output type has correct qinfo
768 if error_name == ErrorIf.WrongInputType:
769 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000770 qinfo = [
771 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000772 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000773 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100774
775 # Invalidate Input/Output list for error if checks.
776 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000777 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100778 pCount, cCount = op["operands"]
779 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000780 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
781 self, error_name, input_list, output_list
782 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100783
Les Bell729b0352021-11-24 10:28:21 +0000784 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100785 self.ser,
786 validator_fcns,
787 error_name,
788 op=op,
789 input_shape=input.shape,
790 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000791 output_shape=result_tensor.shape,
792 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100793 kernel=kernel,
794 stride=stride,
795 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000796 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000797 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100798 input_list=input_list,
799 output_list=output_list,
800 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000801 ):
802 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700803
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000804 if qinfo is None:
805 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700806
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000807 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100808 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000809
810 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700811
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100812 compliance = self.tensorComplianceMetaData(
813 op, inputs[0].dtype, args_dict, result_tensor, error_name
814 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100815
816 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100817
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000818 def build_conv2d(
819 self,
820 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100821 inputs,
822 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000823 validator_fcns=None,
824 error_name=None,
825 qinfo=None,
826 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100827 assert len(inputs) == 3
828 ifm, filter, bias = inputs
829 accum_dtype = args_dict["acc_type"]
830 strides = args_dict["stride"]
831 padding = args_dict["pad"]
832 dilations = args_dict["dilation"]
833
Kevin Cheng550ccc52021-03-03 11:21:43 -0800834 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100835 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100836 self.ser,
837 self.rng,
838 ifm,
839 filter,
840 accum_dtype,
841 strides,
842 padding,
843 dilations,
844 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000845 )
846
847 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000848 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
849 DType.INT8,
850 DType.UINT8,
851 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000852 qinfo = [
853 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100854 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000855 ]
Les Bell0e027d42021-11-09 14:42:14 +0000856
857 # Invalidate Input/Output list for error_if checks.
858 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100859 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000860 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000861 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
862 self, error_name, input_list, output_list
863 )
Les Bell0e027d42021-11-09 14:42:14 +0000864
Les Bell729b0352021-11-24 10:28:21 +0000865 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000866 self.ser,
867 validator_fcns,
868 error_name,
869 op=op,
870 input_dtype=ifm.dtype,
871 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100872 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000873 qinfo=qinfo,
874 input_list=input_list,
875 num_operands=num_operands,
876 output_list=output_list,
877 pad=padding,
878 stride=strides,
879 dilation=dilations,
880 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100881 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100882 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000883 ):
884 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700885
Tai Lyd3797f02023-11-15 23:06:19 +0000886 # TODO - Test local_bound, for now set local bound attribute to False
887 local_bound = False
888
Eric Kunzee5e26762020-10-13 16:11:07 -0700889 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000890 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700891
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000892 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100893
894 compliance = self.tensorComplianceMetaData(
895 op, ifm.dtype, args_dict, result_tensor, error_name
896 )
897
898 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700899
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000900 def build_conv3d(
901 self,
902 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100903 inputs,
904 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000905 validator_fcns=None,
906 error_name=None,
907 qinfo=None,
908 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100909 assert len(inputs) == 3
910 ifm, filter, bias = inputs
911 accum_dtype = args_dict["acc_type"]
912 strides = args_dict["stride"]
913 padding = args_dict["pad"]
914 dilations = args_dict["dilation"]
915
Kevin Cheng1533b852021-09-01 12:51:58 -0700916 assert len(padding) == 6
917 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100918 self.ser,
919 self.rng,
920 ifm,
921 filter,
922 accum_dtype,
923 strides,
924 padding,
925 dilations,
926 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000927 )
928
929 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000930 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
931 DType.INT8,
932 DType.UINT8,
933 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000934 qinfo = [
935 TosaQuantGen.getZeroPoint(self, ifm.dtype),
936 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
937 ]
Les Bell0e027d42021-11-09 14:42:14 +0000938
939 # Invalidate Input/Output list for error_if checks.
940 input_list = [ifm.name, filter.name, bias.name]
941 output_list = [result_tens.name]
942 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000943 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
944 self, error_name, input_list, output_list
945 )
Les Bell0e027d42021-11-09 14:42:14 +0000946
Les Bell729b0352021-11-24 10:28:21 +0000947 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000948 self.ser,
949 validator_fcns,
950 error_name,
951 op=op,
952 input_dtype=ifm.dtype,
953 weight_dtype=filter.dtype,
954 output_dtype=result_tens.dtype,
955 qinfo=qinfo,
956 input_list=input_list,
957 num_operands=num_operands,
958 output_list=output_list,
959 pad=padding,
960 stride=strides,
961 dilation=dilations,
962 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100963 weight_shape=filter.shape,
964 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000965 ):
966 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700967
Tai Lyd3797f02023-11-15 23:06:19 +0000968 # TODO - Test local_bound, for now set local bound attribute to False
969 local_bound = False
970
Kevin Cheng1533b852021-09-01 12:51:58 -0700971 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000972 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700973
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000974 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700975 return result_tens
976
Kevin Cheng550ccc52021-03-03 11:21:43 -0800977 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000978 self,
979 op,
980 ifm,
981 filter,
982 bias,
James Ward8b390432022-08-12 20:48:56 +0100983 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000984 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700985 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000986 output_shape,
987 validator_fcns=None,
988 error_name=None,
989 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800990 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700991 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000992 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100993 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000994 )
Les Bell0e027d42021-11-09 14:42:14 +0000995
996 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000997 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
998 DType.INT8,
999 DType.UINT8,
1000 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001001 qinfo = [
1002 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1003 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1004 ]
Les Bell0e027d42021-11-09 14:42:14 +00001005
1006 # Invalidate Input/Output list for error_if checks.
1007 input_list = [ifm.name, filter.name, bias.name]
1008 output_list = [result_tens.name]
1009 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001010 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1011 self, error_name, input_list, output_list
1012 )
Les Bell0e027d42021-11-09 14:42:14 +00001013
Les Bell729b0352021-11-24 10:28:21 +00001014 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001015 self.ser,
1016 validator_fcns,
1017 error_name,
1018 op=op,
1019 input_dtype=ifm.dtype,
1020 weight_dtype=filter.dtype,
1021 output_dtype=result_tens.dtype,
1022 qinfo=qinfo,
1023 input_list=input_list,
1024 num_operands=num_operands,
1025 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001026 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001027 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001028 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001029 weight_shape=filter.shape,
1030 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001031 ):
1032 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001033
Tai Lyd3797f02023-11-15 23:06:19 +00001034 # TODO - Test local_bound, for now set local bound attribute to False
1035 local_bound = False
1036
Eric Kunzee5e26762020-10-13 16:11:07 -07001037 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001038 attr.TransposeConvAttribute(
1039 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1040 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001041
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001042 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001043 return result_tens
1044
Kevin Cheng550ccc52021-03-03 11:21:43 -08001045 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001046 self,
1047 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001048 inputs,
1049 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001050 validator_fcns=None,
1051 error_name=None,
1052 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001053 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001054 assert len(inputs) == 3
1055 ifm, filter, bias = inputs
1056 accum_dtype = args_dict["acc_type"]
1057 strides = args_dict["stride"]
1058 padding = args_dict["pad"]
1059 dilations = args_dict["dilation"]
1060
Kevin Cheng550ccc52021-03-03 11:21:43 -08001061 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001062 self.ser,
1063 self.rng,
1064 ifm,
1065 filter,
1066 accum_dtype,
1067 strides,
1068 padding,
1069 dilations,
1070 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001071 )
1072
1073 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001074 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1075 DType.INT8,
1076 DType.UINT8,
1077 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001078 qinfo = [
1079 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1080 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1081 ]
Les Bell0e027d42021-11-09 14:42:14 +00001082
1083 # Invalidate Input/Output list for error_if checks.
1084 input_list = [ifm.name, filter.name, bias.name]
1085 output_list = [result_tens.name]
1086 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001087 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1088 self, error_name, input_list, output_list
1089 )
Les Bell0e027d42021-11-09 14:42:14 +00001090
Les Bell729b0352021-11-24 10:28:21 +00001091 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001092 self.ser,
1093 validator_fcns,
1094 error_name,
1095 op=op,
1096 input_dtype=ifm.dtype,
1097 weight_dtype=filter.dtype,
1098 output_dtype=result_tens.dtype,
1099 qinfo=qinfo,
1100 input_list=input_list,
1101 num_operands=num_operands,
1102 output_list=output_list,
1103 pad=padding,
1104 stride=strides,
1105 dilation=dilations,
1106 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001107 weight_shape=filter.shape,
1108 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001109 ):
1110 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001111
Tai Lyd3797f02023-11-15 23:06:19 +00001112 # TODO - Test local_bound, for now set local bound attribute to False
1113 local_bound = False
1114
Eric Kunzee5e26762020-10-13 16:11:07 -07001115 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001116 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001117
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001118 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001119 return result_tens
1120
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001121 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001122 self,
1123 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001124 inputs,
1125 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001126 validator_fcns=None,
1127 error_name=None,
1128 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001129 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001130 assert len(inputs) == 3
1131 ifm, filter, bias = inputs
1132 accum_dtype = args_dict["acc_type"]
1133
1134 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001135 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001136 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001137
1138 # Invalidate Input/Output list for error if checks.
1139 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001140 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001141 pCount, cCount = op["operands"]
1142 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001143 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1144 self, error_name, input_list, output_list
1145 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001146
Les Bell729b0352021-11-24 10:28:21 +00001147 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001148 self.ser,
1149 validator_fcns,
1150 error_name,
1151 op=op,
1152 input_shape=ifm.shape,
1153 input_dtype=ifm.dtype,
1154 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001155 output_shape=result_tensor.shape,
1156 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001157 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001158 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001159 input_list=input_list,
1160 output_list=output_list,
1161 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001162 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001163 ):
1164 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001165
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001166 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001167 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001168
1169 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001170
1171 compliance = self.tensorComplianceMetaData(
1172 op, ifm.dtype, args_dict, result_tensor, error_name
1173 )
1174
1175 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001176
James Ward8b390432022-08-12 20:48:56 +01001177 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001178 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001179 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001180 assert len(inputs) == 2
1181 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001182 accum_dtype = args_dict["acc_type"]
1183 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001184 self.ser, self.rng, a, b, accum_dtype, error_name
1185 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001186
1187 # Invalidate Input/Output list for error if checks.
1188 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001189 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001190 pCount, cCount = op["operands"]
1191 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001192 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1193 self, error_name, input_list, output_list
1194 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001195
Les Bell729b0352021-11-24 10:28:21 +00001196 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001197 self.ser,
1198 validator_fcns,
1199 error_name,
1200 op=op,
1201 input_shape=a.shape,
1202 input_dtype=a.dtype,
1203 input2_shape=b.shape,
1204 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001205 output_shape=result_tensor.shape,
1206 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001207 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001208 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001209 input_list=input_list,
1210 output_list=output_list,
1211 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001212 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001213 ):
1214 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001215
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001216 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001217 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001218
1219 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001220
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001221 compliance = self.tensorComplianceMetaData(
1222 op, a.dtype, args_dict, result_tensor, error_name
1223 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001224
1225 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001226
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001227 def build_reduce(
1228 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1229 ):
1230 assert len(inputs) == 1
1231 a = inputs[0]
1232 axis = args_dict["axis"]
1233 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001234
1235 # Invalidate Input/Output list for error if checks.
1236 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001237 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001238 pCount, cCount = op["operands"]
1239 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001240 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1241 self, error_name, input_list, output_list
1242 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001243
Les Bell729b0352021-11-24 10:28:21 +00001244 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001245 self.ser,
1246 validator_fcns,
1247 error_name,
1248 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001249 axis=axis,
1250 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001251 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001252 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001253 output_dtype=result_tensor.dtype,
1254 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001255 input_list=input_list,
1256 output_list=output_list,
1257 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001258 ):
1259 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001260
1261 attr = ts.TosaSerializerAttribute()
1262 attr.AxisAttribute(axis)
1263
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001264 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001265
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001266 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1267 # Number of products - needed for compliance
1268 args_dict["n"] = a.shape[axis]
1269
1270 compliance = self.tensorComplianceMetaData(
1271 op, a.dtype, args_dict, result_tensor, error_name
1272 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001273
1274 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001275
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001276 def build_clamp(
1277 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1278 ):
1279 assert len(inputs) == 1
1280 a = inputs[0]
1281
1282 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001283
Jeremy Johnson18e26662021-07-22 16:15:29 +01001284 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001285
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001286 if error_name == ErrorIf.MaxSmallerMin:
1287 # Make sure the numbers are different to invoke this error
1288 while v[0] == v[1]:
1289 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1290 max_val = min(v)
1291 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001292 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001293 max_val = max(v)
1294 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001295
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001296 # Invalidate Input/Output list for error if checks.
1297 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001298 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001299 pCount, cCount = op["operands"]
1300 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001301 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1302 self, error_name, input_list, output_list
1303 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001304
Les Bell729b0352021-11-24 10:28:21 +00001305 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001306 self.ser,
1307 validator_fcns,
1308 error_name,
1309 op=op,
1310 max_val=max_val,
1311 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001312 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001313 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001314 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001315 output_dtype=result_tensor.dtype,
1316 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001317 input_list=input_list,
1318 output_list=output_list,
1319 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001320 ):
1321 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001322
1323 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001324 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1325 if a.dtype == DType.FP16:
1326 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1327 min_val = min_val.astype(np.float32)
1328 max_val = max_val.astype(np.float32)
1329
1330 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001331 else:
James Ward34071252022-12-07 15:48:47 +00001332 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001333
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001334 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001335
1336 compliance = self.tensorComplianceMetaData(
1337 op, a.dtype, args_dict, result_tensor, error_name
1338 )
1339
1340 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001341
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001342 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1343 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001344 attr = ts.TosaSerializerAttribute()
1345
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001346 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001347
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001348 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001349 return result_tens
1350
1351 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001352 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1353 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001354
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001355 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001356 return result_tens
1357
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001358 def build_activation(
1359 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1360 ):
1361 assert len(inputs) == 1
1362 a = inputs[0]
1363
1364 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001365
1366 # Invalidate Input/Output list for error if checks.
1367 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001368 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001369 pCount, cCount = op["operands"]
1370 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001371 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1372 self, error_name, input_list, output_list
1373 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374
Les Bell729b0352021-11-24 10:28:21 +00001375 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001376 self.ser,
1377 validator_fcns,
1378 error_name,
1379 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001380 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001381 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001382 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001383 output_dtype=result_tensor.dtype,
1384 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001385 input_list=input_list,
1386 output_list=output_list,
1387 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001388 ):
1389 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001390
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001391 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001392
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001393 compliance = self.tensorComplianceMetaData(
1394 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001395 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001396
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001397 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001398
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001399 def build_concat(
1400 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1401 ):
1402 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001403 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001404 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001405
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001406 result_tensor = OutputShaper.concatOp(
1407 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001408 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001409
Matthew Haddon818ab902021-07-27 09:12:49 +01001410 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001411 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001412 input_tensor_names.append(tensor.name)
1413
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001414 # Invalidate Input/Output list for error if checks.
1415 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001416 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001417 pCount, cCount = op["operands"]
1418 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001419 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1420 self, error_name, input_list, output_list
1421 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001422
Les Bell729b0352021-11-24 10:28:21 +00001423 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424 self.ser,
1425 validator_fcns,
1426 error_name,
1427 op=op,
1428 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001429 input_shape=inputs[0].shape,
1430 output_shape=result_tensor.shape,
1431 input_dtype=inputs[0].dtype,
1432 output_dtype=result_tensor.dtype,
1433 inputs=inputs,
1434 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001435 input_list=input_list,
1436 output_list=output_list,
1437 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001438 ):
1439 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001440
1441 attr = ts.TosaSerializerAttribute()
1442 attr.AxisAttribute(axis)
1443
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001444 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001445
1446 compliance = self.tensorComplianceMetaData(
1447 op, inputs[0].dtype, args_dict, result_tensor, error_name
1448 )
1449
1450 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001451
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001452 def build_pad(
1453 self,
1454 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001455 inputs,
1456 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001457 validator_fcns=None,
1458 error_name=None,
1459 qinfo=None,
1460 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001461 assert len(inputs) == 1
1462 a = inputs[0]
1463 padding = args_dict["pad"]
1464 pad_const_int = args_dict["pad_const_int"]
1465 pad_const_float = args_dict["pad_const_fp"]
1466
1467 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001468
Kevin Chengfe392ce2021-10-18 21:51:55 +00001469 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001470 attr.PadAttribute(
1471 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1472 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001473
Matthew Haddone807aae2021-10-11 18:12:58 +01001474 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001475 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001476 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001477 pCount, cCount = op["operands"]
1478 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001479 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1480 self, error_name, input_list, output_list
1481 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001482
Les Bell729b0352021-11-24 10:28:21 +00001483 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001484 self.ser,
1485 validator_fcns,
1486 error_name,
1487 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001488 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001489 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001490 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001491 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001492 pad=padding,
1493 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001494 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001495 input_list=input_list,
1496 output_list=output_list,
1497 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001498 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001499 ):
1500 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001501
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001502 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001503
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001504 compliance = self.tensorComplianceMetaData(
1505 op, a.dtype, args_dict, result_tensor, error_name
1506 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001507
1508 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001509
Won Jeona21b2e82023-08-10 10:33:01 +00001510 def build_dim(
1511 self,
1512 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001513 inputs,
1514 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001515 validator_fcns=None,
1516 error_name=None,
1517 qinfo=None,
1518 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001519 assert len(inputs) == 1
1520 a = inputs[0]
1521 axis = args_dict["axis"]
1522 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001523
1524 # Invalidate Input/Output list for error if checks.
1525 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001526 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001527 pCount, cCount = op["operands"]
1528 num_operands = pCount + cCount
1529 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1530 self, error_name, input_list, output_list
1531 )
1532
1533 if not TosaErrorValidator.evValidateErrorIfs(
1534 self.ser,
1535 validator_fcns,
1536 error_name,
1537 op=op,
1538 axis=axis,
1539 input_shape=a.shape,
1540 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001541 output_shape=result_tensor.shape,
1542 output_dtype=result_tensor.dtype,
1543 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001544 input_list=input_list,
1545 output_list=output_list,
1546 num_operands=num_operands,
1547 ):
1548 return None
1549
1550 attr = ts.TosaSerializerAttribute()
1551 attr.AxisAttribute(axis)
1552
1553 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001554 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001555
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001556 def build_reshape(
1557 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1558 ):
1559 assert len(inputs) == 1
1560 a = inputs[0]
1561 new_shape = args_dict["new_shape"]
1562 result_tensor = OutputShaper.reshapeOp(
1563 self.ser, self.rng, a, new_shape, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001564 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001565
1566 # Invalidate Input/Output list for error if checks.
1567 input_list = [a.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001568 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001569 pCount, cCount = op["operands"]
1570 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001571 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1572 self, error_name, input_list, output_list
1573 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001574
Les Bell729b0352021-11-24 10:28:21 +00001575 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001576 self.ser,
1577 validator_fcns,
1578 error_name,
1579 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001580 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001581 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001582 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001583 output_dtype=result_tensor.dtype,
1584 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001585 input_list=input_list,
1586 output_list=output_list,
1587 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001588 ):
1589 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001590
1591 attr = ts.TosaSerializerAttribute()
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001592 attr.ReshapeAttribute(new_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07001593
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001594 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001595
1596 compliance = self.tensorComplianceMetaData(
1597 op, a.dtype, args_dict, result_tensor, error_name
1598 )
1599
1600 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001601
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001602 def build_reverse(
1603 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1604 ):
1605 assert len(inputs) == 1
1606 a = inputs[0]
1607 axis = args_dict["axis"]
1608 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001609
1610 # Invalidate Input/Output list for error if checks.
1611 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001612 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001613 pCount, cCount = op["operands"]
1614 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001615 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1616 self, error_name, input_list, output_list
1617 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001618
Les Bell729b0352021-11-24 10:28:21 +00001619 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001620 self.ser,
1621 validator_fcns,
1622 error_name,
1623 op=op,
1624 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001625 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001626 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001627 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001628 output_dtype=result_tensor.dtype,
1629 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001630 input_list=input_list,
1631 output_list=output_list,
1632 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001633 ):
1634 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001635
1636 attr = ts.TosaSerializerAttribute()
1637 attr.AxisAttribute(axis)
1638
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001639 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001640 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001641
Matthew Haddone807aae2021-10-11 18:12:58 +01001642 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1643 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001644
Kevin Chengfe392ce2021-10-18 21:51:55 +00001645 attr = ts.TosaSerializerAttribute()
1646 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001647
Matthew Haddone807aae2021-10-11 18:12:58 +01001648 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001649 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001650 output_list = [result_tens.name]
1651 pCount, cCount = op["operands"]
1652 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001653 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1654 self, error_name, input_list, output_list
1655 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001656
Les Bell729b0352021-11-24 10:28:21 +00001657 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001658 self.ser,
1659 validator_fcns,
1660 error_name,
1661 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001662 input_shape=a.shape,
1663 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001664 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001665 input_dtype=a.dtype,
1666 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001667 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001668 input_list=input_list,
1669 output_list=output_list,
1670 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001671 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001672 ):
1673 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001674
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001675 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001676 return result_tens
1677
Matthew Haddone807aae2021-10-11 18:12:58 +01001678 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001679 result_tens = OutputShaper.sliceOp(
1680 self.ser, self.rng, a, start, size, error_name
1681 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001682
1683 # Invalidate Input/Output list for error if checks.
1684 input_list = [a.name]
1685 output_list = [result_tens.name]
1686 pCount, cCount = op["operands"]
1687 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001688 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1689 self, error_name, input_list, output_list
1690 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001691
Les Bell729b0352021-11-24 10:28:21 +00001692 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001693 self.ser,
1694 validator_fcns,
1695 error_name,
1696 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001697 input_shape=a.shape,
1698 output_shape=result_tens.shape,
1699 input_dtype=a.dtype,
1700 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001701 start=start,
1702 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001703 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001704 input_list=input_list,
1705 output_list=output_list,
1706 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001707 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001708 ):
1709 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001710
1711 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001712 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001713
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001714 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001715 return result_tens
1716
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001717 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1718 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1719
1720 # Invalidate Input/Output list for error if checks.
1721 input_list = [a.name]
1722 output_list = [result_tens.name]
1723 pCount, cCount = op["operands"]
1724 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001725 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1726 self, error_name, input_list, output_list
1727 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001728
Les Bell729b0352021-11-24 10:28:21 +00001729 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001730 self.ser,
1731 validator_fcns,
1732 error_name,
1733 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001734 input_shape=a.shape,
1735 output_shape=result_tens.shape,
1736 input_dtype=a.dtype,
1737 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001738 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001739 input_list=input_list,
1740 output_list=output_list,
1741 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001742 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001743 ):
1744 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001745
1746 attr = ts.TosaSerializerAttribute()
1747 attr.TileAttribute(multiples)
1748
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001749 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001750 return result_tens
1751
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001752 def build_gather(
1753 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1754 ):
1755 assert len(inputs) == 2
1756 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001757
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001758 result_tensor = OutputShaper.gatherOp(
1759 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001760 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001761
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001762 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001763 input_list = [values.name, indices.name]
1764 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001765 pCount, cCount = op["operands"]
1766 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001767 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1768 self, error_name, input_list, output_list
1769 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001770
Les Bell729b0352021-11-24 10:28:21 +00001771 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001772 self.ser,
1773 validator_fcns,
1774 error_name,
1775 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001776 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001777 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001778 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001779 output_dtype=result_tensor.dtype,
1780 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001781 input_list=input_list,
1782 output_list=output_list,
1783 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001784 ):
1785 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001786
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001787 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001788
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001789 compliance = self.tensorComplianceMetaData(
1790 op, values.dtype, args_dict, result_tensor, error_name
1791 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001792
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001793 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001794
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001795 def build_scatter(
1796 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1797 ):
1798 assert len(inputs) == 3
1799 values_in, indices, input = inputs
1800 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001801 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001802 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001803
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001804 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001805 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001806 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001807 pCount, cCount = op["operands"]
1808 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001809 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1810 self, error_name, input_list, output_list
1811 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001812
Les Bell729b0352021-11-24 10:28:21 +00001813 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001814 self.ser,
1815 validator_fcns,
1816 error_name,
1817 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001818 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001819 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001820 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001821 output_dtype=result_tensor.dtype,
1822 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001823 input_list=input_list,
1824 output_list=output_list,
1825 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001826 ):
1827 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001828
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001829 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001830
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001831 compliance = self.tensorComplianceMetaData(
1832 op, values_in.dtype, args_dict, result_tensor, error_name
1833 )
1834
1835 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001836
Kevin Cheng550ccc52021-03-03 11:21:43 -08001837 def build_resize(
1838 self,
1839 op,
1840 input,
1841 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001842 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001843 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001844 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001845 input_dtype,
1846 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001847 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001848 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001849 ):
1850 result_tens = OutputShaper.resizeOp(
1851 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001852 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001853 input,
1854 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001855 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001856 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001857 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001858 input_dtype,
1859 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001860 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001861 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001862
Matthew Haddon848efb42021-09-09 12:30:53 +01001863 # Invalidate Input/Output list for error if checks.
1864 input_list = [input.name]
1865 output_list = [result_tens.name]
1866 pCount, cCount = op["operands"]
1867 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001868 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1869 self, error_name, input_list, output_list
1870 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001871
Les Bell729b0352021-11-24 10:28:21 +00001872 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001873 self.ser,
1874 validator_fcns,
1875 error_name,
1876 op=op,
1877 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001878 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001879 input_dtype=input_dtype,
1880 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001881 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001882 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001883 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001884 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001885 input_list=input_list,
1886 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001887 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001888 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001889 ):
1890 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001891
Eric Kunzee5e26762020-10-13 16:11:07 -07001892 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001893
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001894 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001895
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001896 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001897 return result_tens
1898
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001899 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1900 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1901 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001902 self.ser.addOperator(
1903 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1904 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001905 return result_tens
1906
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001907 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001908 self.ser.addOutputTensor(val)
1909 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001910
1911 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00001912 def build_cast(
1913 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1914 ):
1915 assert len(inputs) == 1
1916 val = inputs[0]
1917 out_dtype = args_dict["out_type"]
1918
1919 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001920 self.ser, self.rng, val, out_dtype, error_name
1921 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001922
1923 # Invalidate Input/Output list for error if checks.
1924 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00001925 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001926 pCount, cCount = op["operands"]
1927 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001928 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1929 self, error_name, input_list, output_list
1930 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001931
Les Bell729b0352021-11-24 10:28:21 +00001932 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001933 self.ser,
1934 validator_fcns,
1935 error_name,
1936 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001937 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00001938 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001939 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00001940 output_dtype=result_tensor.dtype,
1941 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001942 input_list=input_list,
1943 output_list=output_list,
1944 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001945 ):
1946 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001947
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001948 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00001949
1950 compliance = self.tensorComplianceMetaData(
1951 op, val.dtype, args_dict, result_tensor, error_name
1952 )
1953
1954 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001955
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001956 def build_rescale(
1957 self,
1958 op,
1959 val,
1960 out_dtype,
1961 scale32,
1962 double_round,
1963 per_channel,
1964 validator_fcns,
1965 error_name,
1966 ):
1967 result_tens = OutputShaper.typeConversionOp(
1968 self.ser, self.rng, val, out_dtype, error_name
1969 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001970
1971 if per_channel:
1972 nc = val.shape[-1]
1973 else:
1974 nc = 1
1975
1976 in_type_width = self.typeWidth(val.dtype)
1977 out_type_width = self.typeWidth(out_dtype)
1978
Kevin Cheng3a478572021-01-22 17:21:02 -08001979 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001980 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001981 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001982 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001983 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001984 in_type_width += 1
1985 elif error_name in [
1986 ErrorIf.InputZeroPointNotZero,
1987 ErrorIf.U16InputZeroPointNotValid,
1988 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001989 input_zp = self.randInt(-128, 128)
1990 if input_zp == 0:
1991 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001992 in_type_width += 1
1993 elif val.dtype == DType.UINT16:
1994 # Must come after ErrorIf.U16InputZeroPointNotValid check
1995 input_zp = self.rng.choice([0, 32768])
1996 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001997 else:
1998 input_zp = 0
1999
Kevin Cheng3a478572021-01-22 17:21:02 -08002000 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002001 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002002 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002003 elif out_dtype == DType.UINT8:
2004 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002005 out_type_width += 1
2006 elif error_name in [
2007 ErrorIf.OutputZeroPointNotZero,
2008 ErrorIf.U16OutputZeroPointNotValid,
2009 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002010 output_zp = self.randInt(-128, 128)
2011 if output_zp == 0:
2012 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002013 out_type_width += 1
2014 elif out_dtype == DType.UINT16:
2015 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2016 output_zp = self.rng.choice([0, 32768])
2017 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002018 else:
2019 output_zp = 0
2020
2021 # Calculate scale based on:
2022 # scale = a *(2^output_width)/(2^input_width))
2023
2024 a = np.float32(self.rng.random(size=[nc]))
2025 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2026
2027 if scale32:
2028 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002029 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002030 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2031 else:
2032 # Cap the scaling at 2^15 - 1 for scale16
2033 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2034
Kevin Cheng550ccc52021-03-03 11:21:43 -08002035 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002036
2037 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2038 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002039 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2040 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002041
2042 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002043 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2044 scale_arr[i], scale32
2045 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002046 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2047 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002048
Kevin Cheng550ccc52021-03-03 11:21:43 -08002049 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002050 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002051 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002052 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002053 assert val.placeholderFilename
2054 values = np.load(
2055 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2056 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002057 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2058 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2059 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2060 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002061 if not np.all(np.array_equal(values, val_adj)):
2062 # Values changed so overwrite file with new values
2063 np.save(
2064 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2065 val_adj,
2066 False,
2067 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002068
Matthew Haddonc2025212021-10-08 21:21:05 +01002069 # Invalidate Input/Output list for error if checks.
2070 input_list = [val.name]
2071 output_list = [result_tens.name]
2072 pCount, cCount = op["operands"]
2073 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002074 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2075 self, error_name, input_list, output_list
2076 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002077
2078 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002079 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002080 self.ser,
2081 validator_fcns,
2082 error_name,
2083 op=op,
2084 input_dtype=val.dtype,
2085 output_dtype=out_dtype,
2086 input_shape=val.shape,
2087 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002088 scale32=scale32,
2089 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002090 input_list=input_list,
2091 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002092 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002093 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002094 ):
2095 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002096
Eric Kunzee5e26762020-10-13 16:11:07 -07002097 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002098 attr.RescaleAttribute(
2099 input_zp,
2100 output_zp,
2101 multiplier_arr,
2102 shift_arr,
2103 scale32,
2104 double_round,
2105 per_channel,
2106 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002107
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002108 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002109 return result_tens
2110
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002111 def _get_condition_tensor(self, op, cond, error_name):
2112 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002113 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002114 else:
2115 cond_type = DType.BOOL
2116 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2117 choice = self.rng.choice([1, 2])
2118 if choice == 1:
2119 cond_shape = [2]
2120 else:
2121 cond_shape = [1, 2]
2122 else:
2123 # Must be of size 1 (rank 0)
2124 cond_shape = []
2125 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2126 return cond_tens
2127
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002128 def build_cond_if_const(
2129 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2130 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002131 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002132 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002133 # and fill them with const nodes for the body.
2134
2135 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002136 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002137
2138 # Make then/else tensors
2139 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002140
2141 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002142 if error_name in [
2143 ErrorIf.CondIfOutputListThenGraphMismatch,
2144 ErrorIf.CondIfOutputListElseGraphMismatch,
2145 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002146 incorrect_shape = deepcopy(then_tens.shape)
2147 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002148 incorrect_shape[i] += (
2149 self.rng.choice([-3, -2, 2, 3])
2150 if incorrect_shape[i] > 3
2151 else self.rng.choice([1, 2, 4])
2152 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002153 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2154
Jeremy Johnson18e26662021-07-22 16:15:29 +01002155 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2156 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002157
2158 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002159 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002160
2161 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002162 then_block = "THEN_BLOCK"
2163 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002164 attr = ts.TosaSerializerAttribute()
2165 attr.CondIfAttribute(then_block, else_block)
2166
2167 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002168 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002169
Jerry Ge9e94af82022-10-27 09:57:00 -07002170 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002171 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002172 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2173 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2174 else:
2175 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002176 self.ser.addOutputTensor(then_tens)
2177
Jerry Ge9e94af82022-10-27 09:57:00 -07002178 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002179 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2180 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2181 else:
2182 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002183 self.ser.addOutputTensor(else_tens)
2184
Les Bell729b0352021-11-24 10:28:21 +00002185 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002186 self.ser,
2187 validator_fcns,
2188 error_name,
2189 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002190 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002191 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002192 ):
2193 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002194
Eric Kunzee5e26762020-10-13 16:11:07 -07002195 return result_tens
2196
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002197 def build_cond_if_binary(
2198 self, op, a, b, cond, validator_fcns=None, error_name=None
2199 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002200 # For cond_if with a binary op in the then/else blocks, take a and b and
2201 # alternately add or subtract them based on the condition
2202
2203 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002204 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002205
Kevin Cheng550ccc52021-03-03 11:21:43 -08002206 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002207
2208 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002209 then_block = "THEN_BLOCK"
2210 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002211 attr = ts.TosaSerializerAttribute()
2212 attr.CondIfAttribute(then_block, else_block)
2213
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002214 if error_name in [
2215 ErrorIf.CondIfInputListThenGraphMismatch,
2216 ErrorIf.CondIfInputListElseGraphMismatch,
2217 ErrorIf.CondIfOutputListElseGraphMismatch,
2218 ErrorIf.CondIfOutputListThenGraphMismatch,
2219 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002220 incorrect_shape = a.shape.copy()
2221 for i in range(len(incorrect_shape)):
2222 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2223 incorrect_block_input = deepcopy(a)
2224 incorrect_block_input.shape = incorrect_shape
2225
Eric Kunzee5e26762020-10-13 16:11:07 -07002226 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002227 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002228 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002229 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002230
James Ward24dbc422022-10-19 12:20:31 +01002231 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002232 then_op, else_op = Op.ADD, Op.SUB
2233 elif a.dtype in (DType.INT8, DType.INT16):
2234 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2235 else:
2236 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002237
Les Bell6040b4d2021-10-11 12:50:31 +01002238 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002239 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002240 if (
2241 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2242 and block == then_block
2243 ) or (
2244 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2245 and block == else_block
2246 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002247 self.ser.addInputTensor(incorrect_block_input)
2248 self.ser.addInputTensor(b)
2249 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002250 elif (
2251 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2252 and block == then_block
2253 ) or (
2254 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2255 and block == else_block
2256 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002257 self.ser.addInputTensor(a)
2258 self.ser.addInputTensor(b)
2259 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2260 else:
2261 self.ser.addInputTensor(a)
2262 self.ser.addInputTensor(b)
2263 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002264 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002265
Les Bell729b0352021-11-24 10:28:21 +00002266 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002267 self.ser,
2268 validator_fcns,
2269 error_name,
2270 op=op,
2271 a=a,
2272 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002273 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002274 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002275 ):
2276 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002277
Eric Kunzee5e26762020-10-13 16:11:07 -07002278 return result_tens
2279
Matthew Haddon630c17c2021-10-14 15:05:41 +01002280 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002281 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002282
Kevin Cheng550ccc52021-03-03 11:21:43 -08002283 cond_block = "COND_BLOCK"
2284 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002285
2286 attr = ts.TosaSerializerAttribute()
2287 attr.WhileLoopAttribute(cond_block, body_block)
2288
2289 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002290 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002291 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002292 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002293
2294 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002295 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2296 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002297 if error_name == ErrorIf.InputListOutputListMismatch:
2298 incorrect_acc = deepcopy(acc)
2299 for i in range(len(incorrect_acc.shape)):
2300 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2301 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2302 else:
2303 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002304
2305 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002306 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002307 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002308 [iter.name, a.name, acc.name],
2309 [iter_out.name, a_out.name, acc_out.name],
2310 attr,
2311 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002312 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002313
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002314 if error_name in [
2315 ErrorIf.InputListCondGraphMismatch,
2316 ErrorIf.InputListBodyGraphInputMismatch,
2317 ErrorIf.InputListBodyGraphOutputMismatch,
2318 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002319 incorrect_iter = deepcopy(iter)
2320 for i in range(len(incorrect_iter.shape)):
2321 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2322 if len(incorrect_iter.shape) == 0:
2323 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2324
2325 incorrect_acc = deepcopy(acc)
2326 for i in range(len(incorrect_acc.shape)):
2327 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2328
Eric Kunzee5e26762020-10-13 16:11:07 -07002329 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002330 self.ser.addBasicBlock(cond_block)
2331
Matthew Haddon630c17c2021-10-14 15:05:41 +01002332 if error_name == ErrorIf.InputListCondGraphMismatch:
2333 self.ser.addInputTensor(incorrect_iter)
2334 self.ser.addInputTensor(a)
2335 self.ser.addInputTensor(incorrect_acc)
2336 else:
2337 self.ser.addInputTensor(iter)
2338 self.ser.addInputTensor(a)
2339 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002340 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002341
2342 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002343 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002344 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002345 cond_type = DType.BOOL
2346 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2347 choice = self.rng.choice([1, 2])
2348 if choice == 1:
2349 cond_shape = [3]
2350 else:
2351 cond_shape = [1, 2]
2352 else:
2353 cond_shape = []
2354 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002355
Kevin Cheng550ccc52021-03-03 11:21:43 -08002356 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002357
2358 # BODY block (input: a, acc, iter, output: a, acc, iter)
2359 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002360 self.ser.addBasicBlock(body_block)
2361
Matthew Haddon630c17c2021-10-14 15:05:41 +01002362 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2363 self.ser.addInputTensor(incorrect_iter)
2364 self.ser.addInputTensor(a)
2365 self.ser.addInputTensor(incorrect_acc)
2366 else:
2367 self.ser.addInputTensor(iter)
2368 self.ser.addInputTensor(a)
2369 self.ser.addInputTensor(acc)
2370
Kevin Cheng550ccc52021-03-03 11:21:43 -08002371 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002372
2373 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002374 iter_body_out = self.ser.addIntermediate(
2375 incorrect_iter.shape, incorrect_iter.dtype
2376 )
2377 acc_body_out = self.ser.addIntermediate(
2378 incorrect_acc.shape, incorrect_acc.dtype
2379 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002380 else:
2381 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2382 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2383
Eric Kunzee5e26762020-10-13 16:11:07 -07002384 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2385 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2386 self.ser.addOutputTensor(iter_body_out)
2387 self.ser.addOutputTensor(a)
2388 self.ser.addOutputTensor(acc_body_out)
2389
Les Bell729b0352021-11-24 10:28:21 +00002390 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002391 self.ser,
2392 validator_fcns,
2393 error_name,
2394 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002395 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002396 ):
2397 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002398
Eric Kunzee5e26762020-10-13 16:11:07 -07002399 return acc_out
2400
Luke Hutton57287132023-02-06 14:54:18 +00002401 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002402 self,
2403 op,
2404 val1,
2405 val2,
2406 inverse,
2407 validator_fcns=None,
2408 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002409 ):
2410 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2411
2412 input_names = [val1.name, val2.name]
2413 pCount, cCount = op["operands"]
2414 num_operands = pCount + cCount
2415
2416 output_names = [res.name for res in results]
2417 output_shapes = [res.shape for res in results]
2418 output_dtypes = [res.dtype for res in results]
2419
2420 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2421 self, error_name, input_names, output_names
2422 )
2423
2424 if not TosaErrorValidator.evValidateErrorIfs(
2425 self.ser,
2426 validator_fcns,
2427 error_name,
2428 op=op,
2429 inverse=inverse,
2430 input1=val1,
2431 input2=val2,
2432 input_shape=val1.shape,
2433 input_dtype=val1.dtype,
2434 output_shape=output_shapes,
2435 output_dtype=output_dtypes,
2436 result_tensors=results,
2437 input_list=input_names,
2438 output_list=output_names,
2439 num_operands=num_operands,
2440 ):
2441 return None
2442
Tai Lyd3797f02023-11-15 23:06:19 +00002443 # TODO - Test local_bound, for now set local bound attribute to False
2444 local_bound = False
2445
Luke Hutton57287132023-02-06 14:54:18 +00002446 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002447 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002448
2449 self.ser.addOperator(op["op"], input_names, output_names, attr)
2450 return results
2451
Tai Lyd3797f02023-11-15 23:06:19 +00002452 def build_rfft2d(
2453 self,
2454 op,
2455 val,
2456 validator_fcns=None,
2457 error_name=None,
2458 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002459 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2460
2461 input_names = [val.name]
2462 pCount, cCount = op["operands"]
2463 num_operands = pCount + cCount
2464
2465 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002466 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002467 output_dtypes = [res.dtype for res in results]
2468
2469 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2470 self, error_name, input_names, output_names
2471 )
2472
2473 if not TosaErrorValidator.evValidateErrorIfs(
2474 self.ser,
2475 validator_fcns,
2476 error_name,
2477 op=op,
2478 input_shape=val.shape,
2479 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002480 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002481 output_dtype=output_dtypes,
2482 result_tensors=results,
2483 input_list=input_names,
2484 output_list=output_names,
2485 num_operands=num_operands,
2486 ):
2487 return None
2488
Tai Lyd3797f02023-11-15 23:06:19 +00002489 # TODO - Test local_bound, for now set local bound attribute to False
2490 local_bound = False
2491
2492 attr = ts.TosaSerializerAttribute()
2493 attr.RFFTAttribute(local_bound)
2494
2495 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002496 return results
2497
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002498 def create_filter_lists(
2499 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2500 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002501 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2502 default_test_rank_range = range(1, 5)
2503 if not shapeFilter:
2504 shapeFilter = [None]
2505
2506 # Calculate the filters based on what is requested and what the operator allows
2507 rmin, rmax = op["rank"]
2508 if rankFilter is not None:
2509 cleanRankFilter = []
2510 # Ensure rankFilter values are allowed by operator
2511 for rank in rankFilter:
2512 if rank >= rmin and rank <= rmax:
2513 cleanRankFilter.append(rank)
2514 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002515 # Ensure default behaviour is bounded by default range or by operator,
2516 # whichever is the smaller range of ranks.
2517 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002518 cleanRankFilter = (
2519 opRankRange
2520 if len(opRankRange) <= len(default_test_rank_range)
2521 else default_test_rank_range
2522 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002523 else:
2524 cleanRankFilter = range(rmin, rmax + 1)
2525
2526 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002527
Matthew Haddon1c00b712021-10-01 15:51:03 +01002528 if dtypeFilter is not None:
2529 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002530 # Create list of operator dtypes filtered by requested dtypes
2531 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002532 if dtype in dtypeFilter or (
2533 isinstance(dtype, list) and dtype[0] in dtypeFilter
2534 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002535 cleanDtypeFilter.append(dtype)
2536 else:
2537 cleanDtypeFilter = dtypes
2538
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002539 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002540 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002541 "shapeFilter": shapeFilter,
2542 "rankFilter": cleanRankFilter,
2543 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002544 }
2545 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002546 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002547 if validator is not None:
2548 validator_info = validator(check=False, op=op)
2549 else:
2550 return None
2551
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002552 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002553
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002554 # Set parameters as required
2555 if error_arguments["rank"] is not None:
2556 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002557 else:
2558 rankFilter = cleanRankFilter
2559
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002560 if error_arguments["dtype"] is not None:
2561 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002562 else:
2563 dtypeFilter = cleanDtypeFilter
2564
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002565 if error_arguments["shape"] is not None:
2566 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002567 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002568 shapeFilter = shapeFilter[
2569 :2
2570 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002571
2572 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002573 "shapeFilter": shapeFilter,
2574 "rankFilter": rankFilter,
2575 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002576 }
2577 return filterDict
2578
Kevin Cheng550ccc52021-03-03 11:21:43 -08002579 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002580 self,
2581 opName,
2582 shapeFilter=[None],
2583 rankFilter=None,
2584 dtypeFilter=None,
2585 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002586 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002587
2588 try:
2589 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002590 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002591 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002592
2593 # Initialize a new random number generator
2594 self.rng = np.random.default_rng(self.random_seed)
2595
Jeremy Johnson1271c442023-09-05 11:39:26 +01002596 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002597
Eric Kunzee5e26762020-10-13 16:11:07 -07002598 # Test list consists of a tuple of:
2599 # (opName, testNameStr, dtype, shapeList, argumentsList)
2600 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002601 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002602 error_if_validators = op["error_if_validators"]
2603 else:
2604 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002605
Matthew Haddon1c00b712021-10-01 15:51:03 +01002606 for validator in error_if_validators:
2607 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002608 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002609 else:
2610 error_name = None
2611
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002612 filterDict = self.create_filter_lists(
2613 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2614 )
2615 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002616 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002617 cleanRankFilter = filterDict["rankFilter"]
2618 cleanDtypeFilter = filterDict["dtypeFilter"]
2619 cleanShapeFilter = filterDict["shapeFilter"]
2620 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002621
2622 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002623 for t in cleanDtypeFilter:
2624 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002625 # Filter out by rank
2626 if shape is not None and len(shape) != r:
2627 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002628 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002629 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002630
Matthew Haddon74567092021-07-16 15:38:20 +01002631 shapeStr = self.shapeStr(shapeList[0])
2632 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002633
Matthew Haddon74567092021-07-16 15:38:20 +01002634 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2635 argList = []
2636 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002637 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002638 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002639 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002640
Matthew Haddon74567092021-07-16 15:38:20 +01002641 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002642 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002643 if argStr:
2644 testStr = "{}_{}_{}_{}".format(
2645 opName, shapeStr, typeStr, argStr
2646 )
2647 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002648 testStr = "{}_{}_{}".format(
2649 opName, shapeStr, typeStr
2650 )
2651 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002652 if argStr:
2653 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2654 opName, error_name, shapeStr, typeStr, argStr
2655 )
2656 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002657 testStr = "{}_ERRORIF_{}_{}_{}".format(
2658 opName, error_name, shapeStr, typeStr
2659 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002660
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002661 testList.append(
2662 (opName, testStr, t, error_name, shapeList, args)
2663 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002664
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002665 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002666 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2667 if "invalid_test_validators" in op:
2668 invalid_test_validators = op["invalid_test_validators"]
2669 clean_testList = []
2670 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002671 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002672 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002673 if validator_fcn(
2674 opName=test[0],
2675 input_dtype=test[2],
2676 shapeList=test[4],
2677 args=test[5],
2678 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002679 remove_test = True
2680 if not remove_test:
2681 clean_testList.append(test)
2682 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002683
2684 return testList
2685
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002686 def serializeTest(
2687 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2688 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002689 try:
2690 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002691 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002692 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002693
Jeremy Johnson0c716862023-04-13 17:18:19 +01002694 if self.args.verbose:
2695 print(f"Creating {testStr}")
2696
Eric Kunzee5e26762020-10-13 16:11:07 -07002697 # Create a serializer
2698 self.createSerializer(opName, testStr)
2699
Jeremy Johnson1271c442023-09-05 11:39:26 +01002700 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002701 if "error_if_validators" in op:
2702 error_if_validators = op["error_if_validators"]
2703 else:
2704 error_if_validators = None
2705
Kevin Cheng550ccc52021-03-03 11:21:43 -08002706 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002707 num_operands = pCount + cCount
2708
2709 if isinstance(dtype_or_dtypeList, list):
2710 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002711 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002712 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002713 else:
2714 dtypeList = [dtype_or_dtypeList] * (num_operands)
2715
Kevin Cheng93a16282021-08-31 16:14:03 -07002716 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002717 assert (
2718 len(shapeList) == num_operands
2719 ), "shapeList length {} must match number of operands {}".format(
2720 len(shapeList), num_operands
2721 )
2722 assert (
2723 len(dtypeList) == num_operands
2724 ), "dtypeList length {} must match number of operands {}".format(
2725 len(dtypeList), num_operands
2726 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002727
2728 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002729 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002730 except KeyError:
2731 qgen = None
2732
2733 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002734
Matthew Haddon1c00b712021-10-01 15:51:03 +01002735 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002736 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002737 else:
2738 qinfo = None
2739
Jeremy Johnson1271c442023-09-05 11:39:26 +01002740 # Extra meta data for the desc.json
2741 tensMeta = {}
2742
2743 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002744 if isinstance(testArgs, dict):
2745 # New interface with args info in dictionary
2746 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002747 assert "dg_type" in argsDict
2748 tvgInfo = tvgen_fcn(
2749 self, opName, dtypeList, shapeList, argsDict, error_name
2750 )
2751 if tvgInfo.dataGenDict:
2752 tensMeta["data_gen"] = tvgInfo.dataGenDict
2753 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002754
2755 result = build_fcn(
2756 self,
2757 op,
2758 tens,
2759 argsDict,
2760 validator_fcns=error_if_validators,
2761 error_name=error_name,
2762 qinfo=qinfo,
2763 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002764 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002765 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002766 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002767
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002768 try:
2769 if error_if_validators is None:
2770 if qinfo is not None:
2771 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2772 else:
2773 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002774 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002775 if qinfo is not None:
2776 result = build_fcn(
2777 self,
2778 op,
2779 *tens,
2780 *testArgs,
2781 validator_fcns=error_if_validators,
2782 error_name=error_name,
2783 qinfo=qinfo,
2784 )
2785 else:
2786 result = build_fcn(
2787 self,
2788 op,
2789 *tens,
2790 *testArgs,
2791 validator_fcns=error_if_validators,
2792 error_name=error_name,
2793 )
2794 except TypeError as e:
2795 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2796 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002797
Jeremy Johnson1271c442023-09-05 11:39:26 +01002798 if result:
Les Bell729b0352021-11-24 10:28:21 +00002799 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002800 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2801 # Add the compliance meta data
2802 # NOTE: This currently expects only one result output
2803 tensMeta["compliance"] = {
2804 "version": "0.1",
2805 "tensors": {result.resultTensor.name: result.complianceDict},
2806 }
2807 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002808 else:
2809 # The test is not valid
2810 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002811
Eric Kunzee5e26762020-10-13 16:11:07 -07002812 def createDynamicOpLists(self):
2813
Jeremy Johnson00423432022-09-12 17:27:37 +01002814 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2815 # Already created these lists (can occur when class is initialized more than once)
2816 return
2817
Eric Kunzee5e26762020-10-13 16:11:07 -07002818 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002819 if not self.args.level8k:
2820 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2821 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2822 else:
2823 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2824 KERNELS_2D = [[1, bigK], [bigK, 2]]
2825 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002826
Kevin Cheng1533b852021-09-01 12:51:58 -07002827 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002828 testName = "conv2d_{}x{}".format(k[0], k[1])
2829 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2830 self.TOSA_OP_LIST[testName]["filter"] = k
2831 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002832
Kevin Cheng550ccc52021-03-03 11:21:43 -08002833 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2834 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2835 "depthwise_conv2d_TEMPLATE"
2836 ].copy()
2837 self.TOSA_OP_LIST[testName]["filter"] = k
2838 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002839
Kevin Cheng550ccc52021-03-03 11:21:43 -08002840 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2841 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2842 "transpose_conv2d_TEMPLATE"
2843 ].copy()
2844 self.TOSA_OP_LIST[testName]["filter"] = k
2845 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002846
Kevin Cheng1533b852021-09-01 12:51:58 -07002847 for k in KERNELS_3D:
2848 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2849 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2850 self.TOSA_OP_LIST[testName]["filter"] = k
2851 self.TOSA_OP_LIST[testName]["template"] = False
2852
Eric Kunzee5e26762020-10-13 16:11:07 -07002853 # Delete any templates after having created any dynamic ops
2854 # This is a two-pass operation because it's bad practice to delete
2855 # keys from dictionaries while iterating
2856 keyList = []
2857 for k in self.TOSA_OP_LIST:
2858 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002859 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002860 keyList.append(k)
2861 continue
2862 except KeyError:
2863 pass
2864
2865 for k in keyList:
2866 del self.TOSA_OP_LIST[k]
2867
2868 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002869 """Fill in default fields for ops if they aren't already specified.
2870 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002871 for op in self.TOSA_OP_LIST:
2872
2873 # Required fields
2874 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002875 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002876 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002877 raise Exception(
2878 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2879 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002880
2881 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002882 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002883 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002884 raise Exception(
2885 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2886 op
2887 )
2888 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002889
2890 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002891 _ = self.TOSA_OP_LIST[op]["types"]
2892 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002893 raise Exception(
2894 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2895 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002896
2897 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002898 _ = self.TOSA_OP_LIST[op]["op"]
2899 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002900 raise Exception(
2901 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2902 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002903
2904 # Put in default rank range, if missing
2905 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002906 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002907 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002908 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002909
2910 # Tensor operator list
2911 # 'op': op name
2912 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002913 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2914 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002915 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2916 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002917 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002918
Kevin Cheng550ccc52021-03-03 11:21:43 -08002919 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002920 TYPE_INT_FP = [
2921 DType.INT8,
2922 DType.INT16,
2923 DType.INT32,
2924 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002925 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002926 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002927 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002928
Kevin Cheng550ccc52021-03-03 11:21:43 -08002929 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002930 TYPE_FI32 = [
2931 DType.FP32,
2932 DType.FP16,
2933 DType.BF16,
2934 DType.INT32,
2935 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002936 TYPE_FIB = [
2937 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002938 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002939 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002940 DType.INT8,
2941 DType.INT16,
2942 DType.INT32,
2943 DType.BOOL,
2944 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002945 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002946
James Ward24dbc422022-10-19 12:20:31 +01002947 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002948
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002949 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002950 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002951 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002952 [DType.INT8, DType.INT8, DType.INT32],
2953 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002954 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002955 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002956 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002957 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002958 ]
2959
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002960 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002961
2962 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002963 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002964 "argmax": {
2965 "op": Op.ARGMAX,
2966 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002967 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002968 "build_fcn": (
2969 build_argmax,
2970 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002971 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002972 TosaArgGen.agAxis,
2973 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002974 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002975 "error_if_validators": (
2976 TosaErrorValidator.evAxisSmallerZero,
2977 TosaErrorValidator.evAxisLargerRank,
2978 TosaErrorValidator.evArgmaxOutputRankMismatch,
2979 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2980 TosaErrorValidator.evWrongRank,
2981 TosaErrorValidator.evWrongInputType,
2982 TosaErrorValidator.evWrongOutputType,
2983 TosaErrorValidator.evWrongInputList,
2984 TosaErrorValidator.evWrongOutputList,
2985 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002986 "data_gen": {
2987 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
2988 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002989 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002990 "avg_pool2d": {
2991 "op": Op.AVG_POOL2D,
2992 "operands": (1, 0),
2993 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002994 "build_fcn": (
2995 build_pool2d,
2996 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002997 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002998 TosaArgGen.agPooling,
2999 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003000 "qgen": TosaQuantGen.qgUnary,
3001 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003002 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003003 "error_if_validators": (
3004 TosaErrorValidator.evKernelSmallerOne,
3005 TosaErrorValidator.evStrideSmallerOne,
3006 TosaErrorValidator.evPadSmallerZero,
3007 TosaErrorValidator.evWrongRank,
3008 TosaErrorValidator.evWrongInputType,
3009 TosaErrorValidator.evWrongOutputType,
3010 TosaErrorValidator.evWrongInputList,
3011 TosaErrorValidator.evWrongOutputList,
3012 TosaErrorValidator.evInputZeroPointNotZero,
3013 TosaErrorValidator.evOutputZeroPointNotZero,
3014 TosaErrorValidator.evPadLargerEqualKernel,
3015 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003016 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003017 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003018 "data_gen": {
3019 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3020 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003021 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003022 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003023 "conv2d_TEMPLATE": {
3024 "op": Op.CONV2D,
3025 "operands": (1, 2),
3026 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003027 "build_fcn": (
3028 build_conv2d,
3029 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003030 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003031 TosaArgGen.agConv,
3032 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003033 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003034 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003035 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3036 "error_if_validators": (
3037 TosaErrorValidator.evWrongInputType,
3038 TosaErrorValidator.evWrongOutputType,
3039 TosaErrorValidator.evWrongInputList,
3040 TosaErrorValidator.evWrongOutputList,
3041 TosaErrorValidator.evInputZeroPointNotZero,
3042 TosaErrorValidator.evWeightZeroPointNotZero,
3043 TosaErrorValidator.evPadSmallerZero,
3044 TosaErrorValidator.evStrideSmallerOne,
3045 TosaErrorValidator.evDilationSmallerOne,
3046 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003047 TosaErrorValidator.evConvOutputShapeMismatch,
3048 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003049 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003050 "data_gen": {
3051 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3052 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003053 "template": True,
3054 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003055 # Templated operator. Filled in by createDynamicOpLists
3056 "conv3d_TEMPLATE": {
3057 "op": Op.CONV3D,
3058 "operands": (1, 2),
3059 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003060 "build_fcn": (
3061 build_conv3d,
3062 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003063 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003064 TosaArgGen.agConv,
3065 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003066 "qgen": TosaQuantGen.qgConv,
3067 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003068 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3069 "error_if_validators": (
3070 TosaErrorValidator.evWrongInputType,
3071 TosaErrorValidator.evWrongOutputType,
3072 TosaErrorValidator.evWrongInputList,
3073 TosaErrorValidator.evWrongOutputList,
3074 TosaErrorValidator.evInputZeroPointNotZero,
3075 TosaErrorValidator.evWeightZeroPointNotZero,
3076 TosaErrorValidator.evPadSmallerZero,
3077 TosaErrorValidator.evStrideSmallerOne,
3078 TosaErrorValidator.evDilationSmallerOne,
3079 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003080 TosaErrorValidator.evConvOutputShapeMismatch,
3081 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003082 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003083 "template": True,
3084 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003085 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003086 "depthwise_conv2d_TEMPLATE": {
3087 "op": Op.DEPTHWISE_CONV2D,
3088 "operands": (1, 2),
3089 "filter": [1, 1],
3090 "rank": (4, 4),
3091 "build_fcn": (
3092 build_depthwise_conv2d,
3093 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003094 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003095 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003096 ),
3097 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003098 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003099 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3100 "error_if_validators": (
3101 TosaErrorValidator.evWrongInputType,
3102 TosaErrorValidator.evWrongOutputType,
3103 TosaErrorValidator.evWrongInputList,
3104 TosaErrorValidator.evWrongOutputList,
3105 TosaErrorValidator.evInputZeroPointNotZero,
3106 TosaErrorValidator.evWeightZeroPointNotZero,
3107 TosaErrorValidator.evPadSmallerZero,
3108 TosaErrorValidator.evStrideSmallerOne,
3109 TosaErrorValidator.evDilationSmallerOne,
3110 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003111 TosaErrorValidator.evConvOutputShapeMismatch,
3112 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003113 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003114 "template": True,
3115 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003116 "fully_connected": {
3117 "op": Op.FULLY_CONNECTED,
3118 "operands": (1, 2),
3119 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003120 "build_fcn": (
3121 build_fully_connected,
3122 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003123 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003124 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003125 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003126 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003127 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003128 "error_if_validators": (
3129 TosaErrorValidator.evInputZeroPointNotZero,
3130 TosaErrorValidator.evWeightZeroPointNotZero,
3131 TosaErrorValidator.evWrongRank,
3132 TosaErrorValidator.evWrongInputType,
3133 TosaErrorValidator.evWrongOutputType,
3134 TosaErrorValidator.evWrongInputList,
3135 TosaErrorValidator.evWrongOutputList,
3136 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003137 "data_gen": {
3138 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3139 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003140 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003141 "matmul": {
3142 "op": Op.MATMUL,
3143 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003144 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003145 "build_fcn": (
3146 build_matmul,
3147 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003148 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003149 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003150 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003151 "qgen": TosaQuantGen.qgMatmul,
3152 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003153 "error_if_validators": (
3154 TosaErrorValidator.evInputZeroPointNotZero,
3155 TosaErrorValidator.evWrongRank,
3156 TosaErrorValidator.evWrongInputType,
3157 TosaErrorValidator.evWrongOutputType,
3158 TosaErrorValidator.evWrongInputList,
3159 TosaErrorValidator.evWrongOutputList,
3160 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003161 "data_gen": {
3162 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003163 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003164 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003165 "max_pool2d": {
3166 "op": Op.MAX_POOL2D,
3167 "operands": (1, 0),
3168 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003169 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003170 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003171 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003172 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003173 TosaArgGen.agPooling,
3174 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003175 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003176 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003177 "error_if_validators": (
3178 TosaErrorValidator.evKernelSmallerOne,
3179 TosaErrorValidator.evStrideSmallerOne,
3180 TosaErrorValidator.evPadSmallerZero,
3181 TosaErrorValidator.evWrongRank,
3182 TosaErrorValidator.evWrongInputType,
3183 TosaErrorValidator.evWrongOutputType,
3184 TosaErrorValidator.evWrongInputList,
3185 TosaErrorValidator.evWrongOutputList,
3186 TosaErrorValidator.evPadLargerEqualKernel,
3187 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003188 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003189 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003190 "data_gen": {
3191 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3192 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003193 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003194 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003195 "transpose_conv2d_TEMPLATE": {
3196 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003197 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003198 "rank": (4, 4),
3199 "build_fcn": (
3200 build_transpose_conv2d,
3201 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003202 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003203 TosaArgGen.agTransposeConv2D,
3204 ),
3205 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003206 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003207 "invalid_test_validators": (
3208 TosaInvalidValidator.ivHeightWidthInvalid,
3209 TosaInvalidValidator.ivNonPositiveOutputShape,
3210 ),
3211 "error_if_validators": (
3212 TosaErrorValidator.evWrongInputType,
3213 TosaErrorValidator.evWrongOutputType,
3214 TosaErrorValidator.evWrongInputList,
3215 TosaErrorValidator.evWrongOutputList,
3216 TosaErrorValidator.evInputZeroPointNotZero,
3217 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003218 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003219 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003220 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003221 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003222 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003223 "template": True,
3224 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003225 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003226 "clamp": {
3227 "op": Op.CLAMP,
3228 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003229 "build_fcn": (
3230 build_clamp,
3231 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003232 TosaTensorValuesGen.tvgLazyGenDefault,
3233 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003234 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003235 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003236 "error_if_validators": (
3237 TosaErrorValidator.evMaxSmallerMin,
3238 TosaErrorValidator.evWrongInputType,
3239 TosaErrorValidator.evWrongOutputType,
3240 TosaErrorValidator.evWrongInputList,
3241 TosaErrorValidator.evWrongOutputList,
3242 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003243 "data_gen": {
3244 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3245 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003246 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003247 "sigmoid": {
3248 "op": Op.SIGMOID,
3249 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003250 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003251 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003252 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003253 TosaTensorValuesGen.tvgLazyGenDefault,
3254 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003255 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003256 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003257 "error_if_validators": (
3258 TosaErrorValidator.evWrongInputType,
3259 TosaErrorValidator.evWrongOutputType,
3260 TosaErrorValidator.evWrongInputList,
3261 TosaErrorValidator.evWrongOutputList,
3262 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003263 "data_gen": {
3264 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3265 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003266 },
3267 "tanh": {
3268 "op": Op.TANH,
3269 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003270 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003271 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003272 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003273 TosaTensorValuesGen.tvgLazyGenDefault,
3274 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003275 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003276 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003277 "error_if_validators": (
3278 TosaErrorValidator.evWrongInputType,
3279 TosaErrorValidator.evWrongOutputType,
3280 TosaErrorValidator.evWrongInputList,
3281 TosaErrorValidator.evWrongOutputList,
3282 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003283 "data_gen": {
3284 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3285 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003286 "compliance": {
3287 "abs_error_lower_bound": 0.5,
3288 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003289 },
Won Jeon78155c62023-06-10 00:20:04 +00003290 "erf": {
3291 "op": Op.ERF,
3292 "operands": (1, 0),
3293 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003294 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003295 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003296 TosaTensorValuesGen.tvgLazyGenDefault,
3297 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003298 ),
3299 "types": TYPE_FP,
3300 "error_if_validators": (
3301 TosaErrorValidator.evWrongInputType,
3302 TosaErrorValidator.evWrongOutputType,
3303 TosaErrorValidator.evWrongInputList,
3304 TosaErrorValidator.evWrongOutputList,
3305 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003306 "data_gen": {
3307 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3308 },
3309 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003310 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003311 # Elementwise Binary Operators
3312 "add": {
3313 "op": Op.ADD,
3314 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003315 "build_fcn": (
3316 build_binary_broadcast,
3317 TosaTensorGen.tgBroadcastFuzz,
3318 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003319 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003320 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003321 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003322 "error_if_validators": (
3323 TosaErrorValidator.evRankMismatch,
3324 TosaErrorValidator.evWrongInputType,
3325 TosaErrorValidator.evWrongOutputType,
3326 TosaErrorValidator.evWrongInputList,
3327 TosaErrorValidator.evWrongOutputList,
3328 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003329 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003330 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003331 "data_gen": {
3332 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3333 },
3334 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003335 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003336 "arithmetic_right_shift": {
3337 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3338 "operands": (2, 0),
3339 "build_fcn": (
3340 build_arithmetic_right_shift,
3341 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003342 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003343 TosaArgGen.agArithmeticRightShift,
3344 ),
3345 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003346 "error_if_validators": (
3347 TosaErrorValidator.evRankMismatch,
3348 TosaErrorValidator.evWrongInputType,
3349 TosaErrorValidator.evWrongOutputType,
3350 TosaErrorValidator.evWrongInputList,
3351 TosaErrorValidator.evWrongOutputList,
3352 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003353 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003354 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003355 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003356 "bitwise_and": {
3357 "op": Op.BITWISE_AND,
3358 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003359 "build_fcn": (
3360 build_binary_broadcast,
3361 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003362 TosaTensorValuesGen.tvgLazyGenDefault,
3363 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003364 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003365 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003366 "error_if_validators": (
3367 TosaErrorValidator.evRankMismatch,
3368 TosaErrorValidator.evWrongInputType,
3369 TosaErrorValidator.evWrongOutputType,
3370 TosaErrorValidator.evWrongInputList,
3371 TosaErrorValidator.evWrongOutputList,
3372 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003373 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003374 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003375 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003376 "bitwise_or": {
3377 "op": Op.BITWISE_OR,
3378 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003379 "build_fcn": (
3380 build_binary_broadcast,
3381 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003382 TosaTensorValuesGen.tvgLazyGenDefault,
3383 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003384 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003385 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003386 "error_if_validators": (
3387 TosaErrorValidator.evRankMismatch,
3388 TosaErrorValidator.evWrongInputType,
3389 TosaErrorValidator.evWrongOutputType,
3390 TosaErrorValidator.evWrongInputList,
3391 TosaErrorValidator.evWrongOutputList,
3392 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003393 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003394 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003395 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003396 "bitwise_xor": {
3397 "op": Op.BITWISE_XOR,
3398 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003399 "build_fcn": (
3400 build_binary_broadcast,
3401 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003402 TosaTensorValuesGen.tvgLazyGenDefault,
3403 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003404 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003405 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003406 "error_if_validators": (
3407 TosaErrorValidator.evRankMismatch,
3408 TosaErrorValidator.evWrongInputType,
3409 TosaErrorValidator.evWrongOutputType,
3410 TosaErrorValidator.evWrongInputList,
3411 TosaErrorValidator.evWrongOutputList,
3412 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003413 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003414 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003415 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003416 "intdiv": {
3417 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003418 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003419 "build_fcn": (
3420 build_binary_broadcast,
3421 TosaTensorGen.tgBroadcastFuzz,
3422 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003423 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003424 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003425 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003426 "error_if_validators": (
3427 TosaErrorValidator.evRankMismatch,
3428 TosaErrorValidator.evWrongInputType,
3429 TosaErrorValidator.evWrongOutputType,
3430 TosaErrorValidator.evWrongInputList,
3431 TosaErrorValidator.evWrongOutputList,
3432 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003433 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003434 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003435 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003436 "logical_and": {
3437 "op": Op.LOGICAL_AND,
3438 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003439 "build_fcn": (
3440 build_binary_broadcast,
3441 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003442 TosaTensorValuesGen.tvgLazyGenDefault,
3443 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003444 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003445 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003446 "error_if_validators": (
3447 TosaErrorValidator.evRankMismatch,
3448 TosaErrorValidator.evWrongInputType,
3449 TosaErrorValidator.evWrongOutputType,
3450 TosaErrorValidator.evWrongInputList,
3451 TosaErrorValidator.evWrongOutputList,
3452 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003453 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003454 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003455 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003456 "logical_left_shift": {
3457 "op": Op.LOGICAL_LEFT_SHIFT,
3458 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003459 "build_fcn": (
3460 build_binary_broadcast,
3461 TosaTensorGen.tgBroadcastFuzz,
3462 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003463 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003464 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003465 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003466 "error_if_validators": (
3467 TosaErrorValidator.evRankMismatch,
3468 TosaErrorValidator.evWrongInputType,
3469 TosaErrorValidator.evWrongOutputType,
3470 TosaErrorValidator.evWrongInputList,
3471 TosaErrorValidator.evWrongOutputList,
3472 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003473 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003474 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003475 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003476 "logical_right_shift": {
3477 "op": Op.LOGICAL_RIGHT_SHIFT,
3478 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003479 "build_fcn": (
3480 build_binary_broadcast,
3481 TosaTensorGen.tgBroadcastFuzz,
3482 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003483 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003484 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003485 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003486 "error_if_validators": (
3487 TosaErrorValidator.evRankMismatch,
3488 TosaErrorValidator.evWrongInputType,
3489 TosaErrorValidator.evWrongOutputType,
3490 TosaErrorValidator.evWrongInputList,
3491 TosaErrorValidator.evWrongOutputList,
3492 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003493 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003494 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003495 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003496 "logical_or": {
3497 "op": Op.LOGICAL_OR,
3498 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003499 "build_fcn": (
3500 build_binary_broadcast,
3501 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003502 TosaTensorValuesGen.tvgLazyGenDefault,
3503 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003504 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003505 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003506 "error_if_validators": (
3507 TosaErrorValidator.evRankMismatch,
3508 TosaErrorValidator.evWrongInputType,
3509 TosaErrorValidator.evWrongOutputType,
3510 TosaErrorValidator.evWrongInputList,
3511 TosaErrorValidator.evWrongOutputList,
3512 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003513 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003514 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003515 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003516 "logical_xor": {
3517 "op": Op.LOGICAL_XOR,
3518 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003519 "build_fcn": (
3520 build_binary_broadcast,
3521 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003522 TosaTensorValuesGen.tvgLazyGenDefault,
3523 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003524 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003525 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003526 "error_if_validators": (
3527 TosaErrorValidator.evRankMismatch,
3528 TosaErrorValidator.evWrongInputType,
3529 TosaErrorValidator.evWrongOutputType,
3530 TosaErrorValidator.evWrongInputList,
3531 TosaErrorValidator.evWrongOutputList,
3532 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003533 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003534 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003536 "maximum": {
3537 "op": Op.MAXIMUM,
3538 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003539 "build_fcn": (
3540 build_binary_broadcast,
3541 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003542 TosaTensorValuesGen.tvgLazyGenDefault,
3543 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003544 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003545 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003546 "error_if_validators": (
3547 TosaErrorValidator.evRankMismatch,
3548 TosaErrorValidator.evWrongInputType,
3549 TosaErrorValidator.evWrongOutputType,
3550 TosaErrorValidator.evWrongInputList,
3551 TosaErrorValidator.evWrongOutputList,
3552 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003553 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003554 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003555 "data_gen": {
3556 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3557 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003558 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003559 "minimum": {
3560 "op": Op.MINIMUM,
3561 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003562 "build_fcn": (
3563 build_binary_broadcast,
3564 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003565 TosaTensorValuesGen.tvgLazyGenDefault,
3566 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003567 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003568 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003569 "error_if_validators": (
3570 TosaErrorValidator.evRankMismatch,
3571 TosaErrorValidator.evWrongInputType,
3572 TosaErrorValidator.evWrongOutputType,
3573 TosaErrorValidator.evWrongInputList,
3574 TosaErrorValidator.evWrongOutputList,
3575 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003576 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003577 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003578 "data_gen": {
3579 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3580 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003581 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003582 "mul": {
3583 "op": Op.MUL,
3584 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003585 "build_fcn": (
3586 build_mul,
3587 TosaTensorGen.tgBroadcastFuzz,
3588 TosaTensorValuesGen.tvgMul,
3589 TosaArgGen.agMul,
3590 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003591 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003592 "error_if_validators": (
3593 TosaErrorValidator.evWrongInputType,
3594 TosaErrorValidator.evWrongOutputType,
3595 TosaErrorValidator.evWrongInputList,
3596 TosaErrorValidator.evWrongOutputList,
3597 TosaErrorValidator.evRankMismatch,
3598 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003599 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003600 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003601 "data_gen": {
3602 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3603 },
3604 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003605 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003606 "pow": {
3607 "op": Op.POW,
3608 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003609 "build_fcn": (
3610 build_binary_broadcast,
3611 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003612 TosaTensorValuesGen.tvgPow,
3613 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003614 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003615 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003616 "error_if_validators": (
3617 TosaErrorValidator.evRankMismatch,
3618 TosaErrorValidator.evWrongInputType,
3619 TosaErrorValidator.evWrongOutputType,
3620 TosaErrorValidator.evWrongInputList,
3621 TosaErrorValidator.evWrongOutputList,
3622 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003623 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003624 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003625 "data_gen": {
3626 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3627 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003628 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003629 "sub": {
3630 "op": Op.SUB,
3631 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003632 "build_fcn": (
3633 build_binary_broadcast,
3634 TosaTensorGen.tgBroadcastFuzz,
3635 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003636 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003637 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003638 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003639 "error_if_validators": (
3640 TosaErrorValidator.evRankMismatch,
3641 TosaErrorValidator.evWrongInputType,
3642 TosaErrorValidator.evWrongOutputType,
3643 TosaErrorValidator.evWrongInputList,
3644 TosaErrorValidator.evWrongOutputList,
3645 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003646 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003647 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003648 "data_gen": {
3649 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3650 },
3651 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003652 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003653 "table": {
3654 "op": Op.TABLE,
3655 # Use the automatic generation functions to create the input array
3656 # but create the table tensor in the build function, as it may be
3657 # a different type from the input
3658 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003659 "build_fcn": (
3660 build_table,
3661 TosaTensorGen.tgBasic,
3662 TosaTensorValuesGen.tvgDefault,
3663 TosaArgGen.agTable,
3664 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003665 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003666 "error_if_validators": (
3667 TosaErrorValidator.evWrongInputType,
3668 TosaErrorValidator.evWrongOutputType,
3669 TosaErrorValidator.evWrongInputList,
3670 TosaErrorValidator.evWrongOutputList,
3671 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003672 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003673 # Elementwise Unary operators
3674 "abs": {
3675 "op": Op.ABS,
3676 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003677 "build_fcn": (
3678 build_unary,
3679 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003680 TosaTensorValuesGen.tvgLazyGenDefault,
3681 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003682 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003683 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003684 "error_if_validators": (
3685 TosaErrorValidator.evWrongInputType,
3686 TosaErrorValidator.evWrongOutputType,
3687 TosaErrorValidator.evWrongInputList,
3688 TosaErrorValidator.evWrongOutputList,
3689 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003690 "data_gen": {
3691 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3692 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003693 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003694 "bitwise_not": {
3695 "op": Op.BITWISE_NOT,
3696 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003697 "build_fcn": (
3698 build_unary,
3699 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003700 TosaTensorValuesGen.tvgLazyGenDefault,
3701 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003702 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003703 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003704 "error_if_validators": (
3705 TosaErrorValidator.evWrongInputType,
3706 TosaErrorValidator.evWrongOutputType,
3707 TosaErrorValidator.evWrongInputList,
3708 TosaErrorValidator.evWrongOutputList,
3709 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003710 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003711 "ceil": {
3712 "op": Op.CEIL,
3713 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003714 "build_fcn": (
3715 build_unary,
3716 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003717 TosaTensorValuesGen.tvgLazyGenDefault,
3718 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003719 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003720 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003721 "error_if_validators": (
3722 TosaErrorValidator.evWrongInputType,
3723 TosaErrorValidator.evWrongOutputType,
3724 TosaErrorValidator.evWrongInputList,
3725 TosaErrorValidator.evWrongOutputList,
3726 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003727 "data_gen": {
3728 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3729 },
3730 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003731 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003732 "clz": {
3733 "op": Op.CLZ,
3734 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003735 "build_fcn": (
3736 build_unary,
3737 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003738 TosaTensorValuesGen.tvgLazyGenDefault,
3739 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003740 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003741 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003742 "error_if_validators": (
3743 TosaErrorValidator.evWrongInputType,
3744 TosaErrorValidator.evWrongOutputType,
3745 TosaErrorValidator.evWrongInputList,
3746 TosaErrorValidator.evWrongOutputList,
3747 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003748 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003749 "exp": {
3750 "op": Op.EXP,
3751 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003752 "build_fcn": (
3753 build_unary,
3754 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003755 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003756 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003757 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003758 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003759 "error_if_validators": (
3760 TosaErrorValidator.evWrongInputType,
3761 TosaErrorValidator.evWrongOutputType,
3762 TosaErrorValidator.evWrongInputList,
3763 TosaErrorValidator.evWrongOutputList,
3764 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003765 "data_gen": {
3766 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3767 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003768 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003769 "floor": {
3770 "op": Op.FLOOR,
3771 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003772 "build_fcn": (
3773 build_unary,
3774 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003775 TosaTensorValuesGen.tvgLazyGenDefault,
3776 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003777 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003778 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003779 "error_if_validators": (
3780 TosaErrorValidator.evWrongInputType,
3781 TosaErrorValidator.evWrongOutputType,
3782 TosaErrorValidator.evWrongInputList,
3783 TosaErrorValidator.evWrongOutputList,
3784 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003785 "data_gen": {
3786 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3787 },
3788 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003789 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003790 "log": {
3791 "op": Op.LOG,
3792 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003793 "build_fcn": (
3794 build_unary,
3795 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003796 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003797 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003798 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003799 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003800 "error_if_validators": (
3801 TosaErrorValidator.evWrongInputType,
3802 TosaErrorValidator.evWrongOutputType,
3803 TosaErrorValidator.evWrongInputList,
3804 TosaErrorValidator.evWrongOutputList,
3805 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003806 "data_gen": {
3807 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3808 },
3809 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003810 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003811 "logical_not": {
3812 "op": Op.LOGICAL_NOT,
3813 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003814 "build_fcn": (
3815 build_unary,
3816 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003817 TosaTensorValuesGen.tvgLazyGenDefault,
3818 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003819 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003820 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003821 "error_if_validators": (
3822 TosaErrorValidator.evWrongInputType,
3823 TosaErrorValidator.evWrongOutputType,
3824 TosaErrorValidator.evWrongInputList,
3825 TosaErrorValidator.evWrongOutputList,
3826 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003827 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003828 "negate": {
3829 "op": Op.NEGATE,
3830 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003831 "build_fcn": (
3832 build_unary,
3833 TosaTensorGen.tgBasic,
3834 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003835 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003836 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003837 "qgen": TosaQuantGen.qgUnary,
3838 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003839 "error_if_validators": (
3840 TosaErrorValidator.evInputZeroPointNotZero,
3841 TosaErrorValidator.evOutputZeroPointNotZero,
3842 TosaErrorValidator.evWrongInputType,
3843 TosaErrorValidator.evWrongOutputType,
3844 TosaErrorValidator.evWrongInputList,
3845 TosaErrorValidator.evWrongOutputList,
3846 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003847 "data_gen": {
3848 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3849 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003850 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003851 "reciprocal": {
3852 "op": Op.RECIPROCAL,
3853 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003854 "build_fcn": (
3855 build_unary,
3856 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003857 TosaTensorValuesGen.tvgLazyGenDefault,
3858 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003859 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003860 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003861 "error_if_validators": (
3862 TosaErrorValidator.evWrongInputType,
3863 TosaErrorValidator.evWrongOutputType,
3864 TosaErrorValidator.evWrongInputList,
3865 TosaErrorValidator.evWrongOutputList,
3866 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003867 "data_gen": {
3868 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3869 },
3870 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003871 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003872 "rsqrt": {
3873 "op": Op.RSQRT,
3874 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003875 "build_fcn": (
3876 build_unary,
3877 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003878 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003879 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003880 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003881 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003882 "error_if_validators": (
3883 TosaErrorValidator.evWrongInputType,
3884 TosaErrorValidator.evWrongOutputType,
3885 TosaErrorValidator.evWrongInputList,
3886 TosaErrorValidator.evWrongOutputList,
3887 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003888 "data_gen": {
3889 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3890 },
3891 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003892 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003893 # Elementwise Ternary operators
3894 "select": {
3895 "op": Op.SELECT,
3896 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003897 "build_fcn": (
3898 build_select,
3899 TosaTensorGen.tgBroadcastFuzz,
3900 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00003901 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003902 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003903 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003904 "error_if_validators": (
3905 TosaErrorValidator.evRankMismatch,
3906 TosaErrorValidator.evWrongInputType,
3907 TosaErrorValidator.evWrongOutputType,
3908 TosaErrorValidator.evWrongInputList,
3909 TosaErrorValidator.evWrongOutputList,
3910 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003911 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003912 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00003913 "data_gen": {
3914 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3915 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003916 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003917 # Comparison operators
3918 "equal": {
3919 "op": Op.EQUAL,
3920 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003921 "build_fcn": (
3922 build_comparison,
3923 TosaTensorGen.tgBroadcastFuzz,
3924 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003925 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003926 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003927 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003928 "error_if_validators": (
3929 TosaErrorValidator.evRankMismatch,
3930 TosaErrorValidator.evWrongInputType,
3931 TosaErrorValidator.evWrongOutputType,
3932 TosaErrorValidator.evWrongInputList,
3933 TosaErrorValidator.evWrongOutputList,
3934 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003935 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003936 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003937 "data_gen": {
3938 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3939 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003940 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003941 "greater_equal": {
3942 "op": Op.GREATER_EQUAL,
3943 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003944 "build_fcn": (
3945 build_comparison,
3946 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003947 TosaTensorValuesGen.tvgLazyGenDefault,
3948 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003949 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003950 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003951 "error_if_validators": (
3952 TosaErrorValidator.evRankMismatch,
3953 TosaErrorValidator.evWrongInputType,
3954 TosaErrorValidator.evWrongOutputType,
3955 TosaErrorValidator.evWrongInputList,
3956 TosaErrorValidator.evWrongOutputList,
3957 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003958 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003959 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003960 "data_gen": {
3961 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3962 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003963 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003964 "greater": {
3965 "op": Op.GREATER,
3966 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003967 "build_fcn": (
3968 build_comparison,
3969 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003970 TosaTensorValuesGen.tvgLazyGenDefault,
3971 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003972 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003973 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003974 "error_if_validators": (
3975 TosaErrorValidator.evRankMismatch,
3976 TosaErrorValidator.evWrongInputType,
3977 TosaErrorValidator.evWrongOutputType,
3978 TosaErrorValidator.evWrongInputList,
3979 TosaErrorValidator.evWrongOutputList,
3980 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003981 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003982 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003983 "data_gen": {
3984 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3985 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003986 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003987 # Reduction operators
3988 "reduce_all": {
3989 "op": Op.REDUCE_ALL,
3990 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003991 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003992 "build_fcn": (
3993 build_reduce,
3994 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003995 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003996 TosaArgGen.agAxis,
3997 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003998 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003999 "error_if_validators": (
4000 TosaErrorValidator.evAxisLargerRank,
4001 TosaErrorValidator.evAxisSmallerZero,
4002 TosaErrorValidator.evShapeOfAxisNotOne,
4003 TosaErrorValidator.evWrongInputType,
4004 TosaErrorValidator.evWrongOutputType,
4005 TosaErrorValidator.evWrongRank,
4006 TosaErrorValidator.evWrongInputList,
4007 TosaErrorValidator.evWrongOutputList,
4008 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004009 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004010 "reduce_any": {
4011 "op": Op.REDUCE_ANY,
4012 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004013 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004014 "build_fcn": (
4015 build_reduce,
4016 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004017 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004018 TosaArgGen.agAxis,
4019 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004020 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004021 "error_if_validators": (
4022 TosaErrorValidator.evAxisLargerRank,
4023 TosaErrorValidator.evAxisSmallerZero,
4024 TosaErrorValidator.evShapeOfAxisNotOne,
4025 TosaErrorValidator.evWrongInputType,
4026 TosaErrorValidator.evWrongOutputType,
4027 TosaErrorValidator.evWrongRank,
4028 TosaErrorValidator.evWrongInputList,
4029 TosaErrorValidator.evWrongOutputList,
4030 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004031 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004032 "reduce_max": {
4033 "op": Op.REDUCE_MAX,
4034 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004035 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004036 "build_fcn": (
4037 build_reduce,
4038 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004039 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004040 TosaArgGen.agAxis,
4041 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004042 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004043 "error_if_validators": (
4044 TosaErrorValidator.evAxisLargerRank,
4045 TosaErrorValidator.evAxisSmallerZero,
4046 TosaErrorValidator.evShapeOfAxisNotOne,
4047 TosaErrorValidator.evWrongInputType,
4048 TosaErrorValidator.evWrongOutputType,
4049 TosaErrorValidator.evWrongRank,
4050 TosaErrorValidator.evWrongInputList,
4051 TosaErrorValidator.evWrongOutputList,
4052 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004053 "data_gen": {
4054 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4055 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004056 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004057 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004058 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004059 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004060 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004061 "build_fcn": (
4062 build_reduce,
4063 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004064 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004065 TosaArgGen.agAxis,
4066 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004067 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004068 "error_if_validators": (
4069 TosaErrorValidator.evAxisLargerRank,
4070 TosaErrorValidator.evAxisSmallerZero,
4071 TosaErrorValidator.evShapeOfAxisNotOne,
4072 TosaErrorValidator.evWrongInputType,
4073 TosaErrorValidator.evWrongOutputType,
4074 TosaErrorValidator.evWrongRank,
4075 TosaErrorValidator.evWrongInputList,
4076 TosaErrorValidator.evWrongOutputList,
4077 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004078 "data_gen": {
4079 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4080 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004081 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004082 "reduce_product": {
4083 "op": Op.REDUCE_PRODUCT,
4084 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004085 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004086 "build_fcn": (
4087 build_reduce,
4088 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004089 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004090 TosaArgGen.agAxis,
4091 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004092 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004093 "error_if_validators": (
4094 TosaErrorValidator.evAxisLargerRank,
4095 TosaErrorValidator.evAxisSmallerZero,
4096 TosaErrorValidator.evShapeOfAxisNotOne,
4097 TosaErrorValidator.evWrongInputType,
4098 TosaErrorValidator.evWrongOutputType,
4099 TosaErrorValidator.evWrongRank,
4100 TosaErrorValidator.evWrongInputList,
4101 TosaErrorValidator.evWrongOutputList,
4102 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004103 "data_gen": {
4104 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4105 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004106 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004107 "reduce_sum": {
4108 "op": Op.REDUCE_SUM,
4109 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004110 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004111 "build_fcn": (
4112 build_reduce,
4113 TosaTensorGen.tgBasic,
4114 TosaTensorValuesGen.tvgReduceSum,
4115 TosaArgGen.agAxis,
4116 ),
James Ward24dbc422022-10-19 12:20:31 +01004117 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004118 "error_if_validators": (
4119 TosaErrorValidator.evAxisLargerRank,
4120 TosaErrorValidator.evAxisSmallerZero,
4121 TosaErrorValidator.evShapeOfAxisNotOne,
4122 TosaErrorValidator.evWrongInputType,
4123 TosaErrorValidator.evWrongOutputType,
4124 TosaErrorValidator.evWrongRank,
4125 TosaErrorValidator.evWrongInputList,
4126 TosaErrorValidator.evWrongOutputList,
4127 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004128 "data_gen": {
4129 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4130 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004131 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004132 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004133 "concat": {
4134 "op": Op.CONCAT,
4135 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004136 "build_fcn": (
4137 build_concat,
4138 TosaTensorGen.tgConcat,
4139 TosaTensorValuesGen.tvgConcat,
4140 TosaArgGen.agAxis,
4141 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004142 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004143 "error_if_validators": (
4144 TosaErrorValidator.evAxisLargerRank,
4145 TosaErrorValidator.evAxisSmallerZero,
4146 TosaErrorValidator.evConcatInputRankMismatch,
4147 TosaErrorValidator.evConcatShapeSumMismatch,
4148 TosaErrorValidator.evConcatInputDimMismatch,
4149 TosaErrorValidator.evWrongInputType,
4150 TosaErrorValidator.evWrongOutputType,
4151 TosaErrorValidator.evWrongOutputList,
4152 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004153 "data_gen": {
4154 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4155 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004156 },
4157 "pad": {
4158 "op": Op.PAD,
4159 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004160 "build_fcn": (
4161 build_pad,
4162 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004163 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004164 TosaArgGen.agPad,
4165 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004166 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004167 "error_if_validators": (
4168 TosaErrorValidator.evWrongInputType,
4169 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004170 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004171 TosaErrorValidator.evWrongOutputType,
4172 TosaErrorValidator.evWrongInputList,
4173 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004174 TosaErrorValidator.evRankMismatch,
4175 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004176 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004177 "data_gen": {
4178 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4179 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004180 },
Won Jeona21b2e82023-08-10 10:33:01 +00004181 "dim": {
4182 "op": Op.DIM,
4183 "operands": (1, 0),
4184 "build_fcn": (
4185 build_dim,
4186 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004187 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004188 TosaArgGen.agAxis,
4189 ),
4190 "types": TYPE_FIB,
4191 "error_if_validators": (
4192 TosaErrorValidator.evAxisLargerRank,
4193 TosaErrorValidator.evAxisSmallerZero,
4194 TosaErrorValidator.evWrongInputType,
4195 TosaErrorValidator.evWrongInputList,
4196 TosaErrorValidator.evWrongOutputList,
4197 TosaErrorValidator.evWrongRank,
4198 ),
4199 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004200 "reshape": {
4201 "op": Op.RESHAPE,
4202 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004203 "build_fcn": (
4204 build_reshape,
4205 TosaTensorGen.tgBasic,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004206 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004207 TosaArgGen.agReshape,
4208 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004209 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004210 "error_if_validators": (
4211 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4212 TosaErrorValidator.evWrongInputType,
4213 TosaErrorValidator.evWrongOutputType,
4214 TosaErrorValidator.evWrongInputList,
4215 TosaErrorValidator.evWrongOutputList,
4216 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004217 "data_gen": {
4218 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4219 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004220 },
4221 "reverse": {
4222 "op": Op.REVERSE,
4223 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004224 "build_fcn": (
4225 build_reverse,
4226 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004227 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004228 TosaArgGen.agAxis,
4229 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004230 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004231 "error_if_validators": (
4232 TosaErrorValidator.evAxisSmallerZero,
4233 TosaErrorValidator.evAxisLargerRank,
4234 TosaErrorValidator.evWrongInputType,
4235 TosaErrorValidator.evWrongOutputType,
4236 TosaErrorValidator.evWrongInputList,
4237 TosaErrorValidator.evWrongOutputList,
4238 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004239 },
4240 "slice": {
4241 "op": Op.SLICE,
4242 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004243 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004244 "build_fcn": (
4245 build_slice,
4246 TosaTensorGen.tgBasic,
4247 TosaTensorValuesGen.tvgDefault,
4248 TosaArgGen.agSlice,
4249 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004250 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004251 "error_if_validators": (
4252 TosaErrorValidator.evStartSmallerZero,
4253 TosaErrorValidator.evSizeSmallerEqualZero,
4254 TosaErrorValidator.evStartSizeOutsideBounds,
4255 TosaErrorValidator.evSizeOutputShapeMismatch,
4256 TosaErrorValidator.evInputSizeStartLengthMismatch,
4257 TosaErrorValidator.evWrongRank,
4258 TosaErrorValidator.evWrongInputType,
4259 TosaErrorValidator.evWrongOutputType,
4260 TosaErrorValidator.evWrongInputList,
4261 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004262 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004263 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004264 },
4265 "tile": {
4266 "op": Op.TILE,
4267 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004268 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004269 "build_fcn": (
4270 build_tile,
4271 TosaTensorGen.tgBasic,
4272 TosaTensorValuesGen.tvgDefault,
4273 TosaArgGen.agTile,
4274 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004275 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004276 "error_if_validators": (
4277 TosaErrorValidator.evWrongInputType,
4278 TosaErrorValidator.evWrongOutputType,
4279 TosaErrorValidator.evWrongInputList,
4280 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004281 TosaErrorValidator.evRankMismatch,
4282 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004283 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004284 },
4285 "transpose": {
4286 "op": Op.TRANSPOSE,
4287 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004288 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004289 "build_fcn": (
4290 build_transpose,
4291 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004292 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004293 TosaArgGen.agTranspose,
4294 ),
4295 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004296 "error_if_validators": (
4297 TosaErrorValidator.evIndexOutsideBounds,
4298 TosaErrorValidator.evIndexUsedTwice,
4299 TosaErrorValidator.evWrongInputType,
4300 TosaErrorValidator.evWrongOutputType,
4301 TosaErrorValidator.evWrongInputList,
4302 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004303 TosaErrorValidator.evWrongRank,
4304 TosaErrorValidator.evRankMismatch,
4305 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004306 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004307 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004308 # Data nodes
4309 "const": {
4310 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004311 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004312 "build_fcn": (
4313 build_const,
4314 TosaTensorGen.tgBasic,
4315 TosaTensorValuesGen.tvgDefault,
4316 None,
4317 ),
Luke Hutton65872422023-02-20 10:33:04 +00004318 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004319 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004320 "identity": {
4321 "op": Op.IDENTITY,
4322 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004323 "build_fcn": (
4324 build_unary,
4325 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004326 TosaTensorValuesGen.tvgLazyGenDefault,
4327 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004328 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004329 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004330 "data_gen": {
4331 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4332 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004333 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004334 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004335 "gather": {
4336 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004337 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004338 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004339 "build_fcn": (
4340 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004341 TosaTensorGen.tgGather,
4342 TosaTensorValuesGen.tvgGather,
4343 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004344 ),
James Ward24dbc422022-10-19 12:20:31 +01004345 "types": (
4346 DType.INT8,
4347 DType.INT16,
4348 DType.INT32,
4349 DType.FP16,
4350 DType.BF16,
4351 DType.FP32,
4352 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004353 "error_if_validators": (
4354 TosaErrorValidator.evWrongInputType,
4355 TosaErrorValidator.evWrongOutputType,
4356 TosaErrorValidator.evWrongInputList,
4357 TosaErrorValidator.evWrongOutputList,
4358 TosaErrorValidator.evWrongRank,
4359 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004360 "data_gen": {
4361 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4362 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004363 },
4364 "scatter": {
4365 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004366 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004367 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004368 "build_fcn": (
4369 build_scatter,
4370 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004371 TosaTensorValuesGen.tvgScatter,
4372 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004373 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004374 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004375 "error_if_validators": (
4376 TosaErrorValidator.evWrongInputType,
4377 TosaErrorValidator.evWrongOutputType,
4378 TosaErrorValidator.evWrongInputList,
4379 TosaErrorValidator.evWrongOutputList,
4380 TosaErrorValidator.evWrongRank,
4381 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004382 "data_gen": {
4383 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4384 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004385 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004386 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004387 "resize": {
4388 "op": Op.RESIZE,
4389 "operands": (1, 0),
4390 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004391 "build_fcn": (
4392 build_resize,
4393 TosaTensorGen.tgNHWC,
4394 TosaTensorValuesGen.tvgDefault,
4395 TosaArgGen.agResize,
4396 ),
James Ward24dbc422022-10-19 12:20:31 +01004397 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004398 "invalid_test_validators": (
4399 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004400 ),
4401 "error_if_validators": (
4402 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004403 TosaErrorValidator.evScaleSmallerEqualZero,
4404 TosaErrorValidator.evScaleNLargerMax,
4405 TosaErrorValidator.evScaleDLargerMax,
4406 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004407 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004408 TosaErrorValidator.evBorderSmallerMin,
4409 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004410 TosaErrorValidator.evWrongInputType,
4411 TosaErrorValidator.evWrongOutputType,
4412 TosaErrorValidator.evWrongRank,
4413 TosaErrorValidator.evWrongInputList,
4414 TosaErrorValidator.evWrongOutputList,
4415 TosaErrorValidator.evBatchMismatch,
4416 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004417 TosaErrorValidator.evResizeOutputShapeMismatch,
4418 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004419 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004420 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004421 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004422 "cast": {
4423 "op": Op.CAST,
4424 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004425 "build_fcn": (
4426 build_cast,
4427 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004428 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004429 TosaArgGen.agCast,
4430 ),
James Ward8b390432022-08-12 20:48:56 +01004431 "types": (
4432 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004433 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004434 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004435 DType.INT8,
4436 DType.INT16,
4437 DType.INT32,
4438 DType.BOOL,
4439 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004440 "error_if_validators": (
4441 TosaErrorValidator.evWrongInputType,
4442 TosaErrorValidator.evWrongOutputType,
4443 TosaErrorValidator.evWrongInputList,
4444 TosaErrorValidator.evWrongOutputList,
4445 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004446 "data_gen": {
4447 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4448 },
4449 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004450 },
4451 "rescale": {
4452 "op": Op.RESCALE,
4453 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004454 "build_fcn": (
4455 build_rescale,
4456 TosaTensorGen.tgBasic,
4457 TosaTensorValuesGen.tvgDefault,
4458 TosaArgGen.agRescale,
4459 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004460 "types": [
4461 DType.UINT8,
4462 DType.INT8,
4463 DType.INT16,
4464 DType.INT32,
4465 DType.INT48,
4466 DType.UINT16,
4467 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004468 "error_if_validators": (
4469 TosaErrorValidator.evInputZeroPointNotZero,
4470 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004471 TosaErrorValidator.evU16InputZeroPointNotValid,
4472 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004473 TosaErrorValidator.evScaleTrue,
4474 TosaErrorValidator.evScaleNotTrue,
4475 TosaErrorValidator.evWrongInputType,
4476 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004477 TosaErrorValidator.evWrongInputList,
4478 TosaErrorValidator.evWrongOutputList,
4479 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004480 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004481 # Custom
4482 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004483 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004484 # Two varients of cond_if, one that generates one of two constant tensors (no
4485 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4486 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004487 "cond_if_const": {
4488 "op": Op.COND_IF,
4489 "operands": (0, 2),
4490 "build_fcn": (
4491 build_cond_if_const,
4492 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004493 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004494 TosaArgGen.agCondIf,
4495 ),
4496 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004497 "error_if_validators": (
4498 TosaErrorValidator.evOutputListThenGraphMismatch,
4499 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004500 TosaErrorValidator.evCondIfCondNotMatchingBool,
4501 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004502 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004503 },
4504 "cond_if_binary": {
4505 "op": Op.COND_IF,
4506 "operands": (2, 0),
4507 "build_fcn": (
4508 build_cond_if_binary,
4509 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004510 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004511 TosaArgGen.agCondIf,
4512 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004513 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004514 "error_if_validators": (
4515 TosaErrorValidator.evInputListThenGraphMismatch,
4516 TosaErrorValidator.evInputListElseGraphMismatch,
4517 TosaErrorValidator.evOutputListThenGraphMismatch,
4518 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004519 TosaErrorValidator.evCondIfCondNotMatchingBool,
4520 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004521 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004522 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004523 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004524 "while_loop": {
4525 "op": Op.WHILE_LOOP,
4526 "operands": (0, 1),
4527 "build_fcn": (
4528 build_while_loop,
4529 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004530 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004531 TosaArgGen.agWhileLoop,
4532 ),
4533 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004534 "error_if_validators": (
4535 TosaErrorValidator.evInputListOutputListMismatch,
4536 TosaErrorValidator.evInputListCondGraphMismatch,
4537 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4538 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4539 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004540 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004541 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004542 },
Luke Hutton57287132023-02-06 14:54:18 +00004543 "fft2d": {
4544 "op": Op.FFT2D,
4545 "operands": (2, 0),
4546 "rank": (3, 3),
4547 "build_fcn": (
4548 build_fft2d,
4549 TosaTensorGen.tgFFT2d,
4550 TosaTensorValuesGen.tvgDefault,
4551 TosaArgGen.agFFT2d,
4552 ),
4553 "types": [DType.FP32],
4554 "error_if_validators": (
4555 TosaErrorValidator.evWrongInputType,
4556 TosaErrorValidator.evWrongOutputType,
4557 TosaErrorValidator.evWrongInputList,
4558 TosaErrorValidator.evWrongOutputList,
4559 TosaErrorValidator.evWrongRank,
4560 TosaErrorValidator.evBatchMismatch,
4561 TosaErrorValidator.evKernelNotPowerOfTwo,
4562 TosaErrorValidator.evFFTInputShapeMismatch,
4563 TosaErrorValidator.evFFTOutputShapeMismatch,
4564 ),
4565 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004566 "rfft2d": {
4567 "op": Op.RFFT2D,
4568 "operands": (1, 0),
4569 "rank": (3, 3),
4570 "build_fcn": (
4571 build_rfft2d,
4572 TosaTensorGen.tgRFFT2d,
4573 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004574 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004575 ),
4576 "types": [DType.FP32],
4577 "error_if_validators": (
4578 TosaErrorValidator.evWrongInputType,
4579 TosaErrorValidator.evWrongOutputType,
4580 TosaErrorValidator.evWrongInputList,
4581 TosaErrorValidator.evWrongOutputList,
4582 TosaErrorValidator.evWrongRank,
4583 TosaErrorValidator.evBatchMismatch,
4584 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004585 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004586 ),
4587 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004588 }
4589
Kevin Cheng550ccc52021-03-03 11:21:43 -08004590
Eric Kunzee5e26762020-10-13 16:11:07 -07004591class OutputShaper:
4592 # Methods in this class compute the expected output shape and datatype
4593 # for common classes of operations
4594 def __init__(self):
4595 pass
4596
4597 # These methods return arguments that can be used for
4598 # creating a new output tensor
4599 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004600 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4601 if error_name != ErrorIf.RankMismatch:
4602 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004603 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004604
4605 shape = []
4606 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004607 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004608 shape.append(b.shape[i])
4609 else:
4610 shape.append(a.shape[i])
4611
Jerry Ge135c9552023-05-23 20:59:32 +00004612 fuzz_idx = rng.integers(0, len(a.shape))
4613 if error_name == ErrorIf.DimensionMismatch:
4614 shape[fuzz_idx] += 1
4615
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004616 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004617 all_dtypes = [
4618 DType.INT8,
4619 DType.INT16,
4620 DType.INT32,
4621 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004622 DType.FP16,
4623 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004624 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004625 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004626 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4627 outputDType = rng.choice(wrong_dtypes)
4628 else:
4629 outputDType = a.dtype
4630
4631 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004632
4633 @staticmethod
4634 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004635 assert len(a.shape) == len(b.shape)
4636 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004637
4638 shape = []
4639 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004640 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004641 shape.append(a.shape[i])
4642
Kevin Cheng550ccc52021-03-03 11:21:43 -08004643 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004644
4645 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004646 def unaryOp(ser, rng, a, error_name=None):
4647 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004648 all_dtypes = [
4649 DType.INT8,
4650 DType.INT16,
4651 DType.INT32,
4652 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004653 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004654 DType.FP16,
4655 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004656 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004657 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4658 outputDType = rng.choice(wrong_dtypes)
4659 else:
4660 outputDType = a.dtype
4661
4662 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004663
4664 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004665 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004666 if error_name != ErrorIf.RankMismatch:
4667 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004668 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004669
4670 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004671 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004672 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004673 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4674 else:
4675 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004676
Jerry Ge135c9552023-05-23 20:59:32 +00004677 fuzz_idx = rng.integers(0, len(a.shape))
4678 if error_name == ErrorIf.DimensionMismatch:
4679 shape[fuzz_idx] += 1
4680
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004681 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004682 all_dtypes = [
4683 DType.INT8,
4684 DType.INT16,
4685 DType.INT32,
4686 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004687 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004688 DType.FP16,
4689 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004690 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004691 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4692 outputDType = rng.choice(wrong_dtypes)
4693 else:
4694 outputDType = a.dtype
4695
4696 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004697
4698 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004699 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004700 if error_name != ErrorIf.RankMismatch:
4701 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004702 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004703
4704 # Do broadcast
4705 shape = []
4706 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004707 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004708 shape.append(b.shape[i])
4709 else:
4710 shape.append(a.shape[i])
4711
Jerry Ge135c9552023-05-23 20:59:32 +00004712 fuzz_idx = rng.integers(0, len(a.shape))
4713 if error_name == ErrorIf.DimensionMismatch:
4714 shape[fuzz_idx] += 1
4715
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004716 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004717 wrong_dtypes = [
4718 DType.INT8,
4719 DType.INT16,
4720 DType.INT32,
4721 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004722 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004723 DType.FP16,
4724 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004725 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004726 outputDType = rng.choice(wrong_dtypes)
4727 else:
4728 outputDType = DType.BOOL
4729
4730 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004731
4732 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004733 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004734 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004735 if error_name not in [
4736 ErrorIf.AxisSmallerZero,
4737 ErrorIf.AxisLargerRank,
4738 ErrorIf.ShapeOfAxisNotOne,
4739 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004740 shape[axis] = 1
4741 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4742 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004743
Matthew Haddond6ce7252021-09-29 15:35:44 +01004744 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004745 all_dtypes = [
4746 DType.INT8,
4747 DType.INT16,
4748 DType.INT32,
4749 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004750 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004751 DType.FP16,
4752 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004753 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004754 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4755 outputDType = rng.choice(wrong_dtypes)
4756 else:
4757 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004758
Matthew Haddond6ce7252021-09-29 15:35:44 +01004759 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004760
4761 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004762 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004763 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004764
4765 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4766 del shape[axis]
4767
4768 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4769 remove = rng.choice([True, False])
4770 if remove and len(shape) > 1:
4771 del shape[0]
4772 else:
4773 shape.append(1)
4774 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4775 for i in range(len(shape)):
4776 shape[i] = shape[i] + rng.integers(1, 10)
4777
4778 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004779 all_dtypes = [
4780 DType.INT8,
4781 DType.INT16,
4782 DType.INT32,
4783 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004784 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004785 DType.FP16,
4786 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004787 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004788 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4789 outputDType = rng.choice(wrong_dtypes)
4790 else:
4791 outputDType = DType.INT32
4792
4793 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004794
4795 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004796 def conv2dOp(
4797 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4798 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004799
4800 # IFM: NHWC
4801 # Filter: OHWI
4802 # OFM: NHWC
4803
Kevin Cheng550ccc52021-03-03 11:21:43 -08004804 h = (
4805 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004806 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004807 + padding[0]
4808 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004809 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004810 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004811
Kevin Cheng550ccc52021-03-03 11:21:43 -08004812 w = (
4813 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004814 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004815 + padding[2]
4816 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004817 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004818 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004819
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004820 if error_name == ErrorIf.ConvOutputShapeMismatch:
4821 choices = [1, 2, 3]
4822 change = rng.choice(choices)
4823 # increment in multiples of stride to not hit non-integer error case
4824 if change in [1, 3]:
4825 h = h + (rng.choice(choices) * strides[0])
4826 if change in [2, 3]:
4827 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004828
Eric Kunzee5e26762020-10-13 16:11:07 -07004829 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4830
James Ward8b390432022-08-12 20:48:56 +01004831 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004832 # Pick some potentially correct output dtype if input type is incorrect
4833 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004834 else:
James Ward8b390432022-08-12 20:48:56 +01004835 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004836
4837 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004838 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004839 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004840 else:
4841 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004842 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004843 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004844
Kevin Cheng550ccc52021-03-03 11:21:43 -08004845 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004846
4847 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004848 def conv3dOp(
4849 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4850 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004851
4852 # IFM: NDHWC
4853 # Filter: ODHWI
4854 # OFM: NDHWC
4855
4856 d = (
4857 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004858 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004859 + padding[0]
4860 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004861 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004862 ) // strides[0] + 1
4863
4864 h = (
4865 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004866 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004867 + padding[2]
4868 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004869 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004870 ) // strides[1] + 1
4871
4872 w = (
4873 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004874 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004875 + padding[4]
4876 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004877 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004878 ) // strides[2] + 1
4879
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004880 if error_name == ErrorIf.ConvOutputShapeMismatch:
4881 choices = [1, 2, 3, 4]
4882 change = rng.choice(choices)
4883 # increment in multiples of stride to not hit non-integer error case
4884 if change in [1, 4]:
4885 d = d + (rng.choice(choices) * strides[0])
4886 if change in [2, 4]:
4887 h = h + (rng.choice(choices) * strides[1])
4888 if change in [3, 4]:
4889 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004890
Kevin Cheng1533b852021-09-01 12:51:58 -07004891 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4892
James Ward8b390432022-08-12 20:48:56 +01004893 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004894 # Pick some potentially correct output dtype if input type is incorrect
4895 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004896 else:
James Ward8b390432022-08-12 20:48:56 +01004897 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004898
4899 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004900 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004901 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004902 else:
4903 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004904 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004905 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004906
4907 return ser.addOutput(ofm_shape, out_dtype)
4908
4909 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004910 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004911 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004912 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004913 # IFM: NHWC
4914 # Filter: HWCM
4915 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004916
Kevin Cheng550ccc52021-03-03 11:21:43 -08004917 h = (
4918 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004919 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004920 + padding[0]
4921 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004922 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004923 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004924
Kevin Cheng550ccc52021-03-03 11:21:43 -08004925 w = (
4926 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004927 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004928 + padding[2]
4929 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004930 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004931 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004932
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004933 if error_name == ErrorIf.ConvOutputShapeMismatch:
4934 choices = [1, 2, 3]
4935 change = rng.choice(choices)
4936 # increment in multiples of stride to not hit non-integer error case
4937 if change in [1, 3]:
4938 h = h + (rng.choice(choices) * strides[0])
4939 if change in [2, 3]:
4940 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004941
Eric Kunzee5e26762020-10-13 16:11:07 -07004942 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4943
James Ward8b390432022-08-12 20:48:56 +01004944 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004945 # Pick some potentially correct output dtype if input type is incorrect
4946 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004947 else:
James Ward8b390432022-08-12 20:48:56 +01004948 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004949
4950 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004951 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004952 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004953 else:
4954 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004955 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004956 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004957
Kevin Cheng550ccc52021-03-03 11:21:43 -08004958 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004959
4960 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004961 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004962 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004963 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004964 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004965 h = 1
4966 w = 1
4967 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004968 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4969 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004970
4971 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004972 choices = [1, 2, 3]
4973 change = rng.choice(choices)
4974 # increment in multiples of stride to not hit non-integer error case
4975 if change in [1, 3]:
4976 h = h + (rng.choice(choices) * stride[0])
4977 if change in [2, 3]:
4978 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004979 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004980
4981 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004982 all_dtypes = [
4983 DType.INT8,
4984 DType.INT16,
4985 DType.INT32,
4986 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004987 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004988 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004989 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004990 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004991 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4992 outputDType = rng.choice(wrong_dtypes)
4993 else:
4994 outputDType = ifm.dtype
4995
4996 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004997
4998 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004999 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005000 # input: N, IC
5001 # filter: OC, IC
5002 # output: N, OC
5003
5004 output_shape = [input.shape[0], filter.shape[0]]
5005
James Ward8b390432022-08-12 20:48:56 +01005006 # Validated in arg_gen (also invalidated for ErrorIf)
5007 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005008
Kevin Cheng550ccc52021-03-03 11:21:43 -08005009 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005010
5011 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005012 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005013 # a: N, H, C
5014 # b: N, C, W
5015 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005016
Kevin Cheng2d60f002021-06-09 14:18:32 -07005017 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005018
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005019 if error_name == ErrorIf.WrongOutputType:
5020 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005021 incorrect_types = (
5022 DType.INT4,
5023 DType.INT8,
5024 DType.INT16,
5025 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005026 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005027 DType.FP16,
5028 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005029 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005030 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005031 incorrect_types = (
5032 DType.INT4,
5033 DType.INT8,
5034 DType.INT16,
5035 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005036 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005037 DType.FP16,
5038 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005039 )
James Ward24dbc422022-10-19 12:20:31 +01005040 elif (
5041 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5042 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005043 incorrect_types = (
5044 DType.INT4,
5045 DType.INT8,
5046 DType.INT16,
5047 DType.INT32,
5048 DType.INT48,
5049 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005050 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005051 elif error_name == ErrorIf.WrongInputType:
5052 # Pick some potentially correct output dtype if input type is incorrect
5053 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005054 else:
James Ward8b390432022-08-12 20:48:56 +01005055 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005056
Kevin Cheng550ccc52021-03-03 11:21:43 -08005057 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005058
5059 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005060 def concatOp(ser, rng, axis, inputs, error_name=None):
5061 input1 = inputs[0]
5062 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005063
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005064 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005065 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005066 if not (
5067 # unable to concat tensors of different ranks
5068 error_name == ErrorIf.ConcatInputRankMismatch
5069 # unable to concat tensors along an invalid axis
5070 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005071 ):
5072 for tensor in remaining_inputs:
5073 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005074
Matthew Haddon01c359d2021-10-15 16:30:48 +01005075 if error_name == ErrorIf.ConcatShapeSumMismatch:
5076 output_shape[axis] += rng.integers(5, 10)
5077
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005078 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005079 all_dtypes = {
5080 DType.INT8,
5081 DType.INT16,
5082 DType.INT32,
5083 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005084 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005085 DType.FP16,
5086 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005087 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005088 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5089 outputDType = rng.choice(wrong_dtypes)
5090 else:
5091 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005092
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005093 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005094
5095 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005096 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005097
5098 output_shape = a.shape.copy()
5099
5100 for i in range(len(output_shape)):
5101 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5102
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005103 if error_name == ErrorIf.PadOutputShapeMismatch:
5104 bad_dim = rng.choice(range(len(output_shape)))
5105 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005106 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005107 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005108
Matthew Haddone807aae2021-10-11 18:12:58 +01005109 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005110 all_dtypes = [
5111 DType.INT8,
5112 DType.INT16,
5113 DType.INT32,
5114 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005115 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005116 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005117 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005118 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005119 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5120 outputDType = rng.choice(wrong_dtypes)
5121 else:
5122 outputDType = a.dtype
5123
5124 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005125
5126 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005127 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00005128 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00005129
5130 if error_name == ErrorIf.WrongOutputType:
5131 all_dtypes = [
5132 DType.INT8,
5133 DType.INT16,
5134 DType.INT32,
5135 DType.INT48,
5136 DType.FP32,
5137 DType.FP16,
5138 DType.BF16,
5139 ]
5140 wrong_dtypes = list(set(all_dtypes))
5141 outputDType = rng.choice(wrong_dtypes)
5142 else:
5143 outputDType = DType.SHAPE
5144
5145 return ser.addOutput(output_shape, outputDType)
5146
5147 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005148 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005149 output_shape = shape.copy()
5150
Matthew Haddone807aae2021-10-11 18:12:58 +01005151 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5152 for i in range(len(output_shape)):
5153 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5154
5155 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005156 all_dtypes = [
5157 DType.INT8,
5158 DType.INT16,
5159 DType.INT32,
5160 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005161 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005162 DType.FP16,
5163 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005164 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005165 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5166 outputDType = rng.choice(wrong_dtypes)
5167 else:
5168 outputDType = a.dtype
5169
5170 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005171
5172 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005173 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005174
Matthew Haddone807aae2021-10-11 18:12:58 +01005175 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005176 all_dtypes = [
5177 DType.INT8,
5178 DType.INT16,
5179 DType.INT32,
5180 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005181 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005182 DType.FP16,
5183 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005184 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005185 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005186 outputDType = rng.choice(wrong_dtypes)
5187 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005188 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005189
Luke Huttona4e48ca2023-02-22 11:53:48 +00005190 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005191 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005192 for index in range(len(output_shape)):
5193 if output_shape[index] <= 2:
5194 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5195 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005196 output_shape[index] = output_shape[index] + rng.choice(
5197 [-2, -1, 1, 2]
5198 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005199 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5200 output_shape = input.shape.copy()
5201 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005202 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005203
5204 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005205
5206 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005207 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005208
5209 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005210 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005211
5212 for i in range(len(output_shape)):
5213 output_shape[i] = a.shape[i] * multiples[i]
5214
Luke Huttona4e48ca2023-02-22 11:53:48 +00005215 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005216 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005217
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005218 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005219 all_dtypes = [
5220 DType.INT8,
5221 DType.INT16,
5222 DType.INT32,
5223 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005224 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005225 DType.FP16,
5226 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005227 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005228 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5229 outputDType = rng.choice(wrong_dtypes)
5230 else:
5231 outputDType = a.dtype
5232
5233 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005234
5235 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005236 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005237 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005238
Kevin Cheng550ccc52021-03-03 11:21:43 -08005239 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005240
Luke Huttona4e48ca2023-02-22 11:53:48 +00005241 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005242 for i in range(len(output_shape)):
5243 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005244
Luke Huttona4e48ca2023-02-22 11:53:48 +00005245 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5246 for i in range(len(output_shape)):
5247 output_shape[i] += rng.integers(1, 10)
5248 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005249 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005250
Matthew Haddone807aae2021-10-11 18:12:58 +01005251 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005252 all_dtypes = [
5253 DType.INT8,
5254 DType.INT16,
5255 DType.INT32,
5256 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005257 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005258 DType.FP16,
5259 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005260 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005261 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5262 outputDType = rng.choice(wrong_dtypes)
5263 else:
5264 outputDType = a.dtype
5265
5266 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005267
5268 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005269 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005270 if error_name != ErrorIf.WrongRank:
5271 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005272 assert len(indices.shape) == 2
5273 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005274
Kevin Cheng77d0f762020-11-24 10:26:32 -08005275 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5276
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005277 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005278 all_dtypes = [
5279 DType.INT8,
5280 DType.INT16,
5281 DType.INT32,
5282 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005283 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005284 DType.FP16,
5285 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005286 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005287 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5288 outputDType = rng.choice(wrong_dtypes)
5289 else:
5290 outputDType = values.dtype
5291
5292 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005293
5294 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005295 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005296 if error_name != ErrorIf.WrongRank:
5297 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005298 assert len(indices.shape) == 2
5299 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005300 assert values_in.shape[0] == indices.shape[0] # N
5301 assert input.shape[1] == indices.shape[1] # W
5302 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005303
5304 output_shape = values_in.shape
5305
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005306 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005307 all_dtypes = [
5308 DType.INT8,
5309 DType.INT16,
5310 DType.INT32,
5311 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005312 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005313 DType.FP16,
5314 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005315 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005316 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5317 outputDType = rng.choice(wrong_dtypes)
5318 else:
5319 outputDType = values_in.dtype
5320
5321 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005322
5323 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005324 def tableOp(ser, rng, input, error_name=None):
5325 # Same shape as the input, dtype dependent on input dtype
5326 if error_name != ErrorIf.WrongInputType:
5327 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005328 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005329 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005330 wrong_dtypes = [
5331 DType.INT8,
5332 DType.INT16,
5333 DType.INT32,
5334 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005335 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005336 DType.FP16,
5337 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005338 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005339 wrong_dtypes.remove(output_dtype)
5340 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005341 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005342
5343 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005344 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005345 serializer,
5346 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005347 input,
5348 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005349 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005350 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005351 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005352 input_dtype,
5353 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005354 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005355 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005356 # Calculate OH, OW
5357 scale_y_n = scale[0]
5358 scale_y_d = scale[1]
5359 scale_x_n = scale[2]
5360 scale_x_d = scale[3]
5361 if error_name == ErrorIf.ScaleSmallerEqualZero:
5362 scale_y_n = max(scale_y_n, 1)
5363 scale_y_d = max(scale_y_d, 1)
5364 scale_x_n = max(scale_x_n, 1)
5365 scale_x_d = max(scale_x_d, 1)
5366
5367 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5368 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5369
5370 if error_name is not None:
5371 # Make sure the output tensor is valid, which can occur when
5372 # scale, offset or border have been changed for ERROR_IFs
5373 oh = max(oh, 1)
5374 ow = max(ow, 1)
5375 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005376 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5377 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005378
5379 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5380 choices = [1, 2, 3]
5381 change = rng.choice(choices)
5382 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5383 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005384 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005385 oh -= scale_y_d
5386 assert oh > 0 # Should have been caught in agResize
5387 else:
5388 oh += scale_y_d
5389 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005390 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005391 ow -= scale_x_d
5392 assert ow > 0 # Should have been caught in agResize
5393 else:
5394 ow += scale_x_d
5395
Matthew Haddon848efb42021-09-09 12:30:53 +01005396 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005397 output_dims = [
5398 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005399 oh,
5400 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005401 input.shape[0],
5402 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005403 elif error_name == ErrorIf.BatchMismatch:
5404 output_dims = [
5405 input.shape[0] + rng.integers(1, 10),
5406 oh,
5407 ow,
5408 input.shape[3],
5409 ]
5410 elif error_name == ErrorIf.ChannelMismatch:
5411 output_dims = [
5412 input.shape[0],
5413 oh,
5414 ow,
5415 input.shape[3] + rng.integers(1, 10),
5416 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005417 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005418 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005419
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005420 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005421
5422 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005423 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005424 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005425
5426 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005427 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005428 if error_name == ErrorIf.ConvOutputShapeMismatch:
5429 choices = [1, 2, 3]
5430 change = rng.choice(choices)
5431 if change in [1, 3]:
5432 output_shape[1] = output_shape[1] + rng.choice(choices)
5433 if change in [2, 3]:
5434 output_shape[2] = output_shape[2] + rng.choice(choices)
5435
James Ward8b390432022-08-12 20:48:56 +01005436 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005437 # Pick some potentially correct output dtype if input type is incorrect
5438 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005439 else:
James Ward8b390432022-08-12 20:48:56 +01005440 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005441
5442 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005443 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005444 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005445 else:
5446 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005447 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005448 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005449
Kevin Cheng550ccc52021-03-03 11:21:43 -08005450 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005451
5452 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005453 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5454 outputs = []
5455
5456 assert ifm1.dtype == ifm2.dtype
5457 input_dtype = ifm1.dtype
5458
5459 if error_name != ErrorIf.FFTInputShapeMismatch:
5460 assert ifm1.shape == ifm2.shape
5461
5462 input_shape = ifm1.shape
5463 if error_name != ErrorIf.WrongRank:
5464 assert len(input_shape) == 3
5465
5466 output_shape = input_shape.copy()
5467 output_dtype = input_dtype
5468
5469 if error_name == ErrorIf.WrongOutputType:
5470 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005471 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005472 output_dtype = rng.choice(wrong_dtypes)
5473 elif error_name == ErrorIf.BatchMismatch:
5474 output_shape[0] += rng.integers(1, 10)
5475 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5476 modify_dim = rng.choice([1, 2])
5477 output_shape[modify_dim] += rng.integers(1, 10)
5478
5479 outputs.append(serializer.addOutput(output_shape, output_dtype))
5480 outputs.append(serializer.addOutput(output_shape, output_dtype))
5481 return outputs
5482
5483 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005484 def rfft2dOp(serializer, rng, value, error_name=None):
5485 outputs = []
5486
5487 input_shape = value.shape
5488 if error_name != ErrorIf.WrongRank:
5489 assert len(input_shape) == 3
5490
5491 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5492
5493 output_dtype = value.dtype
5494 if error_name == ErrorIf.WrongOutputType:
5495 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005496 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005497 output_dtype = rng.choice(wrong_dtypes)
5498 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005499 output_shape[0] += rng.integers(1, 10)
5500 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5501 modify_dim = rng.choice([1, 2])
5502 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005503
5504 outputs.append(serializer.addOutput(output_shape, output_dtype))
5505 outputs.append(serializer.addOutput(output_shape, output_dtype))
5506 return outputs