blob: b1c53f5aa4f667058881c588d7712d9ea1f8c143 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
79 for dtype in (DType.FP32, DType.FP16, DType.BF16):
80 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
170 elif dtype in (DType.INT32, DType.SHAPE):
171 # restricting too large value for SHAPE
172 rng = (-(1 << 31), (1 << 31))
173 elif dtype == DType.INT48:
174 rng = (-(1 << 47), (1 << 47))
175 else:
176 raise Exception("Unknown dtype: {}".format(dtype))
177
178 if not high_inclusive:
179 # Exclusive high: low <= range < high
180 return rng
181 else:
182 # Inclusive range: low <= range <= high
183 return (rng[0], rng[1] - 1)
184
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000185 def getRandTensor(self, shape, dtype, data_range=None):
186 if data_range is None:
187 low, high = self.getDTypeRange(dtype)
188 else:
189 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100190
Eric Kunzee5e26762020-10-13 16:11:07 -0700191 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100194 return np.int64(self.rng.integers(low=low, high=high, size=shape))
195 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
196 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
197
198 if dtype == DType.FP16:
199 return np.float16(f_tensor)
200 else:
201 f32_tensor = np.float32(f_tensor)
202 if dtype == DType.BF16:
203 # Floor the last 16 bits of each f32 value
204 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
205 else:
206 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700207 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100208 # All other integer types
209 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 placeholders = []
213
Kevin Cheng989cb052021-04-28 16:29:44 -0700214 assert len(shape_list) == len(dtype_list)
215
Jeremy Johnson1271c442023-09-05 11:39:26 +0100216 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700217 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100218 if not self.args.lazy_data_gen:
219 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700220 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700221
222 return placeholders
223
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700225 consts = []
226
Kevin Cheng989cb052021-04-28 16:29:44 -0700227 assert len(shape_list) == len(dtype_list)
228
Jeremy Johnson1271c442023-09-05 11:39:26 +0100229 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700230 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100231 if not self.args.lazy_data_gen:
232 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700233 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700234
235 return consts
236
237 def makeShape(self, rank):
238 if self.targetted_shape:
239 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800240 return np.int32(
241 self.rng.integers(
242 low=self.args.tensor_shape_range[0],
243 high=self.args.tensor_shape_range[1],
244 size=rank,
245 )
246 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700247
248 def setTargetShape(self, shape):
249 self.targetted_shape = shape
250
251 def randInt(self, low=0, high=256):
252 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
253
254 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100255 low, high = self.getDTypeRange(dtype)
256
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100257 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100258 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100259 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100261 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100262 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
263 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700264 elif dtype == DType.BOOL:
265 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700266 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700267 # Special size
268 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700269
270 return np.int32(self.rng.integers(low, high, size=1))[0]
271
272 def shapeStr(self, shape):
273
274 sStr = []
275 # Convert to strings
276 for i in shape:
277 sStr.append(str(i))
278
Kevin Cheng550ccc52021-03-03 11:21:43 -0800279 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700280
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100281 def typeStr(self, dtype):
282 if isinstance(dtype, list) or isinstance(dtype, tuple):
283 assert len(dtype) >= 2
284 strs = [self.typeStr(t) for t in dtype]
285 # Limit types to the first 2 as the 3rd is the accumulator
286 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700287 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100288 if dtype in gtu.DTYPE_ATTRIBUTES:
289 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700290 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100291 raise Exception(
292 "Unknown dtype, cannot convert to string: {}".format(dtype)
293 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700294
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100295 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100296 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100297 if dtype in gtu.DTYPE_ATTRIBUTES:
298 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700299 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100300 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700301
Luke Hutton57287132023-02-06 14:54:18 +0000302 def constrictBatchSize(self, shape):
303 # Limit the batch size unless an explicit target shape set
304 if self.args.max_batch_size and not self.args.target_shapes:
305 shape[0] = min(shape[0], self.args.max_batch_size)
306 return shape
307
James Ward30124a82023-02-02 14:56:33 +0000308 def makeDimension(self):
309 return self.randInt(
310 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
311 )
312
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100313 def tensorComplianceMetaData(
314 self, op, inputType, argsDict, outputTensor, errorName
315 ):
Jeremy Johnson708da822023-11-15 16:25:45 +0000316 # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet
317 UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100318 if (
319 errorName
320 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000321 or (
322 not gtu.dtypeIsSupportedByCompliance(inputType)
323 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
324 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100325 ):
326 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100327 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100328
Jeremy Johnson1271c442023-09-05 11:39:26 +0100329 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100330 compliance_tens = {
331 "mode": None,
332 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
333 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
334 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100335 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
336 mode = gtu.ComplianceMode.DOT_PRODUCT
337 compliance_tens["dot_product_info"] = {
338 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100339 "ks": int(argsDict["ksb"])
340 if "ksb" in argsDict
341 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100342 }
343 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
344 mode = gtu.ComplianceMode.FP_SPECIAL
345 elif "compliance" in op and "ulp" in op["compliance"]:
346 mode = gtu.ComplianceMode.ULP
347 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
348 elif op["op"] == Op.REDUCE_PRODUCT:
349 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000350 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000351 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000352 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000353 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
354 compliance_tens["abs_error_info"] = {
355 "lower_bound": op["compliance"]["abs_error_lower_bound"]
356 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100357 else:
358 mode = gtu.ComplianceMode.EXACT
359 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
360
361 return compliance_tens
362
363 # Build Op functions
364 # Create the output tensor (calling OutputShaper as needed)
365 # Do final tweaks to attributes (if necessary for errorIf)
366 # Add Op into graph
367 # Return resulting tensor information or BuildInfo
368
369 class BuildInfo:
370 """Enhanced build information containing result tensor and associated compliance dict."""
371
372 def __init__(self, resultTensor, complianceDict):
373 self.resultTensor = resultTensor
374 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700375
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000376 def build_unary(
377 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
378 ):
379 assert len(inputs) == 1
380 a = inputs[0]
381 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100382
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000383 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100384
385 # Ensure new output type has correct qinfo
386 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000387 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000388 qinfo = [
389 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000390 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000391 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100392
393 # Invalidate Input/Output list for error if checks.
394 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000395 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100396 pCount, cCount = op["operands"]
397 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000398 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
399 self, error_name, input_list, output_list
400 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100401
Les Bell729b0352021-11-24 10:28:21 +0000402 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100403 self.ser,
404 validator_fcns,
405 error_name,
406 op=op,
407 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000408 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000409 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000410 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100411 input_list=input_list,
412 output_list=output_list,
413 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000414 ):
415 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100416
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000417 attr = None
418 if op["op"] == Op.NEGATE:
419 attr = ts.TosaSerializerAttribute()
420 attr.NegateAttribute(qinfo[0], qinfo[1])
421
422 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000423
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000424 compliance = self.tensorComplianceMetaData(
425 op, a.dtype, args_dict, result_tensor, error_name
426 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000427 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700428
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000429 def build_binary_broadcast(
430 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
431 ):
432 assert len(inputs) == 2
433 a, b = inputs
434 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000435 self.ser, self.rng, a, b, error_name
436 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100437
438 # Invalidate Input/Output list for error if checks.
439 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000440 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100441 pCount, cCount = op["operands"]
442 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000443 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
444 self, error_name, input_list, output_list
445 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100446
Les Bell729b0352021-11-24 10:28:21 +0000447 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100448 self.ser,
449 validator_fcns,
450 error_name,
451 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000452 input1=a,
453 input2=b,
454 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000455 output_dtype=result_tensor.dtype,
456 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100457 input_list=input_list,
458 output_list=output_list,
459 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000460 ):
461 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100462
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000463 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000464
Jeremy Johnson9a758382023-11-07 16:27:35 +0000465 compliance = self.tensorComplianceMetaData(
466 op, a.dtype, args_dict, result_tensor, error_name
467 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000468
469 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700470
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100471 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700472 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000473 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700474 return result_tens
475
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000476 def build_arithmetic_right_shift(
477 self, op, a, b, round, validator_fcns=None, error_name=None
478 ):
479 result_tens = OutputShaper.binaryBroadcastOp(
480 self.ser, self.rng, a, b, error_name
481 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100482
483 # Invalidate Input/Output list for error if checks.
484 input_list = [a.name, b.name]
485 output_list = [result_tens.name]
486 pCount, cCount = op["operands"]
487 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000488 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
489 self, error_name, input_list, output_list
490 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100491
Les Bell729b0352021-11-24 10:28:21 +0000492 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100493 self.ser,
494 validator_fcns,
495 error_name,
496 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000497 input1=a,
498 input2=b,
499 input_dtype=a.dtype,
500 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000501 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100502 input_list=input_list,
503 output_list=output_list,
504 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000505 ):
506 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800507
508 attr = ts.TosaSerializerAttribute()
509 attr.ArithmeticRightShiftAttribute(round)
510
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000511 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800512 return result_tens
513
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100514 def build_mul(
515 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
516 ):
517 assert len(inputs) == 2
518 a, b = inputs
519 shift = args_dict["shift"]
520
521 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000522 self.ser, self.rng, a, b, error_name
523 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700524
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100525 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100526 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100527 result_tensor.setDtype(DType.INT32)
528
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100529 if error_name == ErrorIf.WrongOutputType:
530 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
531 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100532 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100533
534 # Invalidate Input/Output list for error if checks.
535 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100536 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100537 pCount, cCount = op["operands"]
538 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000539 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
540 self, error_name, input_list, output_list
541 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100542
Les Bell729b0352021-11-24 10:28:21 +0000543 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100544 self.ser,
545 validator_fcns,
546 error_name,
547 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000548 input1=a,
549 input2=b,
550 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100551 output_dtype=result_tensor.dtype,
552 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100553 input_list=input_list,
554 output_list=output_list,
555 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000556 ):
557 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700558
Kevin Chengaee1fac2020-11-11 13:54:06 -0800559 attr = ts.TosaSerializerAttribute()
560 attr.MulAttribute(shift)
561
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000562 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100563
564 compliance = self.tensorComplianceMetaData(
565 op, a.dtype, args_dict, result_tensor, error_name
566 )
567
568 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700569
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100570 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
571 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700572
Kevin Chengfe392ce2021-10-18 21:51:55 +0000573 attr = ts.TosaSerializerAttribute()
574 attr.TableAttribute(table)
575
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100576 # Invalidate Input/Output list for error if checks.
577 input_list = [a.name]
578 output_list = [result_tens.name]
579 pCount, cCount = op["operands"]
580 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000581 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
582 self, error_name, input_list, output_list
583 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100584
Les Bell729b0352021-11-24 10:28:21 +0000585 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100586 self.ser,
587 validator_fcns,
588 error_name,
589 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000590 input_shape=a.shape,
591 input_dtype=a.dtype,
592 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000593 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100594 input_list=input_list,
595 output_list=output_list,
596 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000597 ):
598 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100599
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000600 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700601
602 return result_tens
603
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000604 def build_select(
605 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
606 ):
607 assert len(inputs) == 3
608 cond, a, b = inputs
609
610 result_tensor = OutputShaper.selectOp(
611 self.ser, self.rng, cond, a, b, error_name
612 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100613
614 # Invalidate Input/Output list for error if checks.
615 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000616 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100617 pCount, cCount = op["operands"]
618 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000619 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
620 self, error_name, input_list, output_list
621 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622
Les Bell729b0352021-11-24 10:28:21 +0000623 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100624 self.ser,
625 validator_fcns,
626 error_name,
627 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000628 input1=cond,
629 input2=a,
630 input3=b,
631 input_shape=a.shape,
632 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000633 output_dtype=result_tensor.dtype,
634 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100635 input_list=input_list,
636 output_list=output_list,
637 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000638 ):
639 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100640
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000641 self.ser.addOperator(
642 op["op"],
643 input_list,
644 output_list,
645 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000646 compliance = self.tensorComplianceMetaData(
647 op, a.dtype, args_dict, result_tensor, error_name
648 )
649
650 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700651
Jeremy Johnsona0150012023-11-15 15:52:06 +0000652 def build_comparison(
653 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
654 ):
655 assert len(inputs) == 2
656 a, b = inputs
657
658 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000659 self.ser, self.rng, a, b, error_name
660 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100661
662 # Invalidate Input/Output list for error if checks.
663 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000664 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100665 pCount, cCount = op["operands"]
666 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
668 self, error_name, input_list, output_list
669 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100670
Les Bell729b0352021-11-24 10:28:21 +0000671 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100672 self.ser,
673 validator_fcns,
674 error_name,
675 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000676 input1=a,
677 input2=b,
678 input_shape=a.shape,
679 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000680 output_shape=result_tensor.shape,
681 output_dtype=result_tensor.dtype,
682 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100683 input_list=input_list,
684 output_list=output_list,
685 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000686 ):
687 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100688
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000689 self.ser.addOperator(
690 op["op"],
691 input_list,
692 output_list,
693 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000694
695 compliance = self.tensorComplianceMetaData(
696 op, a.dtype, args_dict, result_tensor, error_name
697 )
698 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700699
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000700 def build_argmax(
701 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
702 ):
703 assert len(inputs) == 1
704 a = inputs[0]
705 axis = args_dict["axis"]
706 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100707
708 # Invalidate Input/Output list for error if checks.
709 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000710 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100711 pCount, cCount = op["operands"]
712 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000713 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
714 self, error_name, input_list, output_list
715 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100716
Les Bell729b0352021-11-24 10:28:21 +0000717 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100718 self.ser,
719 validator_fcns,
720 error_name,
721 op=op,
722 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000723 input_shape=a.shape,
724 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000725 output_shape=result_tensor.shape,
726 output_dtype=result_tensor.dtype,
727 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100728 input_list=input_list,
729 output_list=output_list,
730 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000731 ):
732 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700733
734 attr = ts.TosaSerializerAttribute()
735 attr.AxisAttribute(axis)
736
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000737 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000738
739 compliance = self.tensorComplianceMetaData(
740 op, inputs[0].dtype, args_dict, result_tensor, error_name
741 )
742 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700743
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000744 def build_pool2d(
745 self,
746 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100747 inputs,
748 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000749 validator_fcns=None,
750 error_name=None,
751 qinfo=None,
752 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100753 assert len(inputs) == 1
754 input = inputs[0]
755 # max_pool has no accum_dtype
756 accum_dtype = (
757 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
758 )
759 stride = args_dict["stride"]
760 pad = args_dict["pad"]
761 kernel = args_dict["kernel"]
762
Jeremy Johnson0601f802023-11-08 16:28:09 +0000763 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000764 self.ser, self.rng, input, kernel, stride, pad, error_name
765 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100766
767 # Ensure new output type has correct qinfo
768 if error_name == ErrorIf.WrongInputType:
769 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000770 qinfo = [
771 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000772 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000773 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100774
775 # Invalidate Input/Output list for error if checks.
776 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000777 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100778 pCount, cCount = op["operands"]
779 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000780 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
781 self, error_name, input_list, output_list
782 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100783
Les Bell729b0352021-11-24 10:28:21 +0000784 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100785 self.ser,
786 validator_fcns,
787 error_name,
788 op=op,
789 input_shape=input.shape,
790 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000791 output_shape=result_tensor.shape,
792 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100793 kernel=kernel,
794 stride=stride,
795 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000796 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000797 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100798 input_list=input_list,
799 output_list=output_list,
800 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000801 ):
802 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700803
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000804 if qinfo is None:
805 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700806
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000807 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100808 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000809
810 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700811
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100812 compliance = self.tensorComplianceMetaData(
813 op, inputs[0].dtype, args_dict, result_tensor, error_name
814 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100815
816 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100817
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000818 def build_conv2d(
819 self,
820 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100821 inputs,
822 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000823 validator_fcns=None,
824 error_name=None,
825 qinfo=None,
826 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100827 assert len(inputs) == 3
828 ifm, filter, bias = inputs
829 accum_dtype = args_dict["acc_type"]
830 strides = args_dict["stride"]
831 padding = args_dict["pad"]
832 dilations = args_dict["dilation"]
833
Kevin Cheng550ccc52021-03-03 11:21:43 -0800834 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100835 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100836 self.ser,
837 self.rng,
838 ifm,
839 filter,
840 accum_dtype,
841 strides,
842 padding,
843 dilations,
844 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000845 )
846
847 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000848 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
849 DType.INT8,
850 DType.UINT8,
851 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000852 qinfo = [
853 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100854 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000855 ]
Les Bell0e027d42021-11-09 14:42:14 +0000856
857 # Invalidate Input/Output list for error_if checks.
858 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100859 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000860 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000861 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
862 self, error_name, input_list, output_list
863 )
Les Bell0e027d42021-11-09 14:42:14 +0000864
Les Bell729b0352021-11-24 10:28:21 +0000865 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000866 self.ser,
867 validator_fcns,
868 error_name,
869 op=op,
870 input_dtype=ifm.dtype,
871 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100872 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000873 qinfo=qinfo,
874 input_list=input_list,
875 num_operands=num_operands,
876 output_list=output_list,
877 pad=padding,
878 stride=strides,
879 dilation=dilations,
880 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100881 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100882 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000883 ):
884 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700885
Tai Lyd3797f02023-11-15 23:06:19 +0000886 # TODO - Test local_bound, for now set local bound attribute to False
887 local_bound = False
888
Eric Kunzee5e26762020-10-13 16:11:07 -0700889 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000890 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700891
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000892 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100893
894 compliance = self.tensorComplianceMetaData(
895 op, ifm.dtype, args_dict, result_tensor, error_name
896 )
897
898 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700899
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000900 def build_conv3d(
901 self,
902 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100903 inputs,
904 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000905 validator_fcns=None,
906 error_name=None,
907 qinfo=None,
908 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100909 assert len(inputs) == 3
910 ifm, filter, bias = inputs
911 accum_dtype = args_dict["acc_type"]
912 strides = args_dict["stride"]
913 padding = args_dict["pad"]
914 dilations = args_dict["dilation"]
915
Kevin Cheng1533b852021-09-01 12:51:58 -0700916 assert len(padding) == 6
917 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100918 self.ser,
919 self.rng,
920 ifm,
921 filter,
922 accum_dtype,
923 strides,
924 padding,
925 dilations,
926 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000927 )
928
929 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000930 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
931 DType.INT8,
932 DType.UINT8,
933 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000934 qinfo = [
935 TosaQuantGen.getZeroPoint(self, ifm.dtype),
936 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
937 ]
Les Bell0e027d42021-11-09 14:42:14 +0000938
939 # Invalidate Input/Output list for error_if checks.
940 input_list = [ifm.name, filter.name, bias.name]
941 output_list = [result_tens.name]
942 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000943 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
944 self, error_name, input_list, output_list
945 )
Les Bell0e027d42021-11-09 14:42:14 +0000946
Les Bell729b0352021-11-24 10:28:21 +0000947 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000948 self.ser,
949 validator_fcns,
950 error_name,
951 op=op,
952 input_dtype=ifm.dtype,
953 weight_dtype=filter.dtype,
954 output_dtype=result_tens.dtype,
955 qinfo=qinfo,
956 input_list=input_list,
957 num_operands=num_operands,
958 output_list=output_list,
959 pad=padding,
960 stride=strides,
961 dilation=dilations,
962 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100963 weight_shape=filter.shape,
964 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000965 ):
966 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700967
Tai Lyd3797f02023-11-15 23:06:19 +0000968 # TODO - Test local_bound, for now set local bound attribute to False
969 local_bound = False
970
Kevin Cheng1533b852021-09-01 12:51:58 -0700971 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000972 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700973
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000974 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700975 return result_tens
976
Kevin Cheng550ccc52021-03-03 11:21:43 -0800977 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000978 self,
979 op,
980 ifm,
981 filter,
982 bias,
James Ward8b390432022-08-12 20:48:56 +0100983 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000984 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700985 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000986 output_shape,
987 validator_fcns=None,
988 error_name=None,
989 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800990 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700991 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000992 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100993 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000994 )
Les Bell0e027d42021-11-09 14:42:14 +0000995
996 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000997 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
998 DType.INT8,
999 DType.UINT8,
1000 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001001 qinfo = [
1002 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1003 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1004 ]
Les Bell0e027d42021-11-09 14:42:14 +00001005
1006 # Invalidate Input/Output list for error_if checks.
1007 input_list = [ifm.name, filter.name, bias.name]
1008 output_list = [result_tens.name]
1009 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001010 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1011 self, error_name, input_list, output_list
1012 )
Les Bell0e027d42021-11-09 14:42:14 +00001013
Les Bell729b0352021-11-24 10:28:21 +00001014 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001015 self.ser,
1016 validator_fcns,
1017 error_name,
1018 op=op,
1019 input_dtype=ifm.dtype,
1020 weight_dtype=filter.dtype,
1021 output_dtype=result_tens.dtype,
1022 qinfo=qinfo,
1023 input_list=input_list,
1024 num_operands=num_operands,
1025 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001026 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001027 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001028 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001029 weight_shape=filter.shape,
1030 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001031 ):
1032 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001033
Tai Lyd3797f02023-11-15 23:06:19 +00001034 # TODO - Test local_bound, for now set local bound attribute to False
1035 local_bound = False
1036
Eric Kunzee5e26762020-10-13 16:11:07 -07001037 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001038 attr.TransposeConvAttribute(
1039 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1040 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001041
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001042 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001043 return result_tens
1044
Kevin Cheng550ccc52021-03-03 11:21:43 -08001045 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001046 self,
1047 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001048 inputs,
1049 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001050 validator_fcns=None,
1051 error_name=None,
1052 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001053 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001054 assert len(inputs) == 3
1055 ifm, filter, bias = inputs
1056 accum_dtype = args_dict["acc_type"]
1057 strides = args_dict["stride"]
1058 padding = args_dict["pad"]
1059 dilations = args_dict["dilation"]
1060
Kevin Cheng550ccc52021-03-03 11:21:43 -08001061 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001062 self.ser,
1063 self.rng,
1064 ifm,
1065 filter,
1066 accum_dtype,
1067 strides,
1068 padding,
1069 dilations,
1070 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001071 )
1072
1073 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001074 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1075 DType.INT8,
1076 DType.UINT8,
1077 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001078 qinfo = [
1079 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1080 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1081 ]
Les Bell0e027d42021-11-09 14:42:14 +00001082
1083 # Invalidate Input/Output list for error_if checks.
1084 input_list = [ifm.name, filter.name, bias.name]
1085 output_list = [result_tens.name]
1086 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001087 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1088 self, error_name, input_list, output_list
1089 )
Les Bell0e027d42021-11-09 14:42:14 +00001090
Les Bell729b0352021-11-24 10:28:21 +00001091 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001092 self.ser,
1093 validator_fcns,
1094 error_name,
1095 op=op,
1096 input_dtype=ifm.dtype,
1097 weight_dtype=filter.dtype,
1098 output_dtype=result_tens.dtype,
1099 qinfo=qinfo,
1100 input_list=input_list,
1101 num_operands=num_operands,
1102 output_list=output_list,
1103 pad=padding,
1104 stride=strides,
1105 dilation=dilations,
1106 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001107 weight_shape=filter.shape,
1108 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001109 ):
1110 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001111
Tai Lyd3797f02023-11-15 23:06:19 +00001112 # TODO - Test local_bound, for now set local bound attribute to False
1113 local_bound = False
1114
Eric Kunzee5e26762020-10-13 16:11:07 -07001115 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001116 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001117
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001118 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001119 return result_tens
1120
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001121 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001122 self,
1123 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001124 inputs,
1125 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001126 validator_fcns=None,
1127 error_name=None,
1128 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001129 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001130 assert len(inputs) == 3
1131 ifm, filter, bias = inputs
1132 accum_dtype = args_dict["acc_type"]
1133
1134 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001135 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001136 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001137
1138 # Invalidate Input/Output list for error if checks.
1139 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001140 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001141 pCount, cCount = op["operands"]
1142 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001143 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1144 self, error_name, input_list, output_list
1145 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001146
Les Bell729b0352021-11-24 10:28:21 +00001147 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001148 self.ser,
1149 validator_fcns,
1150 error_name,
1151 op=op,
1152 input_shape=ifm.shape,
1153 input_dtype=ifm.dtype,
1154 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001155 output_shape=result_tensor.shape,
1156 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001157 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001158 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001159 input_list=input_list,
1160 output_list=output_list,
1161 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001162 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001163 ):
1164 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001165
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001166 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001167 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001168
1169 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001170
1171 compliance = self.tensorComplianceMetaData(
1172 op, ifm.dtype, args_dict, result_tensor, error_name
1173 )
1174
1175 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001176
James Ward8b390432022-08-12 20:48:56 +01001177 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001178 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001179 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001180 assert len(inputs) == 2
1181 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001182 accum_dtype = args_dict["acc_type"]
1183 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001184 self.ser, self.rng, a, b, accum_dtype, error_name
1185 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001186
1187 # Invalidate Input/Output list for error if checks.
1188 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001189 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001190 pCount, cCount = op["operands"]
1191 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001192 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1193 self, error_name, input_list, output_list
1194 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001195
Les Bell729b0352021-11-24 10:28:21 +00001196 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001197 self.ser,
1198 validator_fcns,
1199 error_name,
1200 op=op,
1201 input_shape=a.shape,
1202 input_dtype=a.dtype,
1203 input2_shape=b.shape,
1204 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001205 output_shape=result_tensor.shape,
1206 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001207 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001208 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001209 input_list=input_list,
1210 output_list=output_list,
1211 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001212 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001213 ):
1214 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001215
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001216 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001217 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001218
1219 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001220
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001221 compliance = self.tensorComplianceMetaData(
1222 op, a.dtype, args_dict, result_tensor, error_name
1223 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001224
1225 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001226
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001227 def build_reduce(
1228 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1229 ):
1230 assert len(inputs) == 1
1231 a = inputs[0]
1232 axis = args_dict["axis"]
1233 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001234
1235 # Invalidate Input/Output list for error if checks.
1236 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001237 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001238 pCount, cCount = op["operands"]
1239 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001240 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1241 self, error_name, input_list, output_list
1242 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001243
Les Bell729b0352021-11-24 10:28:21 +00001244 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001245 self.ser,
1246 validator_fcns,
1247 error_name,
1248 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001249 axis=axis,
1250 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001251 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001252 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001253 output_dtype=result_tensor.dtype,
1254 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001255 input_list=input_list,
1256 output_list=output_list,
1257 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001258 ):
1259 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001260
1261 attr = ts.TosaSerializerAttribute()
1262 attr.AxisAttribute(axis)
1263
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001264 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001265
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001266 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1267 # Number of products - needed for compliance
1268 args_dict["n"] = a.shape[axis]
1269
1270 compliance = self.tensorComplianceMetaData(
1271 op, a.dtype, args_dict, result_tensor, error_name
1272 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001273
1274 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001275
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001276 def build_clamp(
1277 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1278 ):
1279 assert len(inputs) == 1
1280 a = inputs[0]
1281
1282 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001283
Jeremy Johnson18e26662021-07-22 16:15:29 +01001284 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001285
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001286 if error_name == ErrorIf.MaxSmallerMin:
1287 # Make sure the numbers are different to invoke this error
1288 while v[0] == v[1]:
1289 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1290 max_val = min(v)
1291 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001292 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001293 max_val = max(v)
1294 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001295
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001296 # Invalidate Input/Output list for error if checks.
1297 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001298 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001299 pCount, cCount = op["operands"]
1300 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001301 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1302 self, error_name, input_list, output_list
1303 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001304
Les Bell729b0352021-11-24 10:28:21 +00001305 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001306 self.ser,
1307 validator_fcns,
1308 error_name,
1309 op=op,
1310 max_val=max_val,
1311 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001312 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001313 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001314 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001315 output_dtype=result_tensor.dtype,
1316 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001317 input_list=input_list,
1318 output_list=output_list,
1319 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001320 ):
1321 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001322
1323 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001324 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1325 if a.dtype == DType.FP16:
1326 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1327 min_val = min_val.astype(np.float32)
1328 max_val = max_val.astype(np.float32)
1329
1330 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001331 else:
James Ward34071252022-12-07 15:48:47 +00001332 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001333
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001334 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001335
1336 compliance = self.tensorComplianceMetaData(
1337 op, a.dtype, args_dict, result_tensor, error_name
1338 )
1339
1340 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001341
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001342 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1343 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001344 attr = ts.TosaSerializerAttribute()
1345
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001346 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001347
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001348 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001349 return result_tens
1350
1351 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001352 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1353 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001354
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001355 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001356 return result_tens
1357
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001358 def build_activation(
1359 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1360 ):
1361 assert len(inputs) == 1
1362 a = inputs[0]
1363
1364 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001365
1366 # Invalidate Input/Output list for error if checks.
1367 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001368 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001369 pCount, cCount = op["operands"]
1370 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001371 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1372 self, error_name, input_list, output_list
1373 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374
Les Bell729b0352021-11-24 10:28:21 +00001375 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001376 self.ser,
1377 validator_fcns,
1378 error_name,
1379 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001380 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001381 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001382 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001383 output_dtype=result_tensor.dtype,
1384 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001385 input_list=input_list,
1386 output_list=output_list,
1387 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001388 ):
1389 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001390
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001391 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001392
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001393 compliance = self.tensorComplianceMetaData(
1394 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001395 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001396
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001397 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001398
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001399 def build_concat(
1400 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1401 ):
1402 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001403 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001404 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001405
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001406 result_tensor = OutputShaper.concatOp(
1407 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001408 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001409
Matthew Haddon818ab902021-07-27 09:12:49 +01001410 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001411 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001412 input_tensor_names.append(tensor.name)
1413
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001414 # Invalidate Input/Output list for error if checks.
1415 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001416 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001417 pCount, cCount = op["operands"]
1418 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001419 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1420 self, error_name, input_list, output_list
1421 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001422
Les Bell729b0352021-11-24 10:28:21 +00001423 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424 self.ser,
1425 validator_fcns,
1426 error_name,
1427 op=op,
1428 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001429 input_shape=inputs[0].shape,
1430 output_shape=result_tensor.shape,
1431 input_dtype=inputs[0].dtype,
1432 output_dtype=result_tensor.dtype,
1433 inputs=inputs,
1434 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001435 input_list=input_list,
1436 output_list=output_list,
1437 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001438 ):
1439 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001440
1441 attr = ts.TosaSerializerAttribute()
1442 attr.AxisAttribute(axis)
1443
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001444 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001445
1446 compliance = self.tensorComplianceMetaData(
1447 op, inputs[0].dtype, args_dict, result_tensor, error_name
1448 )
1449
1450 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001451
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001452 def build_pad(
1453 self,
1454 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001455 inputs,
1456 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001457 validator_fcns=None,
1458 error_name=None,
1459 qinfo=None,
1460 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001461 assert len(inputs) == 1
1462 a = inputs[0]
1463 padding = args_dict["pad"]
1464 pad_const_int = args_dict["pad_const_int"]
1465 pad_const_float = args_dict["pad_const_fp"]
1466
1467 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001468
Kevin Chengfe392ce2021-10-18 21:51:55 +00001469 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001470 attr.PadAttribute(
1471 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1472 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001473
Matthew Haddone807aae2021-10-11 18:12:58 +01001474 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001475 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001476 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001477 pCount, cCount = op["operands"]
1478 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001479 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1480 self, error_name, input_list, output_list
1481 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001482
Les Bell729b0352021-11-24 10:28:21 +00001483 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001484 self.ser,
1485 validator_fcns,
1486 error_name,
1487 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001488 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001489 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001490 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001491 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001492 pad=padding,
1493 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001494 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001495 input_list=input_list,
1496 output_list=output_list,
1497 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001498 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001499 ):
1500 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001501
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001502 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001503
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001504 compliance = self.tensorComplianceMetaData(
1505 op, a.dtype, args_dict, result_tensor, error_name
1506 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001507
1508 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001509
Won Jeona21b2e82023-08-10 10:33:01 +00001510 def build_dim(
1511 self,
1512 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001513 inputs,
1514 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001515 validator_fcns=None,
1516 error_name=None,
1517 qinfo=None,
1518 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001519 assert len(inputs) == 1
1520 a = inputs[0]
1521 axis = args_dict["axis"]
1522 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001523
1524 # Invalidate Input/Output list for error if checks.
1525 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001526 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001527 pCount, cCount = op["operands"]
1528 num_operands = pCount + cCount
1529 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1530 self, error_name, input_list, output_list
1531 )
1532
1533 if not TosaErrorValidator.evValidateErrorIfs(
1534 self.ser,
1535 validator_fcns,
1536 error_name,
1537 op=op,
1538 axis=axis,
1539 input_shape=a.shape,
1540 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001541 output_shape=result_tensor.shape,
1542 output_dtype=result_tensor.dtype,
1543 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001544 input_list=input_list,
1545 output_list=output_list,
1546 num_operands=num_operands,
1547 ):
1548 return None
1549
1550 attr = ts.TosaSerializerAttribute()
1551 attr.AxisAttribute(axis)
1552
1553 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001554 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001555
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001556 def build_reshape(
1557 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1558 ):
1559 assert len(inputs) == 1
1560 a = inputs[0]
1561 new_shape = args_dict["new_shape"]
1562 result_tensor = OutputShaper.reshapeOp(
1563 self.ser, self.rng, a, new_shape, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001564 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001565
1566 # Invalidate Input/Output list for error if checks.
1567 input_list = [a.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001568 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001569 pCount, cCount = op["operands"]
1570 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001571 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1572 self, error_name, input_list, output_list
1573 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001574
Les Bell729b0352021-11-24 10:28:21 +00001575 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001576 self.ser,
1577 validator_fcns,
1578 error_name,
1579 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001580 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001581 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001582 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001583 output_dtype=result_tensor.dtype,
1584 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001585 input_list=input_list,
1586 output_list=output_list,
1587 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001588 ):
1589 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001590
1591 attr = ts.TosaSerializerAttribute()
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001592 attr.ReshapeAttribute(new_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07001593
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001594 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001595
1596 compliance = self.tensorComplianceMetaData(
1597 op, a.dtype, args_dict, result_tensor, error_name
1598 )
1599
1600 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001601
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001602 def build_reverse(
1603 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1604 ):
1605 assert len(inputs) == 1
1606 a = inputs[0]
1607 axis = args_dict["axis"]
1608 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001609
1610 # Invalidate Input/Output list for error if checks.
1611 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001612 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001613 pCount, cCount = op["operands"]
1614 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001615 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1616 self, error_name, input_list, output_list
1617 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001618
Les Bell729b0352021-11-24 10:28:21 +00001619 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001620 self.ser,
1621 validator_fcns,
1622 error_name,
1623 op=op,
1624 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001625 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001626 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001627 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001628 output_dtype=result_tensor.dtype,
1629 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001630 input_list=input_list,
1631 output_list=output_list,
1632 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001633 ):
1634 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001635
1636 attr = ts.TosaSerializerAttribute()
1637 attr.AxisAttribute(axis)
1638
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001639 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001640 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001641
Matthew Haddone807aae2021-10-11 18:12:58 +01001642 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1643 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001644
Kevin Chengfe392ce2021-10-18 21:51:55 +00001645 attr = ts.TosaSerializerAttribute()
1646 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001647
Matthew Haddone807aae2021-10-11 18:12:58 +01001648 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001649 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001650 output_list = [result_tens.name]
1651 pCount, cCount = op["operands"]
1652 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001653 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1654 self, error_name, input_list, output_list
1655 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001656
Les Bell729b0352021-11-24 10:28:21 +00001657 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001658 self.ser,
1659 validator_fcns,
1660 error_name,
1661 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001662 input_shape=a.shape,
1663 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001664 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001665 input_dtype=a.dtype,
1666 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001667 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001668 input_list=input_list,
1669 output_list=output_list,
1670 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001671 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001672 ):
1673 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001674
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001675 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001676 return result_tens
1677
Matthew Haddone807aae2021-10-11 18:12:58 +01001678 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001679 result_tens = OutputShaper.sliceOp(
1680 self.ser, self.rng, a, start, size, error_name
1681 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001682
1683 # Invalidate Input/Output list for error if checks.
1684 input_list = [a.name]
1685 output_list = [result_tens.name]
1686 pCount, cCount = op["operands"]
1687 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001688 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1689 self, error_name, input_list, output_list
1690 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001691
Les Bell729b0352021-11-24 10:28:21 +00001692 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001693 self.ser,
1694 validator_fcns,
1695 error_name,
1696 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001697 input_shape=a.shape,
1698 output_shape=result_tens.shape,
1699 input_dtype=a.dtype,
1700 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001701 start=start,
1702 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001703 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001704 input_list=input_list,
1705 output_list=output_list,
1706 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001707 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001708 ):
1709 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001710
1711 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001712 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001713
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001714 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001715 return result_tens
1716
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001717 def build_tile(
1718 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1719 ):
1720 assert len(inputs) == 1
1721 a = inputs[0]
1722 multiples = args_dict["multiples"]
1723
1724 result_tensor = OutputShaper.tileOp(
1725 self.ser, self.rng, a, multiples, error_name
1726 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001727
1728 # Invalidate Input/Output list for error if checks.
1729 input_list = [a.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001730 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001731 pCount, cCount = op["operands"]
1732 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001733 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1734 self, error_name, input_list, output_list
1735 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001736
Les Bell729b0352021-11-24 10:28:21 +00001737 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001738 self.ser,
1739 validator_fcns,
1740 error_name,
1741 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001742 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001743 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001744 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001745 output_dtype=result_tensor.dtype,
1746 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001747 input_list=input_list,
1748 output_list=output_list,
1749 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001750 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001751 ):
1752 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001753
1754 attr = ts.TosaSerializerAttribute()
1755 attr.TileAttribute(multiples)
1756
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001757 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001758
1759 compliance = self.tensorComplianceMetaData(
1760 op, a.dtype, args_dict, result_tensor, error_name
1761 )
1762
1763 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001764
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001765 def build_gather(
1766 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1767 ):
1768 assert len(inputs) == 2
1769 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001770
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001771 result_tensor = OutputShaper.gatherOp(
1772 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001773 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001774
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001775 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001776 input_list = [values.name, indices.name]
1777 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001778 pCount, cCount = op["operands"]
1779 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001780 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1781 self, error_name, input_list, output_list
1782 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001783
Les Bell729b0352021-11-24 10:28:21 +00001784 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001785 self.ser,
1786 validator_fcns,
1787 error_name,
1788 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001789 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001790 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001791 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001792 output_dtype=result_tensor.dtype,
1793 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001794 input_list=input_list,
1795 output_list=output_list,
1796 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001797 ):
1798 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001799
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001800 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001801
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001802 compliance = self.tensorComplianceMetaData(
1803 op, values.dtype, args_dict, result_tensor, error_name
1804 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001805
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001806 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001807
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001808 def build_scatter(
1809 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1810 ):
1811 assert len(inputs) == 3
1812 values_in, indices, input = inputs
1813 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001814 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001815 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001816
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001817 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001818 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001819 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001820 pCount, cCount = op["operands"]
1821 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001822 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1823 self, error_name, input_list, output_list
1824 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001825
Les Bell729b0352021-11-24 10:28:21 +00001826 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001827 self.ser,
1828 validator_fcns,
1829 error_name,
1830 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001831 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001832 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001833 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001834 output_dtype=result_tensor.dtype,
1835 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001836 input_list=input_list,
1837 output_list=output_list,
1838 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001839 ):
1840 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001841
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001842 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001843
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001844 compliance = self.tensorComplianceMetaData(
1845 op, values_in.dtype, args_dict, result_tensor, error_name
1846 )
1847
1848 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001849
Kevin Cheng550ccc52021-03-03 11:21:43 -08001850 def build_resize(
1851 self,
1852 op,
1853 input,
1854 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001855 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001856 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001857 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001858 input_dtype,
1859 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001860 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001861 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001862 ):
1863 result_tens = OutputShaper.resizeOp(
1864 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001865 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001866 input,
1867 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001868 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001869 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001870 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001871 input_dtype,
1872 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001873 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001874 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001875
Matthew Haddon848efb42021-09-09 12:30:53 +01001876 # Invalidate Input/Output list for error if checks.
1877 input_list = [input.name]
1878 output_list = [result_tens.name]
1879 pCount, cCount = op["operands"]
1880 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001881 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1882 self, error_name, input_list, output_list
1883 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001884
Les Bell729b0352021-11-24 10:28:21 +00001885 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001886 self.ser,
1887 validator_fcns,
1888 error_name,
1889 op=op,
1890 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001891 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001892 input_dtype=input_dtype,
1893 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001894 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001895 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001896 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001897 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001898 input_list=input_list,
1899 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001900 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001901 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001902 ):
1903 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001904
Eric Kunzee5e26762020-10-13 16:11:07 -07001905 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001906
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001907 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001908
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001909 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001910 return result_tens
1911
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001912 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1913 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1914 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001915 self.ser.addOperator(
1916 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1917 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001918 return result_tens
1919
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001920 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001921 self.ser.addOutputTensor(val)
1922 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001923
1924 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00001925 def build_cast(
1926 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1927 ):
1928 assert len(inputs) == 1
1929 val = inputs[0]
1930 out_dtype = args_dict["out_type"]
1931
1932 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001933 self.ser, self.rng, val, out_dtype, error_name
1934 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001935
1936 # Invalidate Input/Output list for error if checks.
1937 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00001938 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001939 pCount, cCount = op["operands"]
1940 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001941 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1942 self, error_name, input_list, output_list
1943 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001944
Les Bell729b0352021-11-24 10:28:21 +00001945 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001946 self.ser,
1947 validator_fcns,
1948 error_name,
1949 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001950 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00001951 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001952 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00001953 output_dtype=result_tensor.dtype,
1954 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001955 input_list=input_list,
1956 output_list=output_list,
1957 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001958 ):
1959 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001960
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001961 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00001962
1963 compliance = self.tensorComplianceMetaData(
1964 op, val.dtype, args_dict, result_tensor, error_name
1965 )
1966
1967 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001968
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001969 def build_rescale(
1970 self,
1971 op,
1972 val,
1973 out_dtype,
1974 scale32,
1975 double_round,
1976 per_channel,
1977 validator_fcns,
1978 error_name,
1979 ):
1980 result_tens = OutputShaper.typeConversionOp(
1981 self.ser, self.rng, val, out_dtype, error_name
1982 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001983
1984 if per_channel:
1985 nc = val.shape[-1]
1986 else:
1987 nc = 1
1988
1989 in_type_width = self.typeWidth(val.dtype)
1990 out_type_width = self.typeWidth(out_dtype)
1991
Kevin Cheng3a478572021-01-22 17:21:02 -08001992 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001993 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001994 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001995 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001996 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001997 in_type_width += 1
1998 elif error_name in [
1999 ErrorIf.InputZeroPointNotZero,
2000 ErrorIf.U16InputZeroPointNotValid,
2001 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002002 input_zp = self.randInt(-128, 128)
2003 if input_zp == 0:
2004 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002005 in_type_width += 1
2006 elif val.dtype == DType.UINT16:
2007 # Must come after ErrorIf.U16InputZeroPointNotValid check
2008 input_zp = self.rng.choice([0, 32768])
2009 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002010 else:
2011 input_zp = 0
2012
Kevin Cheng3a478572021-01-22 17:21:02 -08002013 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002014 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002015 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002016 elif out_dtype == DType.UINT8:
2017 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002018 out_type_width += 1
2019 elif error_name in [
2020 ErrorIf.OutputZeroPointNotZero,
2021 ErrorIf.U16OutputZeroPointNotValid,
2022 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002023 output_zp = self.randInt(-128, 128)
2024 if output_zp == 0:
2025 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002026 out_type_width += 1
2027 elif out_dtype == DType.UINT16:
2028 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2029 output_zp = self.rng.choice([0, 32768])
2030 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002031 else:
2032 output_zp = 0
2033
2034 # Calculate scale based on:
2035 # scale = a *(2^output_width)/(2^input_width))
2036
2037 a = np.float32(self.rng.random(size=[nc]))
2038 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2039
2040 if scale32:
2041 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002042 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002043 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2044 else:
2045 # Cap the scaling at 2^15 - 1 for scale16
2046 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2047
Kevin Cheng550ccc52021-03-03 11:21:43 -08002048 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002049
2050 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2051 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002052 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2053 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002054
2055 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002056 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2057 scale_arr[i], scale32
2058 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002059 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2060 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002061
Kevin Cheng550ccc52021-03-03 11:21:43 -08002062 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002063 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002064 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002065 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002066 assert val.placeholderFilename
2067 values = np.load(
2068 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2069 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002070 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2071 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2072 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2073 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002074 if not np.all(np.array_equal(values, val_adj)):
2075 # Values changed so overwrite file with new values
2076 np.save(
2077 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2078 val_adj,
2079 False,
2080 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002081
Matthew Haddonc2025212021-10-08 21:21:05 +01002082 # Invalidate Input/Output list for error if checks.
2083 input_list = [val.name]
2084 output_list = [result_tens.name]
2085 pCount, cCount = op["operands"]
2086 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002087 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2088 self, error_name, input_list, output_list
2089 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002090
2091 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002092 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002093 self.ser,
2094 validator_fcns,
2095 error_name,
2096 op=op,
2097 input_dtype=val.dtype,
2098 output_dtype=out_dtype,
2099 input_shape=val.shape,
2100 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002101 scale32=scale32,
2102 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002103 input_list=input_list,
2104 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002105 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002106 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002107 ):
2108 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002109
Eric Kunzee5e26762020-10-13 16:11:07 -07002110 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002111 attr.RescaleAttribute(
2112 input_zp,
2113 output_zp,
2114 multiplier_arr,
2115 shift_arr,
2116 scale32,
2117 double_round,
2118 per_channel,
2119 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002120
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002121 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002122 return result_tens
2123
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002124 def _get_condition_tensor(self, op, cond, error_name):
2125 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002126 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002127 else:
2128 cond_type = DType.BOOL
2129 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2130 choice = self.rng.choice([1, 2])
2131 if choice == 1:
2132 cond_shape = [2]
2133 else:
2134 cond_shape = [1, 2]
2135 else:
2136 # Must be of size 1 (rank 0)
2137 cond_shape = []
2138 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2139 return cond_tens
2140
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002141 def build_cond_if_const(
2142 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2143 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002144 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002145 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002146 # and fill them with const nodes for the body.
2147
2148 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002149 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002150
2151 # Make then/else tensors
2152 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002153
2154 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002155 if error_name in [
2156 ErrorIf.CondIfOutputListThenGraphMismatch,
2157 ErrorIf.CondIfOutputListElseGraphMismatch,
2158 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002159 incorrect_shape = deepcopy(then_tens.shape)
2160 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002161 incorrect_shape[i] += (
2162 self.rng.choice([-3, -2, 2, 3])
2163 if incorrect_shape[i] > 3
2164 else self.rng.choice([1, 2, 4])
2165 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002166 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2167
Jeremy Johnson18e26662021-07-22 16:15:29 +01002168 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2169 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002170
2171 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002172 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002173
2174 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002175 then_block = "THEN_BLOCK"
2176 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002177 attr = ts.TosaSerializerAttribute()
2178 attr.CondIfAttribute(then_block, else_block)
2179
2180 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002181 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002182
Jerry Ge9e94af82022-10-27 09:57:00 -07002183 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002184 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002185 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2186 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2187 else:
2188 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002189 self.ser.addOutputTensor(then_tens)
2190
Jerry Ge9e94af82022-10-27 09:57:00 -07002191 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002192 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2193 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2194 else:
2195 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002196 self.ser.addOutputTensor(else_tens)
2197
Les Bell729b0352021-11-24 10:28:21 +00002198 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002199 self.ser,
2200 validator_fcns,
2201 error_name,
2202 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002203 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002204 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002205 ):
2206 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002207
Eric Kunzee5e26762020-10-13 16:11:07 -07002208 return result_tens
2209
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002210 def build_cond_if_binary(
2211 self, op, a, b, cond, validator_fcns=None, error_name=None
2212 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002213 # For cond_if with a binary op in the then/else blocks, take a and b and
2214 # alternately add or subtract them based on the condition
2215
2216 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002217 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002218
Kevin Cheng550ccc52021-03-03 11:21:43 -08002219 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002220
2221 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002222 then_block = "THEN_BLOCK"
2223 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002224 attr = ts.TosaSerializerAttribute()
2225 attr.CondIfAttribute(then_block, else_block)
2226
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002227 if error_name in [
2228 ErrorIf.CondIfInputListThenGraphMismatch,
2229 ErrorIf.CondIfInputListElseGraphMismatch,
2230 ErrorIf.CondIfOutputListElseGraphMismatch,
2231 ErrorIf.CondIfOutputListThenGraphMismatch,
2232 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002233 incorrect_shape = a.shape.copy()
2234 for i in range(len(incorrect_shape)):
2235 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2236 incorrect_block_input = deepcopy(a)
2237 incorrect_block_input.shape = incorrect_shape
2238
Eric Kunzee5e26762020-10-13 16:11:07 -07002239 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002240 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002241 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002242 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002243
James Ward24dbc422022-10-19 12:20:31 +01002244 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002245 then_op, else_op = Op.ADD, Op.SUB
2246 elif a.dtype in (DType.INT8, DType.INT16):
2247 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2248 else:
2249 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002250
Les Bell6040b4d2021-10-11 12:50:31 +01002251 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002252 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002253 if (
2254 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2255 and block == then_block
2256 ) or (
2257 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2258 and block == else_block
2259 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002260 self.ser.addInputTensor(incorrect_block_input)
2261 self.ser.addInputTensor(b)
2262 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002263 elif (
2264 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2265 and block == then_block
2266 ) or (
2267 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2268 and block == else_block
2269 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002270 self.ser.addInputTensor(a)
2271 self.ser.addInputTensor(b)
2272 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2273 else:
2274 self.ser.addInputTensor(a)
2275 self.ser.addInputTensor(b)
2276 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002277 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002278
Les Bell729b0352021-11-24 10:28:21 +00002279 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002280 self.ser,
2281 validator_fcns,
2282 error_name,
2283 op=op,
2284 a=a,
2285 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002286 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002287 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002288 ):
2289 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002290
Eric Kunzee5e26762020-10-13 16:11:07 -07002291 return result_tens
2292
Matthew Haddon630c17c2021-10-14 15:05:41 +01002293 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002294 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002295
Kevin Cheng550ccc52021-03-03 11:21:43 -08002296 cond_block = "COND_BLOCK"
2297 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002298
2299 attr = ts.TosaSerializerAttribute()
2300 attr.WhileLoopAttribute(cond_block, body_block)
2301
2302 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002303 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002304 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002305 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002306
2307 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002308 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2309 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002310 if error_name == ErrorIf.InputListOutputListMismatch:
2311 incorrect_acc = deepcopy(acc)
2312 for i in range(len(incorrect_acc.shape)):
2313 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2314 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2315 else:
2316 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002317
2318 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002319 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002320 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002321 [iter.name, a.name, acc.name],
2322 [iter_out.name, a_out.name, acc_out.name],
2323 attr,
2324 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002325 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002326
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002327 if error_name in [
2328 ErrorIf.InputListCondGraphMismatch,
2329 ErrorIf.InputListBodyGraphInputMismatch,
2330 ErrorIf.InputListBodyGraphOutputMismatch,
2331 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002332 incorrect_iter = deepcopy(iter)
2333 for i in range(len(incorrect_iter.shape)):
2334 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2335 if len(incorrect_iter.shape) == 0:
2336 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2337
2338 incorrect_acc = deepcopy(acc)
2339 for i in range(len(incorrect_acc.shape)):
2340 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2341
Eric Kunzee5e26762020-10-13 16:11:07 -07002342 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002343 self.ser.addBasicBlock(cond_block)
2344
Matthew Haddon630c17c2021-10-14 15:05:41 +01002345 if error_name == ErrorIf.InputListCondGraphMismatch:
2346 self.ser.addInputTensor(incorrect_iter)
2347 self.ser.addInputTensor(a)
2348 self.ser.addInputTensor(incorrect_acc)
2349 else:
2350 self.ser.addInputTensor(iter)
2351 self.ser.addInputTensor(a)
2352 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002353 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002354
2355 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002356 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002357 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002358 cond_type = DType.BOOL
2359 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2360 choice = self.rng.choice([1, 2])
2361 if choice == 1:
2362 cond_shape = [3]
2363 else:
2364 cond_shape = [1, 2]
2365 else:
2366 cond_shape = []
2367 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002368
Kevin Cheng550ccc52021-03-03 11:21:43 -08002369 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002370
2371 # BODY block (input: a, acc, iter, output: a, acc, iter)
2372 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002373 self.ser.addBasicBlock(body_block)
2374
Matthew Haddon630c17c2021-10-14 15:05:41 +01002375 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2376 self.ser.addInputTensor(incorrect_iter)
2377 self.ser.addInputTensor(a)
2378 self.ser.addInputTensor(incorrect_acc)
2379 else:
2380 self.ser.addInputTensor(iter)
2381 self.ser.addInputTensor(a)
2382 self.ser.addInputTensor(acc)
2383
Kevin Cheng550ccc52021-03-03 11:21:43 -08002384 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002385
2386 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002387 iter_body_out = self.ser.addIntermediate(
2388 incorrect_iter.shape, incorrect_iter.dtype
2389 )
2390 acc_body_out = self.ser.addIntermediate(
2391 incorrect_acc.shape, incorrect_acc.dtype
2392 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002393 else:
2394 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2395 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2396
Eric Kunzee5e26762020-10-13 16:11:07 -07002397 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2398 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2399 self.ser.addOutputTensor(iter_body_out)
2400 self.ser.addOutputTensor(a)
2401 self.ser.addOutputTensor(acc_body_out)
2402
Les Bell729b0352021-11-24 10:28:21 +00002403 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002404 self.ser,
2405 validator_fcns,
2406 error_name,
2407 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002408 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002409 ):
2410 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002411
Eric Kunzee5e26762020-10-13 16:11:07 -07002412 return acc_out
2413
Luke Hutton57287132023-02-06 14:54:18 +00002414 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002415 self,
2416 op,
2417 val1,
2418 val2,
2419 inverse,
2420 validator_fcns=None,
2421 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002422 ):
2423 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2424
2425 input_names = [val1.name, val2.name]
2426 pCount, cCount = op["operands"]
2427 num_operands = pCount + cCount
2428
2429 output_names = [res.name for res in results]
2430 output_shapes = [res.shape for res in results]
2431 output_dtypes = [res.dtype for res in results]
2432
2433 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2434 self, error_name, input_names, output_names
2435 )
2436
2437 if not TosaErrorValidator.evValidateErrorIfs(
2438 self.ser,
2439 validator_fcns,
2440 error_name,
2441 op=op,
2442 inverse=inverse,
2443 input1=val1,
2444 input2=val2,
2445 input_shape=val1.shape,
2446 input_dtype=val1.dtype,
2447 output_shape=output_shapes,
2448 output_dtype=output_dtypes,
2449 result_tensors=results,
2450 input_list=input_names,
2451 output_list=output_names,
2452 num_operands=num_operands,
2453 ):
2454 return None
2455
Tai Lyd3797f02023-11-15 23:06:19 +00002456 # TODO - Test local_bound, for now set local bound attribute to False
2457 local_bound = False
2458
Luke Hutton57287132023-02-06 14:54:18 +00002459 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002460 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002461
2462 self.ser.addOperator(op["op"], input_names, output_names, attr)
2463 return results
2464
Tai Lyd3797f02023-11-15 23:06:19 +00002465 def build_rfft2d(
2466 self,
2467 op,
2468 val,
2469 validator_fcns=None,
2470 error_name=None,
2471 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002472 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2473
2474 input_names = [val.name]
2475 pCount, cCount = op["operands"]
2476 num_operands = pCount + cCount
2477
2478 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002479 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002480 output_dtypes = [res.dtype for res in results]
2481
2482 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2483 self, error_name, input_names, output_names
2484 )
2485
2486 if not TosaErrorValidator.evValidateErrorIfs(
2487 self.ser,
2488 validator_fcns,
2489 error_name,
2490 op=op,
2491 input_shape=val.shape,
2492 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002493 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002494 output_dtype=output_dtypes,
2495 result_tensors=results,
2496 input_list=input_names,
2497 output_list=output_names,
2498 num_operands=num_operands,
2499 ):
2500 return None
2501
Tai Lyd3797f02023-11-15 23:06:19 +00002502 # TODO - Test local_bound, for now set local bound attribute to False
2503 local_bound = False
2504
2505 attr = ts.TosaSerializerAttribute()
2506 attr.RFFTAttribute(local_bound)
2507
2508 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002509 return results
2510
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002511 def create_filter_lists(
2512 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2513 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002514 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2515 default_test_rank_range = range(1, 5)
2516 if not shapeFilter:
2517 shapeFilter = [None]
2518
2519 # Calculate the filters based on what is requested and what the operator allows
2520 rmin, rmax = op["rank"]
2521 if rankFilter is not None:
2522 cleanRankFilter = []
2523 # Ensure rankFilter values are allowed by operator
2524 for rank in rankFilter:
2525 if rank >= rmin and rank <= rmax:
2526 cleanRankFilter.append(rank)
2527 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002528 # Ensure default behaviour is bounded by default range or by operator,
2529 # whichever is the smaller range of ranks.
2530 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002531 cleanRankFilter = (
2532 opRankRange
2533 if len(opRankRange) <= len(default_test_rank_range)
2534 else default_test_rank_range
2535 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002536 else:
2537 cleanRankFilter = range(rmin, rmax + 1)
2538
2539 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002540
Matthew Haddon1c00b712021-10-01 15:51:03 +01002541 if dtypeFilter is not None:
2542 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002543 # Create list of operator dtypes filtered by requested dtypes
2544 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002545 if dtype in dtypeFilter or (
2546 isinstance(dtype, list) and dtype[0] in dtypeFilter
2547 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002548 cleanDtypeFilter.append(dtype)
2549 else:
2550 cleanDtypeFilter = dtypes
2551
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002552 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002553 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002554 "shapeFilter": shapeFilter,
2555 "rankFilter": cleanRankFilter,
2556 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002557 }
2558 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002559 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002560 if validator is not None:
2561 validator_info = validator(check=False, op=op)
2562 else:
2563 return None
2564
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002565 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002566
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002567 # Set parameters as required
2568 if error_arguments["rank"] is not None:
2569 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002570 else:
2571 rankFilter = cleanRankFilter
2572
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002573 if error_arguments["dtype"] is not None:
2574 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002575 else:
2576 dtypeFilter = cleanDtypeFilter
2577
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002578 if error_arguments["shape"] is not None:
2579 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002580 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002581 shapeFilter = shapeFilter[
2582 :2
2583 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002584
2585 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002586 "shapeFilter": shapeFilter,
2587 "rankFilter": rankFilter,
2588 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002589 }
2590 return filterDict
2591
Kevin Cheng550ccc52021-03-03 11:21:43 -08002592 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002593 self,
2594 opName,
2595 shapeFilter=[None],
2596 rankFilter=None,
2597 dtypeFilter=None,
2598 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002599 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002600
2601 try:
2602 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002603 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002604 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002605
2606 # Initialize a new random number generator
2607 self.rng = np.random.default_rng(self.random_seed)
2608
Jeremy Johnson1271c442023-09-05 11:39:26 +01002609 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002610
Eric Kunzee5e26762020-10-13 16:11:07 -07002611 # Test list consists of a tuple of:
2612 # (opName, testNameStr, dtype, shapeList, argumentsList)
2613 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002614 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002615 error_if_validators = op["error_if_validators"]
2616 else:
2617 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002618
Matthew Haddon1c00b712021-10-01 15:51:03 +01002619 for validator in error_if_validators:
2620 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002621 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002622 else:
2623 error_name = None
2624
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002625 filterDict = self.create_filter_lists(
2626 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2627 )
2628 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002629 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002630 cleanRankFilter = filterDict["rankFilter"]
2631 cleanDtypeFilter = filterDict["dtypeFilter"]
2632 cleanShapeFilter = filterDict["shapeFilter"]
2633 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002634
2635 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002636 for t in cleanDtypeFilter:
2637 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002638 # Filter out by rank
2639 if shape is not None and len(shape) != r:
2640 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002641 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002642 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002643
Matthew Haddon74567092021-07-16 15:38:20 +01002644 shapeStr = self.shapeStr(shapeList[0])
2645 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002646
Matthew Haddon74567092021-07-16 15:38:20 +01002647 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2648 argList = []
2649 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002650 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002651 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002652 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002653
Matthew Haddon74567092021-07-16 15:38:20 +01002654 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002655 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002656 if argStr:
2657 testStr = "{}_{}_{}_{}".format(
2658 opName, shapeStr, typeStr, argStr
2659 )
2660 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002661 testStr = "{}_{}_{}".format(
2662 opName, shapeStr, typeStr
2663 )
2664 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002665 if argStr:
2666 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2667 opName, error_name, shapeStr, typeStr, argStr
2668 )
2669 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002670 testStr = "{}_ERRORIF_{}_{}_{}".format(
2671 opName, error_name, shapeStr, typeStr
2672 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002673
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002674 testList.append(
2675 (opName, testStr, t, error_name, shapeList, args)
2676 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002677
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002678 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002679 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2680 if "invalid_test_validators" in op:
2681 invalid_test_validators = op["invalid_test_validators"]
2682 clean_testList = []
2683 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002684 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002685 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002686 if validator_fcn(
2687 opName=test[0],
2688 input_dtype=test[2],
2689 shapeList=test[4],
2690 args=test[5],
2691 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002692 remove_test = True
2693 if not remove_test:
2694 clean_testList.append(test)
2695 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002696
2697 return testList
2698
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002699 def serializeTest(
2700 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2701 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002702 try:
2703 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002704 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002705 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002706
Jeremy Johnson0c716862023-04-13 17:18:19 +01002707 if self.args.verbose:
2708 print(f"Creating {testStr}")
2709
Eric Kunzee5e26762020-10-13 16:11:07 -07002710 # Create a serializer
2711 self.createSerializer(opName, testStr)
2712
Jeremy Johnson1271c442023-09-05 11:39:26 +01002713 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002714 if "error_if_validators" in op:
2715 error_if_validators = op["error_if_validators"]
2716 else:
2717 error_if_validators = None
2718
Kevin Cheng550ccc52021-03-03 11:21:43 -08002719 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002720 num_operands = pCount + cCount
2721
2722 if isinstance(dtype_or_dtypeList, list):
2723 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002724 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002725 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002726 else:
2727 dtypeList = [dtype_or_dtypeList] * (num_operands)
2728
Kevin Cheng93a16282021-08-31 16:14:03 -07002729 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002730 assert (
2731 len(shapeList) == num_operands
2732 ), "shapeList length {} must match number of operands {}".format(
2733 len(shapeList), num_operands
2734 )
2735 assert (
2736 len(dtypeList) == num_operands
2737 ), "dtypeList length {} must match number of operands {}".format(
2738 len(dtypeList), num_operands
2739 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002740
2741 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002742 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002743 except KeyError:
2744 qgen = None
2745
2746 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002747
Matthew Haddon1c00b712021-10-01 15:51:03 +01002748 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002749 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002750 else:
2751 qinfo = None
2752
Jeremy Johnson1271c442023-09-05 11:39:26 +01002753 # Extra meta data for the desc.json
2754 tensMeta = {}
2755
2756 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002757 if isinstance(testArgs, dict):
2758 # New interface with args info in dictionary
2759 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002760 assert "dg_type" in argsDict
2761 tvgInfo = tvgen_fcn(
2762 self, opName, dtypeList, shapeList, argsDict, error_name
2763 )
2764 if tvgInfo.dataGenDict:
2765 tensMeta["data_gen"] = tvgInfo.dataGenDict
2766 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002767
2768 result = build_fcn(
2769 self,
2770 op,
2771 tens,
2772 argsDict,
2773 validator_fcns=error_if_validators,
2774 error_name=error_name,
2775 qinfo=qinfo,
2776 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002777 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002778 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002779 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002780
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002781 try:
2782 if error_if_validators is None:
2783 if qinfo is not None:
2784 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2785 else:
2786 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002787 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002788 if qinfo is not None:
2789 result = build_fcn(
2790 self,
2791 op,
2792 *tens,
2793 *testArgs,
2794 validator_fcns=error_if_validators,
2795 error_name=error_name,
2796 qinfo=qinfo,
2797 )
2798 else:
2799 result = build_fcn(
2800 self,
2801 op,
2802 *tens,
2803 *testArgs,
2804 validator_fcns=error_if_validators,
2805 error_name=error_name,
2806 )
2807 except TypeError as e:
2808 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2809 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002810
Jeremy Johnson1271c442023-09-05 11:39:26 +01002811 if result:
Les Bell729b0352021-11-24 10:28:21 +00002812 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002813 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2814 # Add the compliance meta data
2815 # NOTE: This currently expects only one result output
2816 tensMeta["compliance"] = {
2817 "version": "0.1",
2818 "tensors": {result.resultTensor.name: result.complianceDict},
2819 }
2820 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002821 else:
2822 # The test is not valid
2823 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002824
Eric Kunzee5e26762020-10-13 16:11:07 -07002825 def createDynamicOpLists(self):
2826
Jeremy Johnson00423432022-09-12 17:27:37 +01002827 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2828 # Already created these lists (can occur when class is initialized more than once)
2829 return
2830
Eric Kunzee5e26762020-10-13 16:11:07 -07002831 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002832 if not self.args.level8k:
2833 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2834 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2835 else:
2836 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2837 KERNELS_2D = [[1, bigK], [bigK, 2]]
2838 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002839
Kevin Cheng1533b852021-09-01 12:51:58 -07002840 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002841 testName = "conv2d_{}x{}".format(k[0], k[1])
2842 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2843 self.TOSA_OP_LIST[testName]["filter"] = k
2844 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002845
Kevin Cheng550ccc52021-03-03 11:21:43 -08002846 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2847 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2848 "depthwise_conv2d_TEMPLATE"
2849 ].copy()
2850 self.TOSA_OP_LIST[testName]["filter"] = k
2851 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002852
Kevin Cheng550ccc52021-03-03 11:21:43 -08002853 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2854 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2855 "transpose_conv2d_TEMPLATE"
2856 ].copy()
2857 self.TOSA_OP_LIST[testName]["filter"] = k
2858 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002859
Kevin Cheng1533b852021-09-01 12:51:58 -07002860 for k in KERNELS_3D:
2861 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2862 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2863 self.TOSA_OP_LIST[testName]["filter"] = k
2864 self.TOSA_OP_LIST[testName]["template"] = False
2865
Eric Kunzee5e26762020-10-13 16:11:07 -07002866 # Delete any templates after having created any dynamic ops
2867 # This is a two-pass operation because it's bad practice to delete
2868 # keys from dictionaries while iterating
2869 keyList = []
2870 for k in self.TOSA_OP_LIST:
2871 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002872 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002873 keyList.append(k)
2874 continue
2875 except KeyError:
2876 pass
2877
2878 for k in keyList:
2879 del self.TOSA_OP_LIST[k]
2880
2881 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002882 """Fill in default fields for ops if they aren't already specified.
2883 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002884 for op in self.TOSA_OP_LIST:
2885
2886 # Required fields
2887 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002888 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002889 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002890 raise Exception(
2891 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2892 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002893
2894 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002895 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002896 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002897 raise Exception(
2898 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2899 op
2900 )
2901 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002902
2903 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002904 _ = self.TOSA_OP_LIST[op]["types"]
2905 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002906 raise Exception(
2907 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2908 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002909
2910 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002911 _ = self.TOSA_OP_LIST[op]["op"]
2912 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002913 raise Exception(
2914 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2915 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002916
2917 # Put in default rank range, if missing
2918 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002919 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002920 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002921 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002922
2923 # Tensor operator list
2924 # 'op': op name
2925 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002926 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2927 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002928 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2929 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002930 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002931
Kevin Cheng550ccc52021-03-03 11:21:43 -08002932 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002933 TYPE_INT_FP = [
2934 DType.INT8,
2935 DType.INT16,
2936 DType.INT32,
2937 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002938 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002939 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002940 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002941
Kevin Cheng550ccc52021-03-03 11:21:43 -08002942 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002943 TYPE_FI32 = [
2944 DType.FP32,
2945 DType.FP16,
2946 DType.BF16,
2947 DType.INT32,
2948 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002949 TYPE_FIB = [
2950 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002951 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002952 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002953 DType.INT8,
2954 DType.INT16,
2955 DType.INT32,
2956 DType.BOOL,
2957 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002958 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002959
James Ward24dbc422022-10-19 12:20:31 +01002960 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002961
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002962 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002963 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002964 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002965 [DType.INT8, DType.INT8, DType.INT32],
2966 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002967 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002968 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002969 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002970 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002971 ]
2972
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002973 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002974
2975 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002976 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002977 "argmax": {
2978 "op": Op.ARGMAX,
2979 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002980 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002981 "build_fcn": (
2982 build_argmax,
2983 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002984 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002985 TosaArgGen.agAxis,
2986 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002987 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002988 "error_if_validators": (
2989 TosaErrorValidator.evAxisSmallerZero,
2990 TosaErrorValidator.evAxisLargerRank,
2991 TosaErrorValidator.evArgmaxOutputRankMismatch,
2992 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2993 TosaErrorValidator.evWrongRank,
2994 TosaErrorValidator.evWrongInputType,
2995 TosaErrorValidator.evWrongOutputType,
2996 TosaErrorValidator.evWrongInputList,
2997 TosaErrorValidator.evWrongOutputList,
2998 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002999 "data_gen": {
3000 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3001 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003002 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003003 "avg_pool2d": {
3004 "op": Op.AVG_POOL2D,
3005 "operands": (1, 0),
3006 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003007 "build_fcn": (
3008 build_pool2d,
3009 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003010 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003011 TosaArgGen.agPooling,
3012 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003013 "qgen": TosaQuantGen.qgUnary,
3014 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003015 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003016 "error_if_validators": (
3017 TosaErrorValidator.evKernelSmallerOne,
3018 TosaErrorValidator.evStrideSmallerOne,
3019 TosaErrorValidator.evPadSmallerZero,
3020 TosaErrorValidator.evWrongRank,
3021 TosaErrorValidator.evWrongInputType,
3022 TosaErrorValidator.evWrongOutputType,
3023 TosaErrorValidator.evWrongInputList,
3024 TosaErrorValidator.evWrongOutputList,
3025 TosaErrorValidator.evInputZeroPointNotZero,
3026 TosaErrorValidator.evOutputZeroPointNotZero,
3027 TosaErrorValidator.evPadLargerEqualKernel,
3028 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003029 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003030 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003031 "data_gen": {
3032 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3033 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003034 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003035 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003036 "conv2d_TEMPLATE": {
3037 "op": Op.CONV2D,
3038 "operands": (1, 2),
3039 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003040 "build_fcn": (
3041 build_conv2d,
3042 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003043 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003044 TosaArgGen.agConv,
3045 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003046 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003047 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003048 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3049 "error_if_validators": (
3050 TosaErrorValidator.evWrongInputType,
3051 TosaErrorValidator.evWrongOutputType,
3052 TosaErrorValidator.evWrongInputList,
3053 TosaErrorValidator.evWrongOutputList,
3054 TosaErrorValidator.evInputZeroPointNotZero,
3055 TosaErrorValidator.evWeightZeroPointNotZero,
3056 TosaErrorValidator.evPadSmallerZero,
3057 TosaErrorValidator.evStrideSmallerOne,
3058 TosaErrorValidator.evDilationSmallerOne,
3059 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003060 TosaErrorValidator.evConvOutputShapeMismatch,
3061 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003062 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003063 "data_gen": {
3064 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3065 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003066 "template": True,
3067 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003068 # Templated operator. Filled in by createDynamicOpLists
3069 "conv3d_TEMPLATE": {
3070 "op": Op.CONV3D,
3071 "operands": (1, 2),
3072 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003073 "build_fcn": (
3074 build_conv3d,
3075 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003076 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003077 TosaArgGen.agConv,
3078 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003079 "qgen": TosaQuantGen.qgConv,
3080 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003081 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3082 "error_if_validators": (
3083 TosaErrorValidator.evWrongInputType,
3084 TosaErrorValidator.evWrongOutputType,
3085 TosaErrorValidator.evWrongInputList,
3086 TosaErrorValidator.evWrongOutputList,
3087 TosaErrorValidator.evInputZeroPointNotZero,
3088 TosaErrorValidator.evWeightZeroPointNotZero,
3089 TosaErrorValidator.evPadSmallerZero,
3090 TosaErrorValidator.evStrideSmallerOne,
3091 TosaErrorValidator.evDilationSmallerOne,
3092 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003093 TosaErrorValidator.evConvOutputShapeMismatch,
3094 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003095 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003096 "template": True,
3097 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003098 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003099 "depthwise_conv2d_TEMPLATE": {
3100 "op": Op.DEPTHWISE_CONV2D,
3101 "operands": (1, 2),
3102 "filter": [1, 1],
3103 "rank": (4, 4),
3104 "build_fcn": (
3105 build_depthwise_conv2d,
3106 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003107 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003108 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003109 ),
3110 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003111 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003112 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3113 "error_if_validators": (
3114 TosaErrorValidator.evWrongInputType,
3115 TosaErrorValidator.evWrongOutputType,
3116 TosaErrorValidator.evWrongInputList,
3117 TosaErrorValidator.evWrongOutputList,
3118 TosaErrorValidator.evInputZeroPointNotZero,
3119 TosaErrorValidator.evWeightZeroPointNotZero,
3120 TosaErrorValidator.evPadSmallerZero,
3121 TosaErrorValidator.evStrideSmallerOne,
3122 TosaErrorValidator.evDilationSmallerOne,
3123 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003124 TosaErrorValidator.evConvOutputShapeMismatch,
3125 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003126 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003127 "template": True,
3128 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003129 "fully_connected": {
3130 "op": Op.FULLY_CONNECTED,
3131 "operands": (1, 2),
3132 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003133 "build_fcn": (
3134 build_fully_connected,
3135 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003136 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003137 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003138 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003139 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003140 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003141 "error_if_validators": (
3142 TosaErrorValidator.evInputZeroPointNotZero,
3143 TosaErrorValidator.evWeightZeroPointNotZero,
3144 TosaErrorValidator.evWrongRank,
3145 TosaErrorValidator.evWrongInputType,
3146 TosaErrorValidator.evWrongOutputType,
3147 TosaErrorValidator.evWrongInputList,
3148 TosaErrorValidator.evWrongOutputList,
3149 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003150 "data_gen": {
3151 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3152 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003153 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003154 "matmul": {
3155 "op": Op.MATMUL,
3156 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003157 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003158 "build_fcn": (
3159 build_matmul,
3160 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003161 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003162 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003163 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003164 "qgen": TosaQuantGen.qgMatmul,
3165 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003166 "error_if_validators": (
3167 TosaErrorValidator.evInputZeroPointNotZero,
3168 TosaErrorValidator.evWrongRank,
3169 TosaErrorValidator.evWrongInputType,
3170 TosaErrorValidator.evWrongOutputType,
3171 TosaErrorValidator.evWrongInputList,
3172 TosaErrorValidator.evWrongOutputList,
3173 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003174 "data_gen": {
3175 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003176 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003177 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003178 "max_pool2d": {
3179 "op": Op.MAX_POOL2D,
3180 "operands": (1, 0),
3181 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003182 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003183 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003184 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003185 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003186 TosaArgGen.agPooling,
3187 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003188 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003189 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003190 "error_if_validators": (
3191 TosaErrorValidator.evKernelSmallerOne,
3192 TosaErrorValidator.evStrideSmallerOne,
3193 TosaErrorValidator.evPadSmallerZero,
3194 TosaErrorValidator.evWrongRank,
3195 TosaErrorValidator.evWrongInputType,
3196 TosaErrorValidator.evWrongOutputType,
3197 TosaErrorValidator.evWrongInputList,
3198 TosaErrorValidator.evWrongOutputList,
3199 TosaErrorValidator.evPadLargerEqualKernel,
3200 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003201 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003202 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003203 "data_gen": {
3204 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3205 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003206 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003207 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003208 "transpose_conv2d_TEMPLATE": {
3209 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003210 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003211 "rank": (4, 4),
3212 "build_fcn": (
3213 build_transpose_conv2d,
3214 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003215 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003216 TosaArgGen.agTransposeConv2D,
3217 ),
3218 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003219 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003220 "invalid_test_validators": (
3221 TosaInvalidValidator.ivHeightWidthInvalid,
3222 TosaInvalidValidator.ivNonPositiveOutputShape,
3223 ),
3224 "error_if_validators": (
3225 TosaErrorValidator.evWrongInputType,
3226 TosaErrorValidator.evWrongOutputType,
3227 TosaErrorValidator.evWrongInputList,
3228 TosaErrorValidator.evWrongOutputList,
3229 TosaErrorValidator.evInputZeroPointNotZero,
3230 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003231 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003232 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003233 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003234 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003235 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003236 "template": True,
3237 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003238 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003239 "clamp": {
3240 "op": Op.CLAMP,
3241 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003242 "build_fcn": (
3243 build_clamp,
3244 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003245 TosaTensorValuesGen.tvgLazyGenDefault,
3246 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003247 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003248 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003249 "error_if_validators": (
3250 TosaErrorValidator.evMaxSmallerMin,
3251 TosaErrorValidator.evWrongInputType,
3252 TosaErrorValidator.evWrongOutputType,
3253 TosaErrorValidator.evWrongInputList,
3254 TosaErrorValidator.evWrongOutputList,
3255 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003256 "data_gen": {
3257 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3258 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003259 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003260 "sigmoid": {
3261 "op": Op.SIGMOID,
3262 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003263 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003264 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003265 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003266 TosaTensorValuesGen.tvgLazyGenDefault,
3267 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003268 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003269 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003270 "error_if_validators": (
3271 TosaErrorValidator.evWrongInputType,
3272 TosaErrorValidator.evWrongOutputType,
3273 TosaErrorValidator.evWrongInputList,
3274 TosaErrorValidator.evWrongOutputList,
3275 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003276 "data_gen": {
3277 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3278 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003279 },
3280 "tanh": {
3281 "op": Op.TANH,
3282 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003283 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003284 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003285 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003286 TosaTensorValuesGen.tvgLazyGenDefault,
3287 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003288 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003289 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003290 "error_if_validators": (
3291 TosaErrorValidator.evWrongInputType,
3292 TosaErrorValidator.evWrongOutputType,
3293 TosaErrorValidator.evWrongInputList,
3294 TosaErrorValidator.evWrongOutputList,
3295 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003296 "data_gen": {
3297 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3298 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003299 "compliance": {
3300 "abs_error_lower_bound": 0.5,
3301 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003302 },
Won Jeon78155c62023-06-10 00:20:04 +00003303 "erf": {
3304 "op": Op.ERF,
3305 "operands": (1, 0),
3306 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003307 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003308 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003309 TosaTensorValuesGen.tvgLazyGenDefault,
3310 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003311 ),
3312 "types": TYPE_FP,
3313 "error_if_validators": (
3314 TosaErrorValidator.evWrongInputType,
3315 TosaErrorValidator.evWrongOutputType,
3316 TosaErrorValidator.evWrongInputList,
3317 TosaErrorValidator.evWrongOutputList,
3318 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003319 "data_gen": {
3320 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3321 },
3322 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003323 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003324 # Elementwise Binary Operators
3325 "add": {
3326 "op": Op.ADD,
3327 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003328 "build_fcn": (
3329 build_binary_broadcast,
3330 TosaTensorGen.tgBroadcastFuzz,
3331 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003332 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003333 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003334 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003335 "error_if_validators": (
3336 TosaErrorValidator.evRankMismatch,
3337 TosaErrorValidator.evWrongInputType,
3338 TosaErrorValidator.evWrongOutputType,
3339 TosaErrorValidator.evWrongInputList,
3340 TosaErrorValidator.evWrongOutputList,
3341 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003342 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003343 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003344 "data_gen": {
3345 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3346 },
3347 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003348 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003349 "arithmetic_right_shift": {
3350 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3351 "operands": (2, 0),
3352 "build_fcn": (
3353 build_arithmetic_right_shift,
3354 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003355 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003356 TosaArgGen.agArithmeticRightShift,
3357 ),
3358 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003359 "error_if_validators": (
3360 TosaErrorValidator.evRankMismatch,
3361 TosaErrorValidator.evWrongInputType,
3362 TosaErrorValidator.evWrongOutputType,
3363 TosaErrorValidator.evWrongInputList,
3364 TosaErrorValidator.evWrongOutputList,
3365 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003366 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003367 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003368 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003369 "bitwise_and": {
3370 "op": Op.BITWISE_AND,
3371 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003372 "build_fcn": (
3373 build_binary_broadcast,
3374 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003375 TosaTensorValuesGen.tvgLazyGenDefault,
3376 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003377 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003378 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003379 "error_if_validators": (
3380 TosaErrorValidator.evRankMismatch,
3381 TosaErrorValidator.evWrongInputType,
3382 TosaErrorValidator.evWrongOutputType,
3383 TosaErrorValidator.evWrongInputList,
3384 TosaErrorValidator.evWrongOutputList,
3385 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003386 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003387 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003388 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003389 "bitwise_or": {
3390 "op": Op.BITWISE_OR,
3391 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003392 "build_fcn": (
3393 build_binary_broadcast,
3394 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003395 TosaTensorValuesGen.tvgLazyGenDefault,
3396 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003397 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003398 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003399 "error_if_validators": (
3400 TosaErrorValidator.evRankMismatch,
3401 TosaErrorValidator.evWrongInputType,
3402 TosaErrorValidator.evWrongOutputType,
3403 TosaErrorValidator.evWrongInputList,
3404 TosaErrorValidator.evWrongOutputList,
3405 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003406 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003407 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003408 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003409 "bitwise_xor": {
3410 "op": Op.BITWISE_XOR,
3411 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003412 "build_fcn": (
3413 build_binary_broadcast,
3414 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003415 TosaTensorValuesGen.tvgLazyGenDefault,
3416 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003417 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003418 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003419 "error_if_validators": (
3420 TosaErrorValidator.evRankMismatch,
3421 TosaErrorValidator.evWrongInputType,
3422 TosaErrorValidator.evWrongOutputType,
3423 TosaErrorValidator.evWrongInputList,
3424 TosaErrorValidator.evWrongOutputList,
3425 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003426 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003427 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003428 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003429 "intdiv": {
3430 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003431 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003432 "build_fcn": (
3433 build_binary_broadcast,
3434 TosaTensorGen.tgBroadcastFuzz,
3435 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003436 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003437 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003438 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003439 "error_if_validators": (
3440 TosaErrorValidator.evRankMismatch,
3441 TosaErrorValidator.evWrongInputType,
3442 TosaErrorValidator.evWrongOutputType,
3443 TosaErrorValidator.evWrongInputList,
3444 TosaErrorValidator.evWrongOutputList,
3445 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003446 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003447 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003448 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003449 "logical_and": {
3450 "op": Op.LOGICAL_AND,
3451 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003452 "build_fcn": (
3453 build_binary_broadcast,
3454 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003455 TosaTensorValuesGen.tvgLazyGenDefault,
3456 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003457 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003458 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003459 "error_if_validators": (
3460 TosaErrorValidator.evRankMismatch,
3461 TosaErrorValidator.evWrongInputType,
3462 TosaErrorValidator.evWrongOutputType,
3463 TosaErrorValidator.evWrongInputList,
3464 TosaErrorValidator.evWrongOutputList,
3465 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003466 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003467 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003468 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003469 "logical_left_shift": {
3470 "op": Op.LOGICAL_LEFT_SHIFT,
3471 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003472 "build_fcn": (
3473 build_binary_broadcast,
3474 TosaTensorGen.tgBroadcastFuzz,
3475 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003476 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003477 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003478 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003479 "error_if_validators": (
3480 TosaErrorValidator.evRankMismatch,
3481 TosaErrorValidator.evWrongInputType,
3482 TosaErrorValidator.evWrongOutputType,
3483 TosaErrorValidator.evWrongInputList,
3484 TosaErrorValidator.evWrongOutputList,
3485 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003486 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003487 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003488 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003489 "logical_right_shift": {
3490 "op": Op.LOGICAL_RIGHT_SHIFT,
3491 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003492 "build_fcn": (
3493 build_binary_broadcast,
3494 TosaTensorGen.tgBroadcastFuzz,
3495 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003496 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003497 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003498 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003499 "error_if_validators": (
3500 TosaErrorValidator.evRankMismatch,
3501 TosaErrorValidator.evWrongInputType,
3502 TosaErrorValidator.evWrongOutputType,
3503 TosaErrorValidator.evWrongInputList,
3504 TosaErrorValidator.evWrongOutputList,
3505 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003506 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003507 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003508 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003509 "logical_or": {
3510 "op": Op.LOGICAL_OR,
3511 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003512 "build_fcn": (
3513 build_binary_broadcast,
3514 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003515 TosaTensorValuesGen.tvgLazyGenDefault,
3516 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003517 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003518 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003519 "error_if_validators": (
3520 TosaErrorValidator.evRankMismatch,
3521 TosaErrorValidator.evWrongInputType,
3522 TosaErrorValidator.evWrongOutputType,
3523 TosaErrorValidator.evWrongInputList,
3524 TosaErrorValidator.evWrongOutputList,
3525 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003526 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003527 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003528 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003529 "logical_xor": {
3530 "op": Op.LOGICAL_XOR,
3531 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003532 "build_fcn": (
3533 build_binary_broadcast,
3534 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003535 TosaTensorValuesGen.tvgLazyGenDefault,
3536 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003537 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003538 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003539 "error_if_validators": (
3540 TosaErrorValidator.evRankMismatch,
3541 TosaErrorValidator.evWrongInputType,
3542 TosaErrorValidator.evWrongOutputType,
3543 TosaErrorValidator.evWrongInputList,
3544 TosaErrorValidator.evWrongOutputList,
3545 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003546 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003547 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003548 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003549 "maximum": {
3550 "op": Op.MAXIMUM,
3551 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003552 "build_fcn": (
3553 build_binary_broadcast,
3554 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003555 TosaTensorValuesGen.tvgLazyGenDefault,
3556 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003557 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003558 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003559 "error_if_validators": (
3560 TosaErrorValidator.evRankMismatch,
3561 TosaErrorValidator.evWrongInputType,
3562 TosaErrorValidator.evWrongOutputType,
3563 TosaErrorValidator.evWrongInputList,
3564 TosaErrorValidator.evWrongOutputList,
3565 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003566 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003567 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003568 "data_gen": {
3569 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3570 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003571 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003572 "minimum": {
3573 "op": Op.MINIMUM,
3574 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003575 "build_fcn": (
3576 build_binary_broadcast,
3577 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003578 TosaTensorValuesGen.tvgLazyGenDefault,
3579 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003580 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003581 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003582 "error_if_validators": (
3583 TosaErrorValidator.evRankMismatch,
3584 TosaErrorValidator.evWrongInputType,
3585 TosaErrorValidator.evWrongOutputType,
3586 TosaErrorValidator.evWrongInputList,
3587 TosaErrorValidator.evWrongOutputList,
3588 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003589 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003590 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003591 "data_gen": {
3592 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3593 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003594 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003595 "mul": {
3596 "op": Op.MUL,
3597 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003598 "build_fcn": (
3599 build_mul,
3600 TosaTensorGen.tgBroadcastFuzz,
3601 TosaTensorValuesGen.tvgMul,
3602 TosaArgGen.agMul,
3603 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003604 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003605 "error_if_validators": (
3606 TosaErrorValidator.evWrongInputType,
3607 TosaErrorValidator.evWrongOutputType,
3608 TosaErrorValidator.evWrongInputList,
3609 TosaErrorValidator.evWrongOutputList,
3610 TosaErrorValidator.evRankMismatch,
3611 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003612 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003613 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003614 "data_gen": {
3615 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3616 },
3617 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003618 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003619 "pow": {
3620 "op": Op.POW,
3621 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003622 "build_fcn": (
3623 build_binary_broadcast,
3624 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003625 TosaTensorValuesGen.tvgPow,
3626 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003627 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003628 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003629 "error_if_validators": (
3630 TosaErrorValidator.evRankMismatch,
3631 TosaErrorValidator.evWrongInputType,
3632 TosaErrorValidator.evWrongOutputType,
3633 TosaErrorValidator.evWrongInputList,
3634 TosaErrorValidator.evWrongOutputList,
3635 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003636 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003637 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003638 "data_gen": {
3639 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3640 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003641 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003642 "sub": {
3643 "op": Op.SUB,
3644 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003645 "build_fcn": (
3646 build_binary_broadcast,
3647 TosaTensorGen.tgBroadcastFuzz,
3648 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003649 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003650 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003651 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003652 "error_if_validators": (
3653 TosaErrorValidator.evRankMismatch,
3654 TosaErrorValidator.evWrongInputType,
3655 TosaErrorValidator.evWrongOutputType,
3656 TosaErrorValidator.evWrongInputList,
3657 TosaErrorValidator.evWrongOutputList,
3658 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003659 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003660 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003661 "data_gen": {
3662 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3663 },
3664 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003665 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003666 "table": {
3667 "op": Op.TABLE,
3668 # Use the automatic generation functions to create the input array
3669 # but create the table tensor in the build function, as it may be
3670 # a different type from the input
3671 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003672 "build_fcn": (
3673 build_table,
3674 TosaTensorGen.tgBasic,
3675 TosaTensorValuesGen.tvgDefault,
3676 TosaArgGen.agTable,
3677 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003678 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003679 "error_if_validators": (
3680 TosaErrorValidator.evWrongInputType,
3681 TosaErrorValidator.evWrongOutputType,
3682 TosaErrorValidator.evWrongInputList,
3683 TosaErrorValidator.evWrongOutputList,
3684 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003685 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003686 # Elementwise Unary operators
3687 "abs": {
3688 "op": Op.ABS,
3689 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003690 "build_fcn": (
3691 build_unary,
3692 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003693 TosaTensorValuesGen.tvgLazyGenDefault,
3694 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003695 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003696 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003697 "error_if_validators": (
3698 TosaErrorValidator.evWrongInputType,
3699 TosaErrorValidator.evWrongOutputType,
3700 TosaErrorValidator.evWrongInputList,
3701 TosaErrorValidator.evWrongOutputList,
3702 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003703 "data_gen": {
3704 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3705 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003706 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003707 "bitwise_not": {
3708 "op": Op.BITWISE_NOT,
3709 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003710 "build_fcn": (
3711 build_unary,
3712 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003713 TosaTensorValuesGen.tvgLazyGenDefault,
3714 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003715 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003716 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003717 "error_if_validators": (
3718 TosaErrorValidator.evWrongInputType,
3719 TosaErrorValidator.evWrongOutputType,
3720 TosaErrorValidator.evWrongInputList,
3721 TosaErrorValidator.evWrongOutputList,
3722 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003723 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003724 "ceil": {
3725 "op": Op.CEIL,
3726 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003727 "build_fcn": (
3728 build_unary,
3729 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003730 TosaTensorValuesGen.tvgLazyGenDefault,
3731 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003732 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003733 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003734 "error_if_validators": (
3735 TosaErrorValidator.evWrongInputType,
3736 TosaErrorValidator.evWrongOutputType,
3737 TosaErrorValidator.evWrongInputList,
3738 TosaErrorValidator.evWrongOutputList,
3739 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003740 "data_gen": {
3741 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3742 },
3743 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003744 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003745 "clz": {
3746 "op": Op.CLZ,
3747 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003748 "build_fcn": (
3749 build_unary,
3750 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003751 TosaTensorValuesGen.tvgLazyGenDefault,
3752 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003753 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003754 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003755 "error_if_validators": (
3756 TosaErrorValidator.evWrongInputType,
3757 TosaErrorValidator.evWrongOutputType,
3758 TosaErrorValidator.evWrongInputList,
3759 TosaErrorValidator.evWrongOutputList,
3760 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003761 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003762 "exp": {
3763 "op": Op.EXP,
3764 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003765 "build_fcn": (
3766 build_unary,
3767 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003768 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003769 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003770 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003771 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003772 "error_if_validators": (
3773 TosaErrorValidator.evWrongInputType,
3774 TosaErrorValidator.evWrongOutputType,
3775 TosaErrorValidator.evWrongInputList,
3776 TosaErrorValidator.evWrongOutputList,
3777 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003778 "data_gen": {
3779 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3780 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003782 "floor": {
3783 "op": Op.FLOOR,
3784 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003785 "build_fcn": (
3786 build_unary,
3787 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003788 TosaTensorValuesGen.tvgLazyGenDefault,
3789 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003790 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003791 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003792 "error_if_validators": (
3793 TosaErrorValidator.evWrongInputType,
3794 TosaErrorValidator.evWrongOutputType,
3795 TosaErrorValidator.evWrongInputList,
3796 TosaErrorValidator.evWrongOutputList,
3797 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003798 "data_gen": {
3799 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3800 },
3801 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003802 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003803 "log": {
3804 "op": Op.LOG,
3805 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003806 "build_fcn": (
3807 build_unary,
3808 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003809 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003810 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003811 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003812 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003813 "error_if_validators": (
3814 TosaErrorValidator.evWrongInputType,
3815 TosaErrorValidator.evWrongOutputType,
3816 TosaErrorValidator.evWrongInputList,
3817 TosaErrorValidator.evWrongOutputList,
3818 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003819 "data_gen": {
3820 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3821 },
3822 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003823 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003824 "logical_not": {
3825 "op": Op.LOGICAL_NOT,
3826 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003827 "build_fcn": (
3828 build_unary,
3829 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003830 TosaTensorValuesGen.tvgLazyGenDefault,
3831 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003832 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003833 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003834 "error_if_validators": (
3835 TosaErrorValidator.evWrongInputType,
3836 TosaErrorValidator.evWrongOutputType,
3837 TosaErrorValidator.evWrongInputList,
3838 TosaErrorValidator.evWrongOutputList,
3839 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003840 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003841 "negate": {
3842 "op": Op.NEGATE,
3843 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003844 "build_fcn": (
3845 build_unary,
3846 TosaTensorGen.tgBasic,
3847 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003848 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003849 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003850 "qgen": TosaQuantGen.qgUnary,
3851 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003852 "error_if_validators": (
3853 TosaErrorValidator.evInputZeroPointNotZero,
3854 TosaErrorValidator.evOutputZeroPointNotZero,
3855 TosaErrorValidator.evWrongInputType,
3856 TosaErrorValidator.evWrongOutputType,
3857 TosaErrorValidator.evWrongInputList,
3858 TosaErrorValidator.evWrongOutputList,
3859 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003860 "data_gen": {
3861 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3862 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003863 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003864 "reciprocal": {
3865 "op": Op.RECIPROCAL,
3866 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003867 "build_fcn": (
3868 build_unary,
3869 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003870 TosaTensorValuesGen.tvgLazyGenDefault,
3871 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003872 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003873 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003874 "error_if_validators": (
3875 TosaErrorValidator.evWrongInputType,
3876 TosaErrorValidator.evWrongOutputType,
3877 TosaErrorValidator.evWrongInputList,
3878 TosaErrorValidator.evWrongOutputList,
3879 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003880 "data_gen": {
3881 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3882 },
3883 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003884 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003885 "rsqrt": {
3886 "op": Op.RSQRT,
3887 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003888 "build_fcn": (
3889 build_unary,
3890 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003891 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003892 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003893 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003894 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003895 "error_if_validators": (
3896 TosaErrorValidator.evWrongInputType,
3897 TosaErrorValidator.evWrongOutputType,
3898 TosaErrorValidator.evWrongInputList,
3899 TosaErrorValidator.evWrongOutputList,
3900 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003901 "data_gen": {
3902 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3903 },
3904 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003905 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003906 # Elementwise Ternary operators
3907 "select": {
3908 "op": Op.SELECT,
3909 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003910 "build_fcn": (
3911 build_select,
3912 TosaTensorGen.tgBroadcastFuzz,
3913 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00003914 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003915 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003916 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003917 "error_if_validators": (
3918 TosaErrorValidator.evRankMismatch,
3919 TosaErrorValidator.evWrongInputType,
3920 TosaErrorValidator.evWrongOutputType,
3921 TosaErrorValidator.evWrongInputList,
3922 TosaErrorValidator.evWrongOutputList,
3923 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003924 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003925 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00003926 "data_gen": {
3927 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3928 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003929 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003930 # Comparison operators
3931 "equal": {
3932 "op": Op.EQUAL,
3933 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003934 "build_fcn": (
3935 build_comparison,
3936 TosaTensorGen.tgBroadcastFuzz,
3937 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003938 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003939 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003940 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003941 "error_if_validators": (
3942 TosaErrorValidator.evRankMismatch,
3943 TosaErrorValidator.evWrongInputType,
3944 TosaErrorValidator.evWrongOutputType,
3945 TosaErrorValidator.evWrongInputList,
3946 TosaErrorValidator.evWrongOutputList,
3947 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003948 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003949 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003950 "data_gen": {
3951 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3952 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003953 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003954 "greater_equal": {
3955 "op": Op.GREATER_EQUAL,
3956 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003957 "build_fcn": (
3958 build_comparison,
3959 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003960 TosaTensorValuesGen.tvgLazyGenDefault,
3961 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003962 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003963 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003964 "error_if_validators": (
3965 TosaErrorValidator.evRankMismatch,
3966 TosaErrorValidator.evWrongInputType,
3967 TosaErrorValidator.evWrongOutputType,
3968 TosaErrorValidator.evWrongInputList,
3969 TosaErrorValidator.evWrongOutputList,
3970 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003971 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003972 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003973 "data_gen": {
3974 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3975 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003976 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003977 "greater": {
3978 "op": Op.GREATER,
3979 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003980 "build_fcn": (
3981 build_comparison,
3982 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003983 TosaTensorValuesGen.tvgLazyGenDefault,
3984 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003985 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003986 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003987 "error_if_validators": (
3988 TosaErrorValidator.evRankMismatch,
3989 TosaErrorValidator.evWrongInputType,
3990 TosaErrorValidator.evWrongOutputType,
3991 TosaErrorValidator.evWrongInputList,
3992 TosaErrorValidator.evWrongOutputList,
3993 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003994 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003995 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003996 "data_gen": {
3997 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3998 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003999 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004000 # Reduction operators
4001 "reduce_all": {
4002 "op": Op.REDUCE_ALL,
4003 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004004 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004005 "build_fcn": (
4006 build_reduce,
4007 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004008 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004009 TosaArgGen.agAxis,
4010 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004011 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004012 "error_if_validators": (
4013 TosaErrorValidator.evAxisLargerRank,
4014 TosaErrorValidator.evAxisSmallerZero,
4015 TosaErrorValidator.evShapeOfAxisNotOne,
4016 TosaErrorValidator.evWrongInputType,
4017 TosaErrorValidator.evWrongOutputType,
4018 TosaErrorValidator.evWrongRank,
4019 TosaErrorValidator.evWrongInputList,
4020 TosaErrorValidator.evWrongOutputList,
4021 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004022 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004023 "reduce_any": {
4024 "op": Op.REDUCE_ANY,
4025 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004026 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004027 "build_fcn": (
4028 build_reduce,
4029 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004030 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004031 TosaArgGen.agAxis,
4032 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004033 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004034 "error_if_validators": (
4035 TosaErrorValidator.evAxisLargerRank,
4036 TosaErrorValidator.evAxisSmallerZero,
4037 TosaErrorValidator.evShapeOfAxisNotOne,
4038 TosaErrorValidator.evWrongInputType,
4039 TosaErrorValidator.evWrongOutputType,
4040 TosaErrorValidator.evWrongRank,
4041 TosaErrorValidator.evWrongInputList,
4042 TosaErrorValidator.evWrongOutputList,
4043 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004044 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004045 "reduce_max": {
4046 "op": Op.REDUCE_MAX,
4047 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004048 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004049 "build_fcn": (
4050 build_reduce,
4051 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004052 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004053 TosaArgGen.agAxis,
4054 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004055 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004056 "error_if_validators": (
4057 TosaErrorValidator.evAxisLargerRank,
4058 TosaErrorValidator.evAxisSmallerZero,
4059 TosaErrorValidator.evShapeOfAxisNotOne,
4060 TosaErrorValidator.evWrongInputType,
4061 TosaErrorValidator.evWrongOutputType,
4062 TosaErrorValidator.evWrongRank,
4063 TosaErrorValidator.evWrongInputList,
4064 TosaErrorValidator.evWrongOutputList,
4065 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004066 "data_gen": {
4067 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4068 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004069 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004070 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004071 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004072 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004073 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004074 "build_fcn": (
4075 build_reduce,
4076 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004077 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004078 TosaArgGen.agAxis,
4079 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004080 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004081 "error_if_validators": (
4082 TosaErrorValidator.evAxisLargerRank,
4083 TosaErrorValidator.evAxisSmallerZero,
4084 TosaErrorValidator.evShapeOfAxisNotOne,
4085 TosaErrorValidator.evWrongInputType,
4086 TosaErrorValidator.evWrongOutputType,
4087 TosaErrorValidator.evWrongRank,
4088 TosaErrorValidator.evWrongInputList,
4089 TosaErrorValidator.evWrongOutputList,
4090 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004091 "data_gen": {
4092 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4093 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004094 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004095 "reduce_product": {
4096 "op": Op.REDUCE_PRODUCT,
4097 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004098 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004099 "build_fcn": (
4100 build_reduce,
4101 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004102 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004103 TosaArgGen.agAxis,
4104 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004105 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004106 "error_if_validators": (
4107 TosaErrorValidator.evAxisLargerRank,
4108 TosaErrorValidator.evAxisSmallerZero,
4109 TosaErrorValidator.evShapeOfAxisNotOne,
4110 TosaErrorValidator.evWrongInputType,
4111 TosaErrorValidator.evWrongOutputType,
4112 TosaErrorValidator.evWrongRank,
4113 TosaErrorValidator.evWrongInputList,
4114 TosaErrorValidator.evWrongOutputList,
4115 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004116 "data_gen": {
4117 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4118 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004119 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004120 "reduce_sum": {
4121 "op": Op.REDUCE_SUM,
4122 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004123 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004124 "build_fcn": (
4125 build_reduce,
4126 TosaTensorGen.tgBasic,
4127 TosaTensorValuesGen.tvgReduceSum,
4128 TosaArgGen.agAxis,
4129 ),
James Ward24dbc422022-10-19 12:20:31 +01004130 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004131 "error_if_validators": (
4132 TosaErrorValidator.evAxisLargerRank,
4133 TosaErrorValidator.evAxisSmallerZero,
4134 TosaErrorValidator.evShapeOfAxisNotOne,
4135 TosaErrorValidator.evWrongInputType,
4136 TosaErrorValidator.evWrongOutputType,
4137 TosaErrorValidator.evWrongRank,
4138 TosaErrorValidator.evWrongInputList,
4139 TosaErrorValidator.evWrongOutputList,
4140 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004141 "data_gen": {
4142 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4143 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004144 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004145 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004146 "concat": {
4147 "op": Op.CONCAT,
4148 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004149 "build_fcn": (
4150 build_concat,
4151 TosaTensorGen.tgConcat,
4152 TosaTensorValuesGen.tvgConcat,
4153 TosaArgGen.agAxis,
4154 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004155 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004156 "error_if_validators": (
4157 TosaErrorValidator.evAxisLargerRank,
4158 TosaErrorValidator.evAxisSmallerZero,
4159 TosaErrorValidator.evConcatInputRankMismatch,
4160 TosaErrorValidator.evConcatShapeSumMismatch,
4161 TosaErrorValidator.evConcatInputDimMismatch,
4162 TosaErrorValidator.evWrongInputType,
4163 TosaErrorValidator.evWrongOutputType,
4164 TosaErrorValidator.evWrongOutputList,
4165 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004166 "data_gen": {
4167 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4168 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004169 },
4170 "pad": {
4171 "op": Op.PAD,
4172 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004173 "build_fcn": (
4174 build_pad,
4175 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004176 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004177 TosaArgGen.agPad,
4178 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004179 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004180 "error_if_validators": (
4181 TosaErrorValidator.evWrongInputType,
4182 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004183 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004184 TosaErrorValidator.evWrongOutputType,
4185 TosaErrorValidator.evWrongInputList,
4186 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004187 TosaErrorValidator.evRankMismatch,
4188 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004189 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004190 "data_gen": {
4191 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4192 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004193 },
Won Jeona21b2e82023-08-10 10:33:01 +00004194 "dim": {
4195 "op": Op.DIM,
4196 "operands": (1, 0),
4197 "build_fcn": (
4198 build_dim,
4199 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004200 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004201 TosaArgGen.agAxis,
4202 ),
4203 "types": TYPE_FIB,
4204 "error_if_validators": (
4205 TosaErrorValidator.evAxisLargerRank,
4206 TosaErrorValidator.evAxisSmallerZero,
4207 TosaErrorValidator.evWrongInputType,
4208 TosaErrorValidator.evWrongInputList,
4209 TosaErrorValidator.evWrongOutputList,
4210 TosaErrorValidator.evWrongRank,
4211 ),
4212 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004213 "reshape": {
4214 "op": Op.RESHAPE,
4215 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004216 "build_fcn": (
4217 build_reshape,
4218 TosaTensorGen.tgBasic,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004219 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004220 TosaArgGen.agReshape,
4221 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004222 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004223 "error_if_validators": (
4224 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4225 TosaErrorValidator.evWrongInputType,
4226 TosaErrorValidator.evWrongOutputType,
4227 TosaErrorValidator.evWrongInputList,
4228 TosaErrorValidator.evWrongOutputList,
4229 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004230 "data_gen": {
4231 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4232 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004233 },
4234 "reverse": {
4235 "op": Op.REVERSE,
4236 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004237 "build_fcn": (
4238 build_reverse,
4239 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004240 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004241 TosaArgGen.agAxis,
4242 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004243 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004244 "error_if_validators": (
4245 TosaErrorValidator.evAxisSmallerZero,
4246 TosaErrorValidator.evAxisLargerRank,
4247 TosaErrorValidator.evWrongInputType,
4248 TosaErrorValidator.evWrongOutputType,
4249 TosaErrorValidator.evWrongInputList,
4250 TosaErrorValidator.evWrongOutputList,
4251 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004252 },
4253 "slice": {
4254 "op": Op.SLICE,
4255 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004256 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004257 "build_fcn": (
4258 build_slice,
4259 TosaTensorGen.tgBasic,
4260 TosaTensorValuesGen.tvgDefault,
4261 TosaArgGen.agSlice,
4262 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004263 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004264 "error_if_validators": (
4265 TosaErrorValidator.evStartSmallerZero,
4266 TosaErrorValidator.evSizeSmallerEqualZero,
4267 TosaErrorValidator.evStartSizeOutsideBounds,
4268 TosaErrorValidator.evSizeOutputShapeMismatch,
4269 TosaErrorValidator.evInputSizeStartLengthMismatch,
4270 TosaErrorValidator.evWrongRank,
4271 TosaErrorValidator.evWrongInputType,
4272 TosaErrorValidator.evWrongOutputType,
4273 TosaErrorValidator.evWrongInputList,
4274 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004275 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004276 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004277 },
4278 "tile": {
4279 "op": Op.TILE,
4280 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004281 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004282 "build_fcn": (
4283 build_tile,
4284 TosaTensorGen.tgBasic,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004285 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004286 TosaArgGen.agTile,
4287 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004288 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004289 "error_if_validators": (
4290 TosaErrorValidator.evWrongInputType,
4291 TosaErrorValidator.evWrongOutputType,
4292 TosaErrorValidator.evWrongInputList,
4293 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004294 TosaErrorValidator.evRankMismatch,
4295 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004296 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004297 "data_gen": {
4298 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4299 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004300 },
4301 "transpose": {
4302 "op": Op.TRANSPOSE,
4303 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004304 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004305 "build_fcn": (
4306 build_transpose,
4307 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004308 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004309 TosaArgGen.agTranspose,
4310 ),
4311 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004312 "error_if_validators": (
4313 TosaErrorValidator.evIndexOutsideBounds,
4314 TosaErrorValidator.evIndexUsedTwice,
4315 TosaErrorValidator.evWrongInputType,
4316 TosaErrorValidator.evWrongOutputType,
4317 TosaErrorValidator.evWrongInputList,
4318 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004319 TosaErrorValidator.evWrongRank,
4320 TosaErrorValidator.evRankMismatch,
4321 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004322 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004323 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004324 # Data nodes
4325 "const": {
4326 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004327 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004328 "build_fcn": (
4329 build_const,
4330 TosaTensorGen.tgBasic,
4331 TosaTensorValuesGen.tvgDefault,
4332 None,
4333 ),
Luke Hutton65872422023-02-20 10:33:04 +00004334 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004335 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004336 "identity": {
4337 "op": Op.IDENTITY,
4338 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004339 "build_fcn": (
4340 build_unary,
4341 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004342 TosaTensorValuesGen.tvgLazyGenDefault,
4343 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004344 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004345 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004346 "data_gen": {
4347 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4348 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004349 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004350 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004351 "gather": {
4352 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004353 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004354 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004355 "build_fcn": (
4356 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004357 TosaTensorGen.tgGather,
4358 TosaTensorValuesGen.tvgGather,
4359 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004360 ),
James Ward24dbc422022-10-19 12:20:31 +01004361 "types": (
4362 DType.INT8,
4363 DType.INT16,
4364 DType.INT32,
4365 DType.FP16,
4366 DType.BF16,
4367 DType.FP32,
4368 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004369 "error_if_validators": (
4370 TosaErrorValidator.evWrongInputType,
4371 TosaErrorValidator.evWrongOutputType,
4372 TosaErrorValidator.evWrongInputList,
4373 TosaErrorValidator.evWrongOutputList,
4374 TosaErrorValidator.evWrongRank,
4375 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004376 "data_gen": {
4377 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4378 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004379 },
4380 "scatter": {
4381 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004382 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004383 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004384 "build_fcn": (
4385 build_scatter,
4386 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004387 TosaTensorValuesGen.tvgScatter,
4388 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004389 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004390 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004391 "error_if_validators": (
4392 TosaErrorValidator.evWrongInputType,
4393 TosaErrorValidator.evWrongOutputType,
4394 TosaErrorValidator.evWrongInputList,
4395 TosaErrorValidator.evWrongOutputList,
4396 TosaErrorValidator.evWrongRank,
4397 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004398 "data_gen": {
4399 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4400 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004401 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004402 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004403 "resize": {
4404 "op": Op.RESIZE,
4405 "operands": (1, 0),
4406 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004407 "build_fcn": (
4408 build_resize,
4409 TosaTensorGen.tgNHWC,
4410 TosaTensorValuesGen.tvgDefault,
4411 TosaArgGen.agResize,
4412 ),
James Ward24dbc422022-10-19 12:20:31 +01004413 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004414 "invalid_test_validators": (
4415 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004416 ),
4417 "error_if_validators": (
4418 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004419 TosaErrorValidator.evScaleSmallerEqualZero,
4420 TosaErrorValidator.evScaleNLargerMax,
4421 TosaErrorValidator.evScaleDLargerMax,
4422 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004423 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004424 TosaErrorValidator.evBorderSmallerMin,
4425 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004426 TosaErrorValidator.evWrongInputType,
4427 TosaErrorValidator.evWrongOutputType,
4428 TosaErrorValidator.evWrongRank,
4429 TosaErrorValidator.evWrongInputList,
4430 TosaErrorValidator.evWrongOutputList,
4431 TosaErrorValidator.evBatchMismatch,
4432 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004433 TosaErrorValidator.evResizeOutputShapeMismatch,
4434 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004435 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004436 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004437 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004438 "cast": {
4439 "op": Op.CAST,
4440 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004441 "build_fcn": (
4442 build_cast,
4443 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004444 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004445 TosaArgGen.agCast,
4446 ),
James Ward8b390432022-08-12 20:48:56 +01004447 "types": (
4448 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004449 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004450 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004451 DType.INT8,
4452 DType.INT16,
4453 DType.INT32,
4454 DType.BOOL,
4455 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004456 "error_if_validators": (
4457 TosaErrorValidator.evWrongInputType,
4458 TosaErrorValidator.evWrongOutputType,
4459 TosaErrorValidator.evWrongInputList,
4460 TosaErrorValidator.evWrongOutputList,
4461 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004462 "data_gen": {
4463 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4464 },
4465 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004466 },
4467 "rescale": {
4468 "op": Op.RESCALE,
4469 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004470 "build_fcn": (
4471 build_rescale,
4472 TosaTensorGen.tgBasic,
4473 TosaTensorValuesGen.tvgDefault,
4474 TosaArgGen.agRescale,
4475 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004476 "types": [
4477 DType.UINT8,
4478 DType.INT8,
4479 DType.INT16,
4480 DType.INT32,
4481 DType.INT48,
4482 DType.UINT16,
4483 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004484 "error_if_validators": (
4485 TosaErrorValidator.evInputZeroPointNotZero,
4486 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004487 TosaErrorValidator.evU16InputZeroPointNotValid,
4488 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004489 TosaErrorValidator.evScaleTrue,
4490 TosaErrorValidator.evScaleNotTrue,
4491 TosaErrorValidator.evWrongInputType,
4492 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004493 TosaErrorValidator.evWrongInputList,
4494 TosaErrorValidator.evWrongOutputList,
4495 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004496 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004497 # Custom
4498 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004499 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004500 # Two varients of cond_if, one that generates one of two constant tensors (no
4501 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4502 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004503 "cond_if_const": {
4504 "op": Op.COND_IF,
4505 "operands": (0, 2),
4506 "build_fcn": (
4507 build_cond_if_const,
4508 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004509 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004510 TosaArgGen.agCondIf,
4511 ),
4512 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004513 "error_if_validators": (
4514 TosaErrorValidator.evOutputListThenGraphMismatch,
4515 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004516 TosaErrorValidator.evCondIfCondNotMatchingBool,
4517 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004518 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004519 },
4520 "cond_if_binary": {
4521 "op": Op.COND_IF,
4522 "operands": (2, 0),
4523 "build_fcn": (
4524 build_cond_if_binary,
4525 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004526 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004527 TosaArgGen.agCondIf,
4528 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004529 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004530 "error_if_validators": (
4531 TosaErrorValidator.evInputListThenGraphMismatch,
4532 TosaErrorValidator.evInputListElseGraphMismatch,
4533 TosaErrorValidator.evOutputListThenGraphMismatch,
4534 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004535 TosaErrorValidator.evCondIfCondNotMatchingBool,
4536 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004537 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004538 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004539 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004540 "while_loop": {
4541 "op": Op.WHILE_LOOP,
4542 "operands": (0, 1),
4543 "build_fcn": (
4544 build_while_loop,
4545 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004546 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004547 TosaArgGen.agWhileLoop,
4548 ),
4549 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004550 "error_if_validators": (
4551 TosaErrorValidator.evInputListOutputListMismatch,
4552 TosaErrorValidator.evInputListCondGraphMismatch,
4553 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4554 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4555 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004556 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004557 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004558 },
Luke Hutton57287132023-02-06 14:54:18 +00004559 "fft2d": {
4560 "op": Op.FFT2D,
4561 "operands": (2, 0),
4562 "rank": (3, 3),
4563 "build_fcn": (
4564 build_fft2d,
4565 TosaTensorGen.tgFFT2d,
4566 TosaTensorValuesGen.tvgDefault,
4567 TosaArgGen.agFFT2d,
4568 ),
4569 "types": [DType.FP32],
4570 "error_if_validators": (
4571 TosaErrorValidator.evWrongInputType,
4572 TosaErrorValidator.evWrongOutputType,
4573 TosaErrorValidator.evWrongInputList,
4574 TosaErrorValidator.evWrongOutputList,
4575 TosaErrorValidator.evWrongRank,
4576 TosaErrorValidator.evBatchMismatch,
4577 TosaErrorValidator.evKernelNotPowerOfTwo,
4578 TosaErrorValidator.evFFTInputShapeMismatch,
4579 TosaErrorValidator.evFFTOutputShapeMismatch,
4580 ),
4581 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004582 "rfft2d": {
4583 "op": Op.RFFT2D,
4584 "operands": (1, 0),
4585 "rank": (3, 3),
4586 "build_fcn": (
4587 build_rfft2d,
4588 TosaTensorGen.tgRFFT2d,
4589 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004590 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004591 ),
4592 "types": [DType.FP32],
4593 "error_if_validators": (
4594 TosaErrorValidator.evWrongInputType,
4595 TosaErrorValidator.evWrongOutputType,
4596 TosaErrorValidator.evWrongInputList,
4597 TosaErrorValidator.evWrongOutputList,
4598 TosaErrorValidator.evWrongRank,
4599 TosaErrorValidator.evBatchMismatch,
4600 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004601 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004602 ),
4603 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004604 }
4605
Kevin Cheng550ccc52021-03-03 11:21:43 -08004606
Eric Kunzee5e26762020-10-13 16:11:07 -07004607class OutputShaper:
4608 # Methods in this class compute the expected output shape and datatype
4609 # for common classes of operations
4610 def __init__(self):
4611 pass
4612
4613 # These methods return arguments that can be used for
4614 # creating a new output tensor
4615 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004616 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4617 if error_name != ErrorIf.RankMismatch:
4618 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004619 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004620
4621 shape = []
4622 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004623 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004624 shape.append(b.shape[i])
4625 else:
4626 shape.append(a.shape[i])
4627
Jerry Ge135c9552023-05-23 20:59:32 +00004628 fuzz_idx = rng.integers(0, len(a.shape))
4629 if error_name == ErrorIf.DimensionMismatch:
4630 shape[fuzz_idx] += 1
4631
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004632 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004633 all_dtypes = [
4634 DType.INT8,
4635 DType.INT16,
4636 DType.INT32,
4637 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004638 DType.FP16,
4639 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004640 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004641 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004642 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4643 outputDType = rng.choice(wrong_dtypes)
4644 else:
4645 outputDType = a.dtype
4646
4647 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004648
4649 @staticmethod
4650 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004651 assert len(a.shape) == len(b.shape)
4652 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004653
4654 shape = []
4655 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004656 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004657 shape.append(a.shape[i])
4658
Kevin Cheng550ccc52021-03-03 11:21:43 -08004659 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004660
4661 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004662 def unaryOp(ser, rng, a, error_name=None):
4663 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004664 all_dtypes = [
4665 DType.INT8,
4666 DType.INT16,
4667 DType.INT32,
4668 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004669 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004670 DType.FP16,
4671 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004672 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004673 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4674 outputDType = rng.choice(wrong_dtypes)
4675 else:
4676 outputDType = a.dtype
4677
4678 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004679
4680 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004681 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004682 if error_name != ErrorIf.RankMismatch:
4683 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004684 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004685
4686 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004687 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004688 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004689 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4690 else:
4691 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004692
Jerry Ge135c9552023-05-23 20:59:32 +00004693 fuzz_idx = rng.integers(0, len(a.shape))
4694 if error_name == ErrorIf.DimensionMismatch:
4695 shape[fuzz_idx] += 1
4696
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004697 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004698 all_dtypes = [
4699 DType.INT8,
4700 DType.INT16,
4701 DType.INT32,
4702 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004703 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004704 DType.FP16,
4705 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004706 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004707 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4708 outputDType = rng.choice(wrong_dtypes)
4709 else:
4710 outputDType = a.dtype
4711
4712 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004713
4714 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004715 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004716 if error_name != ErrorIf.RankMismatch:
4717 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004718 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004719
4720 # Do broadcast
4721 shape = []
4722 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004723 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004724 shape.append(b.shape[i])
4725 else:
4726 shape.append(a.shape[i])
4727
Jerry Ge135c9552023-05-23 20:59:32 +00004728 fuzz_idx = rng.integers(0, len(a.shape))
4729 if error_name == ErrorIf.DimensionMismatch:
4730 shape[fuzz_idx] += 1
4731
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004732 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004733 wrong_dtypes = [
4734 DType.INT8,
4735 DType.INT16,
4736 DType.INT32,
4737 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004738 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004739 DType.FP16,
4740 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004741 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004742 outputDType = rng.choice(wrong_dtypes)
4743 else:
4744 outputDType = DType.BOOL
4745
4746 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004747
4748 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004749 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004750 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004751 if error_name not in [
4752 ErrorIf.AxisSmallerZero,
4753 ErrorIf.AxisLargerRank,
4754 ErrorIf.ShapeOfAxisNotOne,
4755 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004756 shape[axis] = 1
4757 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4758 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004759
Matthew Haddond6ce7252021-09-29 15:35:44 +01004760 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004761 all_dtypes = [
4762 DType.INT8,
4763 DType.INT16,
4764 DType.INT32,
4765 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004766 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004767 DType.FP16,
4768 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004769 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004770 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4771 outputDType = rng.choice(wrong_dtypes)
4772 else:
4773 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004774
Matthew Haddond6ce7252021-09-29 15:35:44 +01004775 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004776
4777 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004778 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004779 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004780
4781 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4782 del shape[axis]
4783
4784 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4785 remove = rng.choice([True, False])
4786 if remove and len(shape) > 1:
4787 del shape[0]
4788 else:
4789 shape.append(1)
4790 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4791 for i in range(len(shape)):
4792 shape[i] = shape[i] + rng.integers(1, 10)
4793
4794 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004795 all_dtypes = [
4796 DType.INT8,
4797 DType.INT16,
4798 DType.INT32,
4799 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004800 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004801 DType.FP16,
4802 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004803 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004804 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4805 outputDType = rng.choice(wrong_dtypes)
4806 else:
4807 outputDType = DType.INT32
4808
4809 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004810
4811 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004812 def conv2dOp(
4813 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4814 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004815
4816 # IFM: NHWC
4817 # Filter: OHWI
4818 # OFM: NHWC
4819
Kevin Cheng550ccc52021-03-03 11:21:43 -08004820 h = (
4821 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004822 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004823 + padding[0]
4824 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004825 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004826 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004827
Kevin Cheng550ccc52021-03-03 11:21:43 -08004828 w = (
4829 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004830 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004831 + padding[2]
4832 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004833 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004834 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004835
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004836 if error_name == ErrorIf.ConvOutputShapeMismatch:
4837 choices = [1, 2, 3]
4838 change = rng.choice(choices)
4839 # increment in multiples of stride to not hit non-integer error case
4840 if change in [1, 3]:
4841 h = h + (rng.choice(choices) * strides[0])
4842 if change in [2, 3]:
4843 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004844
Eric Kunzee5e26762020-10-13 16:11:07 -07004845 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4846
James Ward8b390432022-08-12 20:48:56 +01004847 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004848 # Pick some potentially correct output dtype if input type is incorrect
4849 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004850 else:
James Ward8b390432022-08-12 20:48:56 +01004851 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004852
4853 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004854 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004855 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004856 else:
4857 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004858 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004859 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004860
Kevin Cheng550ccc52021-03-03 11:21:43 -08004861 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004862
4863 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004864 def conv3dOp(
4865 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4866 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004867
4868 # IFM: NDHWC
4869 # Filter: ODHWI
4870 # OFM: NDHWC
4871
4872 d = (
4873 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004874 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004875 + padding[0]
4876 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004877 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004878 ) // strides[0] + 1
4879
4880 h = (
4881 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004882 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004883 + padding[2]
4884 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004885 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004886 ) // strides[1] + 1
4887
4888 w = (
4889 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004890 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004891 + padding[4]
4892 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004893 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004894 ) // strides[2] + 1
4895
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004896 if error_name == ErrorIf.ConvOutputShapeMismatch:
4897 choices = [1, 2, 3, 4]
4898 change = rng.choice(choices)
4899 # increment in multiples of stride to not hit non-integer error case
4900 if change in [1, 4]:
4901 d = d + (rng.choice(choices) * strides[0])
4902 if change in [2, 4]:
4903 h = h + (rng.choice(choices) * strides[1])
4904 if change in [3, 4]:
4905 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004906
Kevin Cheng1533b852021-09-01 12:51:58 -07004907 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4908
James Ward8b390432022-08-12 20:48:56 +01004909 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004910 # Pick some potentially correct output dtype if input type is incorrect
4911 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004912 else:
James Ward8b390432022-08-12 20:48:56 +01004913 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004914
4915 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004916 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004917 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004918 else:
4919 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004920 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004921 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004922
4923 return ser.addOutput(ofm_shape, out_dtype)
4924
4925 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004926 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004927 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004928 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004929 # IFM: NHWC
4930 # Filter: HWCM
4931 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004932
Kevin Cheng550ccc52021-03-03 11:21:43 -08004933 h = (
4934 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004935 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004936 + padding[0]
4937 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004938 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004939 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004940
Kevin Cheng550ccc52021-03-03 11:21:43 -08004941 w = (
4942 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004943 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004944 + padding[2]
4945 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004946 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004947 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004948
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004949 if error_name == ErrorIf.ConvOutputShapeMismatch:
4950 choices = [1, 2, 3]
4951 change = rng.choice(choices)
4952 # increment in multiples of stride to not hit non-integer error case
4953 if change in [1, 3]:
4954 h = h + (rng.choice(choices) * strides[0])
4955 if change in [2, 3]:
4956 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004957
Eric Kunzee5e26762020-10-13 16:11:07 -07004958 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4959
James Ward8b390432022-08-12 20:48:56 +01004960 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004961 # Pick some potentially correct output dtype if input type is incorrect
4962 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004963 else:
James Ward8b390432022-08-12 20:48:56 +01004964 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004965
4966 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004967 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004968 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004969 else:
4970 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004971 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004972 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004973
Kevin Cheng550ccc52021-03-03 11:21:43 -08004974 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004975
4976 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004977 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004978 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004979 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004980 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004981 h = 1
4982 w = 1
4983 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004984 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4985 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004986
4987 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004988 choices = [1, 2, 3]
4989 change = rng.choice(choices)
4990 # increment in multiples of stride to not hit non-integer error case
4991 if change in [1, 3]:
4992 h = h + (rng.choice(choices) * stride[0])
4993 if change in [2, 3]:
4994 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004995 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004996
4997 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004998 all_dtypes = [
4999 DType.INT8,
5000 DType.INT16,
5001 DType.INT32,
5002 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005003 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005004 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005005 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005006 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005007 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5008 outputDType = rng.choice(wrong_dtypes)
5009 else:
5010 outputDType = ifm.dtype
5011
5012 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005013
5014 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005015 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005016 # input: N, IC
5017 # filter: OC, IC
5018 # output: N, OC
5019
5020 output_shape = [input.shape[0], filter.shape[0]]
5021
James Ward8b390432022-08-12 20:48:56 +01005022 # Validated in arg_gen (also invalidated for ErrorIf)
5023 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005024
Kevin Cheng550ccc52021-03-03 11:21:43 -08005025 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005026
5027 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005028 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005029 # a: N, H, C
5030 # b: N, C, W
5031 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005032
Kevin Cheng2d60f002021-06-09 14:18:32 -07005033 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005034
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005035 if error_name == ErrorIf.WrongOutputType:
5036 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005037 incorrect_types = (
5038 DType.INT4,
5039 DType.INT8,
5040 DType.INT16,
5041 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005042 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005043 DType.FP16,
5044 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005045 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005046 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005047 incorrect_types = (
5048 DType.INT4,
5049 DType.INT8,
5050 DType.INT16,
5051 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005052 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005053 DType.FP16,
5054 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005055 )
James Ward24dbc422022-10-19 12:20:31 +01005056 elif (
5057 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5058 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005059 incorrect_types = (
5060 DType.INT4,
5061 DType.INT8,
5062 DType.INT16,
5063 DType.INT32,
5064 DType.INT48,
5065 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005066 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005067 elif error_name == ErrorIf.WrongInputType:
5068 # Pick some potentially correct output dtype if input type is incorrect
5069 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005070 else:
James Ward8b390432022-08-12 20:48:56 +01005071 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005072
Kevin Cheng550ccc52021-03-03 11:21:43 -08005073 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005074
5075 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005076 def concatOp(ser, rng, axis, inputs, error_name=None):
5077 input1 = inputs[0]
5078 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005079
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005080 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005081 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005082 if not (
5083 # unable to concat tensors of different ranks
5084 error_name == ErrorIf.ConcatInputRankMismatch
5085 # unable to concat tensors along an invalid axis
5086 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005087 ):
5088 for tensor in remaining_inputs:
5089 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005090
Matthew Haddon01c359d2021-10-15 16:30:48 +01005091 if error_name == ErrorIf.ConcatShapeSumMismatch:
5092 output_shape[axis] += rng.integers(5, 10)
5093
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005094 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005095 all_dtypes = {
5096 DType.INT8,
5097 DType.INT16,
5098 DType.INT32,
5099 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005100 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005101 DType.FP16,
5102 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005103 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005104 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5105 outputDType = rng.choice(wrong_dtypes)
5106 else:
5107 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005108
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005109 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005110
5111 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005112 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005113
5114 output_shape = a.shape.copy()
5115
5116 for i in range(len(output_shape)):
5117 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5118
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005119 if error_name == ErrorIf.PadOutputShapeMismatch:
5120 bad_dim = rng.choice(range(len(output_shape)))
5121 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005122 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005123 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005124
Matthew Haddone807aae2021-10-11 18:12:58 +01005125 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005126 all_dtypes = [
5127 DType.INT8,
5128 DType.INT16,
5129 DType.INT32,
5130 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005131 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005132 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005133 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005134 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005135 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5136 outputDType = rng.choice(wrong_dtypes)
5137 else:
5138 outputDType = a.dtype
5139
5140 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005141
5142 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005143 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00005144 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00005145
5146 if error_name == ErrorIf.WrongOutputType:
5147 all_dtypes = [
5148 DType.INT8,
5149 DType.INT16,
5150 DType.INT32,
5151 DType.INT48,
5152 DType.FP32,
5153 DType.FP16,
5154 DType.BF16,
5155 ]
5156 wrong_dtypes = list(set(all_dtypes))
5157 outputDType = rng.choice(wrong_dtypes)
5158 else:
5159 outputDType = DType.SHAPE
5160
5161 return ser.addOutput(output_shape, outputDType)
5162
5163 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005164 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005165 output_shape = shape.copy()
5166
Matthew Haddone807aae2021-10-11 18:12:58 +01005167 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5168 for i in range(len(output_shape)):
5169 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5170
5171 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005172 all_dtypes = [
5173 DType.INT8,
5174 DType.INT16,
5175 DType.INT32,
5176 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005177 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005178 DType.FP16,
5179 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005180 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005181 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5182 outputDType = rng.choice(wrong_dtypes)
5183 else:
5184 outputDType = a.dtype
5185
5186 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005187
5188 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005189 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005190
Matthew Haddone807aae2021-10-11 18:12:58 +01005191 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005192 all_dtypes = [
5193 DType.INT8,
5194 DType.INT16,
5195 DType.INT32,
5196 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005197 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005198 DType.FP16,
5199 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005200 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005201 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005202 outputDType = rng.choice(wrong_dtypes)
5203 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005204 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005205
Luke Huttona4e48ca2023-02-22 11:53:48 +00005206 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005207 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005208 for index in range(len(output_shape)):
5209 if output_shape[index] <= 2:
5210 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5211 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005212 output_shape[index] = output_shape[index] + rng.choice(
5213 [-2, -1, 1, 2]
5214 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005215 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5216 output_shape = input.shape.copy()
5217 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005218 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005219
5220 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005221
5222 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005223 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005224
5225 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005226 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005227
5228 for i in range(len(output_shape)):
5229 output_shape[i] = a.shape[i] * multiples[i]
5230
Luke Huttona4e48ca2023-02-22 11:53:48 +00005231 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005232 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005233
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005234 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005235 all_dtypes = [
5236 DType.INT8,
5237 DType.INT16,
5238 DType.INT32,
5239 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005240 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005241 DType.FP16,
5242 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005243 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005244 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5245 outputDType = rng.choice(wrong_dtypes)
5246 else:
5247 outputDType = a.dtype
5248
5249 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005250
5251 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005252 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005253 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005254
Kevin Cheng550ccc52021-03-03 11:21:43 -08005255 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005256
Luke Huttona4e48ca2023-02-22 11:53:48 +00005257 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005258 for i in range(len(output_shape)):
5259 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005260
Luke Huttona4e48ca2023-02-22 11:53:48 +00005261 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5262 for i in range(len(output_shape)):
5263 output_shape[i] += rng.integers(1, 10)
5264 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005265 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005266
Matthew Haddone807aae2021-10-11 18:12:58 +01005267 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005268 all_dtypes = [
5269 DType.INT8,
5270 DType.INT16,
5271 DType.INT32,
5272 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005273 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005274 DType.FP16,
5275 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005276 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005277 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5278 outputDType = rng.choice(wrong_dtypes)
5279 else:
5280 outputDType = a.dtype
5281
5282 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005283
5284 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005285 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005286 if error_name != ErrorIf.WrongRank:
5287 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005288 assert len(indices.shape) == 2
5289 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005290
Kevin Cheng77d0f762020-11-24 10:26:32 -08005291 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5292
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005293 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005294 all_dtypes = [
5295 DType.INT8,
5296 DType.INT16,
5297 DType.INT32,
5298 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005299 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005300 DType.FP16,
5301 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005302 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005303 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5304 outputDType = rng.choice(wrong_dtypes)
5305 else:
5306 outputDType = values.dtype
5307
5308 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005309
5310 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005311 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005312 if error_name != ErrorIf.WrongRank:
5313 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005314 assert len(indices.shape) == 2
5315 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005316 assert values_in.shape[0] == indices.shape[0] # N
5317 assert input.shape[1] == indices.shape[1] # W
5318 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005319
5320 output_shape = values_in.shape
5321
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005322 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005323 all_dtypes = [
5324 DType.INT8,
5325 DType.INT16,
5326 DType.INT32,
5327 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005328 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005329 DType.FP16,
5330 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005331 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005332 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5333 outputDType = rng.choice(wrong_dtypes)
5334 else:
5335 outputDType = values_in.dtype
5336
5337 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005338
5339 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005340 def tableOp(ser, rng, input, error_name=None):
5341 # Same shape as the input, dtype dependent on input dtype
5342 if error_name != ErrorIf.WrongInputType:
5343 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005344 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005345 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005346 wrong_dtypes = [
5347 DType.INT8,
5348 DType.INT16,
5349 DType.INT32,
5350 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005351 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005352 DType.FP16,
5353 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005354 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005355 wrong_dtypes.remove(output_dtype)
5356 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005357 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005358
5359 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005360 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005361 serializer,
5362 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005363 input,
5364 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005365 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005366 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005367 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005368 input_dtype,
5369 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005370 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005371 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005372 # Calculate OH, OW
5373 scale_y_n = scale[0]
5374 scale_y_d = scale[1]
5375 scale_x_n = scale[2]
5376 scale_x_d = scale[3]
5377 if error_name == ErrorIf.ScaleSmallerEqualZero:
5378 scale_y_n = max(scale_y_n, 1)
5379 scale_y_d = max(scale_y_d, 1)
5380 scale_x_n = max(scale_x_n, 1)
5381 scale_x_d = max(scale_x_d, 1)
5382
5383 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5384 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5385
5386 if error_name is not None:
5387 # Make sure the output tensor is valid, which can occur when
5388 # scale, offset or border have been changed for ERROR_IFs
5389 oh = max(oh, 1)
5390 ow = max(ow, 1)
5391 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005392 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5393 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005394
5395 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5396 choices = [1, 2, 3]
5397 change = rng.choice(choices)
5398 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5399 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005400 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005401 oh -= scale_y_d
5402 assert oh > 0 # Should have been caught in agResize
5403 else:
5404 oh += scale_y_d
5405 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005406 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005407 ow -= scale_x_d
5408 assert ow > 0 # Should have been caught in agResize
5409 else:
5410 ow += scale_x_d
5411
Matthew Haddon848efb42021-09-09 12:30:53 +01005412 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005413 output_dims = [
5414 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005415 oh,
5416 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005417 input.shape[0],
5418 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005419 elif error_name == ErrorIf.BatchMismatch:
5420 output_dims = [
5421 input.shape[0] + rng.integers(1, 10),
5422 oh,
5423 ow,
5424 input.shape[3],
5425 ]
5426 elif error_name == ErrorIf.ChannelMismatch:
5427 output_dims = [
5428 input.shape[0],
5429 oh,
5430 ow,
5431 input.shape[3] + rng.integers(1, 10),
5432 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005433 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005434 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005435
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005436 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005437
5438 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005439 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005440 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005441
5442 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005443 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005444 if error_name == ErrorIf.ConvOutputShapeMismatch:
5445 choices = [1, 2, 3]
5446 change = rng.choice(choices)
5447 if change in [1, 3]:
5448 output_shape[1] = output_shape[1] + rng.choice(choices)
5449 if change in [2, 3]:
5450 output_shape[2] = output_shape[2] + rng.choice(choices)
5451
James Ward8b390432022-08-12 20:48:56 +01005452 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005453 # Pick some potentially correct output dtype if input type is incorrect
5454 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005455 else:
James Ward8b390432022-08-12 20:48:56 +01005456 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005457
5458 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005459 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005460 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005461 else:
5462 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005463 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005464 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005465
Kevin Cheng550ccc52021-03-03 11:21:43 -08005466 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005467
5468 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005469 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5470 outputs = []
5471
5472 assert ifm1.dtype == ifm2.dtype
5473 input_dtype = ifm1.dtype
5474
5475 if error_name != ErrorIf.FFTInputShapeMismatch:
5476 assert ifm1.shape == ifm2.shape
5477
5478 input_shape = ifm1.shape
5479 if error_name != ErrorIf.WrongRank:
5480 assert len(input_shape) == 3
5481
5482 output_shape = input_shape.copy()
5483 output_dtype = input_dtype
5484
5485 if error_name == ErrorIf.WrongOutputType:
5486 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005487 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005488 output_dtype = rng.choice(wrong_dtypes)
5489 elif error_name == ErrorIf.BatchMismatch:
5490 output_shape[0] += rng.integers(1, 10)
5491 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5492 modify_dim = rng.choice([1, 2])
5493 output_shape[modify_dim] += rng.integers(1, 10)
5494
5495 outputs.append(serializer.addOutput(output_shape, output_dtype))
5496 outputs.append(serializer.addOutput(output_shape, output_dtype))
5497 return outputs
5498
5499 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005500 def rfft2dOp(serializer, rng, value, error_name=None):
5501 outputs = []
5502
5503 input_shape = value.shape
5504 if error_name != ErrorIf.WrongRank:
5505 assert len(input_shape) == 3
5506
5507 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5508
5509 output_dtype = value.dtype
5510 if error_name == ErrorIf.WrongOutputType:
5511 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005512 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005513 output_dtype = rng.choice(wrong_dtypes)
5514 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005515 output_shape[0] += rng.integers(1, 10)
5516 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5517 modify_dim = rng.choice([1, 2])
5518 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005519
5520 outputs.append(serializer.addOutput(output_shape, output_dtype))
5521 outputs.append(serializer.addOutput(output_shape, output_dtype))
5522 return outputs