blob: 2290c54f3f53af5ba76995b8729478f9d13148a5 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
79 for dtype in (DType.FP32, DType.FP16, DType.BF16):
80 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100155 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
170 elif dtype in (DType.INT32, DType.SHAPE):
171 # restricting too large value for SHAPE
172 rng = (-(1 << 31), (1 << 31))
173 elif dtype == DType.INT48:
174 rng = (-(1 << 47), (1 << 47))
175 else:
176 raise Exception("Unknown dtype: {}".format(dtype))
177
178 if not high_inclusive:
179 # Exclusive high: low <= range < high
180 return rng
181 else:
182 # Inclusive range: low <= range <= high
183 return (rng[0], rng[1] - 1)
184
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000185 def getRandTensor(self, shape, dtype, data_range=None):
186 if data_range is None:
187 low, high = self.getDTypeRange(dtype)
188 else:
189 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100190
Eric Kunzee5e26762020-10-13 16:11:07 -0700191 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100194 return np.int64(self.rng.integers(low=low, high=high, size=shape))
195 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
196 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
197
198 if dtype == DType.FP16:
199 return np.float16(f_tensor)
200 else:
201 f32_tensor = np.float32(f_tensor)
202 if dtype == DType.BF16:
203 # Floor the last 16 bits of each f32 value
204 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
205 else:
206 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700207 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100208 # All other integer types
209 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 placeholders = []
213
Kevin Cheng989cb052021-04-28 16:29:44 -0700214 assert len(shape_list) == len(dtype_list)
215
Jeremy Johnson1271c442023-09-05 11:39:26 +0100216 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700217 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100218 if not self.args.lazy_data_gen:
219 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700220 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700221
222 return placeholders
223
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700225 consts = []
226
Kevin Cheng989cb052021-04-28 16:29:44 -0700227 assert len(shape_list) == len(dtype_list)
228
Jeremy Johnson1271c442023-09-05 11:39:26 +0100229 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700230 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100231 if not self.args.lazy_data_gen:
232 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700233 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700234
235 return consts
236
237 def makeShape(self, rank):
238 if self.targetted_shape:
239 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800240 return np.int32(
241 self.rng.integers(
242 low=self.args.tensor_shape_range[0],
243 high=self.args.tensor_shape_range[1],
244 size=rank,
245 )
246 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700247
248 def setTargetShape(self, shape):
249 self.targetted_shape = shape
250
251 def randInt(self, low=0, high=256):
252 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
253
254 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100255 low, high = self.getDTypeRange(dtype)
256
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100257 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100258 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100259 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100261 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100262 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
263 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700264 elif dtype == DType.BOOL:
265 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700266 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700267 # Special size
268 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700269
270 return np.int32(self.rng.integers(low, high, size=1))[0]
271
272 def shapeStr(self, shape):
273
274 sStr = []
275 # Convert to strings
276 for i in shape:
277 sStr.append(str(i))
278
Kevin Cheng550ccc52021-03-03 11:21:43 -0800279 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700280
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100281 def typeStr(self, dtype):
282 if isinstance(dtype, list) or isinstance(dtype, tuple):
283 assert len(dtype) >= 2
284 strs = [self.typeStr(t) for t in dtype]
285 # Limit types to the first 2 as the 3rd is the accumulator
286 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700287 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100288 if dtype in gtu.DTYPE_ATTRIBUTES:
289 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700290 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100291 raise Exception(
292 "Unknown dtype, cannot convert to string: {}".format(dtype)
293 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700294
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100295 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100296 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100297 if dtype in gtu.DTYPE_ATTRIBUTES:
298 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700299 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100300 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700301
Luke Hutton57287132023-02-06 14:54:18 +0000302 def constrictBatchSize(self, shape):
303 # Limit the batch size unless an explicit target shape set
304 if self.args.max_batch_size and not self.args.target_shapes:
305 shape[0] = min(shape[0], self.args.max_batch_size)
306 return shape
307
James Ward30124a82023-02-02 14:56:33 +0000308 def makeDimension(self):
309 return self.randInt(
310 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
311 )
312
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100313 def tensorComplianceMetaData(
314 self, op, inputType, argsDict, outputTensor, errorName
315 ):
Jeremy Johnson708da822023-11-15 16:25:45 +0000316 # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet
317 UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100318 if (
319 errorName
320 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000321 or (
322 not gtu.dtypeIsSupportedByCompliance(inputType)
323 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
324 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100325 ):
326 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100327 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100328
Jeremy Johnson1271c442023-09-05 11:39:26 +0100329 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100330 compliance_tens = {
331 "mode": None,
332 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
333 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
334 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100335 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
336 mode = gtu.ComplianceMode.DOT_PRODUCT
337 compliance_tens["dot_product_info"] = {
338 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100339 "ks": int(argsDict["ksb"])
340 if "ksb" in argsDict
341 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100342 }
343 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
344 mode = gtu.ComplianceMode.FP_SPECIAL
345 elif "compliance" in op and "ulp" in op["compliance"]:
346 mode = gtu.ComplianceMode.ULP
347 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
348 elif op["op"] == Op.REDUCE_PRODUCT:
349 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnson534923d2023-12-04 11:11:06 +0000350 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000351 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnson1271c442023-09-05 11:39:26 +0100352 else:
353 mode = gtu.ComplianceMode.EXACT
354 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
355
356 return compliance_tens
357
358 # Build Op functions
359 # Create the output tensor (calling OutputShaper as needed)
360 # Do final tweaks to attributes (if necessary for errorIf)
361 # Add Op into graph
362 # Return resulting tensor information or BuildInfo
363
364 class BuildInfo:
365 """Enhanced build information containing result tensor and associated compliance dict."""
366
367 def __init__(self, resultTensor, complianceDict):
368 self.resultTensor = resultTensor
369 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700370
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000371 def build_unary(
372 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
373 ):
374 assert len(inputs) == 1
375 a = inputs[0]
376 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100377
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000378 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100379
380 # Ensure new output type has correct qinfo
381 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000382 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000383 qinfo = [
384 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000385 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000386 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100387
388 # Invalidate Input/Output list for error if checks.
389 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000390 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100391 pCount, cCount = op["operands"]
392 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
394 self, error_name, input_list, output_list
395 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100396
Les Bell729b0352021-11-24 10:28:21 +0000397 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100398 self.ser,
399 validator_fcns,
400 error_name,
401 op=op,
402 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000403 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000404 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000405 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100406 input_list=input_list,
407 output_list=output_list,
408 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000409 ):
410 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100411
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000412 attr = None
413 if op["op"] == Op.NEGATE:
414 attr = ts.TosaSerializerAttribute()
415 attr.NegateAttribute(qinfo[0], qinfo[1])
416
417 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000418
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000419 compliance = self.tensorComplianceMetaData(
420 op, a.dtype, args_dict, result_tensor, error_name
421 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000422 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700423
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000424 def build_binary_broadcast(
425 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
426 ):
427 assert len(inputs) == 2
428 a, b = inputs
429 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000430 self.ser, self.rng, a, b, error_name
431 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100432
433 # Invalidate Input/Output list for error if checks.
434 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000435 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100436 pCount, cCount = op["operands"]
437 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000438 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
439 self, error_name, input_list, output_list
440 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100441
Les Bell729b0352021-11-24 10:28:21 +0000442 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100443 self.ser,
444 validator_fcns,
445 error_name,
446 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000447 input1=a,
448 input2=b,
449 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000450 output_dtype=result_tensor.dtype,
451 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100452 input_list=input_list,
453 output_list=output_list,
454 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000455 ):
456 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100457
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000458 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000459
Jeremy Johnson9a758382023-11-07 16:27:35 +0000460 compliance = self.tensorComplianceMetaData(
461 op, a.dtype, args_dict, result_tensor, error_name
462 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000463
464 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700465
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100466 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700467 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000468 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700469 return result_tens
470
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000471 def build_arithmetic_right_shift(
472 self, op, a, b, round, validator_fcns=None, error_name=None
473 ):
474 result_tens = OutputShaper.binaryBroadcastOp(
475 self.ser, self.rng, a, b, error_name
476 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100477
478 # Invalidate Input/Output list for error if checks.
479 input_list = [a.name, b.name]
480 output_list = [result_tens.name]
481 pCount, cCount = op["operands"]
482 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000483 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
484 self, error_name, input_list, output_list
485 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100486
Les Bell729b0352021-11-24 10:28:21 +0000487 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100488 self.ser,
489 validator_fcns,
490 error_name,
491 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000492 input1=a,
493 input2=b,
494 input_dtype=a.dtype,
495 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000496 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100497 input_list=input_list,
498 output_list=output_list,
499 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000500 ):
501 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800502
503 attr = ts.TosaSerializerAttribute()
504 attr.ArithmeticRightShiftAttribute(round)
505
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000506 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800507 return result_tens
508
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100509 def build_mul(
510 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
511 ):
512 assert len(inputs) == 2
513 a, b = inputs
514 shift = args_dict["shift"]
515
516 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000517 self.ser, self.rng, a, b, error_name
518 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700519
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100520 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100521 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100522 result_tensor.setDtype(DType.INT32)
523
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100524 if error_name == ErrorIf.WrongOutputType:
525 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
526 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100527 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100528
529 # Invalidate Input/Output list for error if checks.
530 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100531 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100532 pCount, cCount = op["operands"]
533 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000534 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
535 self, error_name, input_list, output_list
536 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100537
Les Bell729b0352021-11-24 10:28:21 +0000538 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100539 self.ser,
540 validator_fcns,
541 error_name,
542 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000543 input1=a,
544 input2=b,
545 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100546 output_dtype=result_tensor.dtype,
547 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100548 input_list=input_list,
549 output_list=output_list,
550 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000551 ):
552 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700553
Kevin Chengaee1fac2020-11-11 13:54:06 -0800554 attr = ts.TosaSerializerAttribute()
555 attr.MulAttribute(shift)
556
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000557 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100558
559 compliance = self.tensorComplianceMetaData(
560 op, a.dtype, args_dict, result_tensor, error_name
561 )
562
563 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700564
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100565 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
566 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700567
Kevin Chengfe392ce2021-10-18 21:51:55 +0000568 attr = ts.TosaSerializerAttribute()
569 attr.TableAttribute(table)
570
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100571 # Invalidate Input/Output list for error if checks.
572 input_list = [a.name]
573 output_list = [result_tens.name]
574 pCount, cCount = op["operands"]
575 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000576 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
577 self, error_name, input_list, output_list
578 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100579
Les Bell729b0352021-11-24 10:28:21 +0000580 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100581 self.ser,
582 validator_fcns,
583 error_name,
584 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000585 input_shape=a.shape,
586 input_dtype=a.dtype,
587 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000588 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100589 input_list=input_list,
590 output_list=output_list,
591 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000592 ):
593 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100594
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000595 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700596
597 return result_tens
598
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100599 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
600 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
601
602 # Invalidate Input/Output list for error if checks.
603 input_list = [cond.name, a.name, b.name]
604 output_list = [result_tens.name]
605 pCount, cCount = op["operands"]
606 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000607 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
608 self, error_name, input_list, output_list
609 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100610
Les Bell729b0352021-11-24 10:28:21 +0000611 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100612 self.ser,
613 validator_fcns,
614 error_name,
615 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000616 input1=cond,
617 input2=a,
618 input3=b,
619 input_shape=a.shape,
620 input_dtype=a.dtype,
621 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000622 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100623 input_list=input_list,
624 output_list=output_list,
625 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000626 ):
627 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100628
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000629 self.ser.addOperator(
630 op["op"],
631 input_list,
632 output_list,
633 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700634 return result_tens
635
Jeremy Johnsona0150012023-11-15 15:52:06 +0000636 def build_comparison(
637 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
638 ):
639 assert len(inputs) == 2
640 a, b = inputs
641
642 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000643 self.ser, self.rng, a, b, error_name
644 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100645
646 # Invalidate Input/Output list for error if checks.
647 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000648 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100649 pCount, cCount = op["operands"]
650 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000651 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
652 self, error_name, input_list, output_list
653 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100654
Les Bell729b0352021-11-24 10:28:21 +0000655 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100656 self.ser,
657 validator_fcns,
658 error_name,
659 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000660 input1=a,
661 input2=b,
662 input_shape=a.shape,
663 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000664 output_shape=result_tensor.shape,
665 output_dtype=result_tensor.dtype,
666 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100667 input_list=input_list,
668 output_list=output_list,
669 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000670 ):
671 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100672
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000673 self.ser.addOperator(
674 op["op"],
675 input_list,
676 output_list,
677 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000678
679 compliance = self.tensorComplianceMetaData(
680 op, a.dtype, args_dict, result_tensor, error_name
681 )
682 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700683
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000684 def build_argmax(
685 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
686 ):
687 assert len(inputs) == 1
688 a = inputs[0]
689 axis = args_dict["axis"]
690 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100691
692 # Invalidate Input/Output list for error if checks.
693 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000694 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100695 pCount, cCount = op["operands"]
696 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000697 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
698 self, error_name, input_list, output_list
699 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100700
Les Bell729b0352021-11-24 10:28:21 +0000701 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100702 self.ser,
703 validator_fcns,
704 error_name,
705 op=op,
706 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000707 input_shape=a.shape,
708 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000709 output_shape=result_tensor.shape,
710 output_dtype=result_tensor.dtype,
711 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100712 input_list=input_list,
713 output_list=output_list,
714 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000715 ):
716 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700717
718 attr = ts.TosaSerializerAttribute()
719 attr.AxisAttribute(axis)
720
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000721 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000722
723 compliance = self.tensorComplianceMetaData(
724 op, inputs[0].dtype, args_dict, result_tensor, error_name
725 )
726 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700727
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000728 def build_pool2d(
729 self,
730 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100731 inputs,
732 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000733 validator_fcns=None,
734 error_name=None,
735 qinfo=None,
736 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100737 assert len(inputs) == 1
738 input = inputs[0]
739 # max_pool has no accum_dtype
740 accum_dtype = (
741 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
742 )
743 stride = args_dict["stride"]
744 pad = args_dict["pad"]
745 kernel = args_dict["kernel"]
746
Jeremy Johnson0601f802023-11-08 16:28:09 +0000747 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000748 self.ser, self.rng, input, kernel, stride, pad, error_name
749 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100750
751 # Ensure new output type has correct qinfo
752 if error_name == ErrorIf.WrongInputType:
753 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000754 qinfo = [
755 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000756 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000757 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100758
759 # Invalidate Input/Output list for error if checks.
760 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000761 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100762 pCount, cCount = op["operands"]
763 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000764 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
765 self, error_name, input_list, output_list
766 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100767
Les Bell729b0352021-11-24 10:28:21 +0000768 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100769 self.ser,
770 validator_fcns,
771 error_name,
772 op=op,
773 input_shape=input.shape,
774 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000775 output_shape=result_tensor.shape,
776 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100777 kernel=kernel,
778 stride=stride,
779 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000780 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000781 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100782 input_list=input_list,
783 output_list=output_list,
784 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000785 ):
786 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700787
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000788 if qinfo is None:
789 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700790
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000791 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100792 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000793
794 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700795
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100796 compliance = self.tensorComplianceMetaData(
797 op, inputs[0].dtype, args_dict, result_tensor, error_name
798 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100799
800 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100801
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000802 def build_conv2d(
803 self,
804 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100805 inputs,
806 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000807 validator_fcns=None,
808 error_name=None,
809 qinfo=None,
810 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100811 assert len(inputs) == 3
812 ifm, filter, bias = inputs
813 accum_dtype = args_dict["acc_type"]
814 strides = args_dict["stride"]
815 padding = args_dict["pad"]
816 dilations = args_dict["dilation"]
817
Kevin Cheng550ccc52021-03-03 11:21:43 -0800818 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100819 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100820 self.ser,
821 self.rng,
822 ifm,
823 filter,
824 accum_dtype,
825 strides,
826 padding,
827 dilations,
828 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000829 )
830
831 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000832 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
833 DType.INT8,
834 DType.UINT8,
835 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000836 qinfo = [
837 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100838 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000839 ]
Les Bell0e027d42021-11-09 14:42:14 +0000840
841 # Invalidate Input/Output list for error_if checks.
842 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100843 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000844 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000845 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
846 self, error_name, input_list, output_list
847 )
Les Bell0e027d42021-11-09 14:42:14 +0000848
Les Bell729b0352021-11-24 10:28:21 +0000849 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000850 self.ser,
851 validator_fcns,
852 error_name,
853 op=op,
854 input_dtype=ifm.dtype,
855 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100856 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000857 qinfo=qinfo,
858 input_list=input_list,
859 num_operands=num_operands,
860 output_list=output_list,
861 pad=padding,
862 stride=strides,
863 dilation=dilations,
864 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100865 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100866 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000867 ):
868 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700869
Tai Lyd3797f02023-11-15 23:06:19 +0000870 # TODO - Test local_bound, for now set local bound attribute to False
871 local_bound = False
872
Eric Kunzee5e26762020-10-13 16:11:07 -0700873 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000874 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700875
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000876 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100877
878 compliance = self.tensorComplianceMetaData(
879 op, ifm.dtype, args_dict, result_tensor, error_name
880 )
881
882 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700883
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000884 def build_conv3d(
885 self,
886 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100887 inputs,
888 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000889 validator_fcns=None,
890 error_name=None,
891 qinfo=None,
892 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100893 assert len(inputs) == 3
894 ifm, filter, bias = inputs
895 accum_dtype = args_dict["acc_type"]
896 strides = args_dict["stride"]
897 padding = args_dict["pad"]
898 dilations = args_dict["dilation"]
899
Kevin Cheng1533b852021-09-01 12:51:58 -0700900 assert len(padding) == 6
901 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100902 self.ser,
903 self.rng,
904 ifm,
905 filter,
906 accum_dtype,
907 strides,
908 padding,
909 dilations,
910 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000911 )
912
913 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000914 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
915 DType.INT8,
916 DType.UINT8,
917 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000918 qinfo = [
919 TosaQuantGen.getZeroPoint(self, ifm.dtype),
920 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
921 ]
Les Bell0e027d42021-11-09 14:42:14 +0000922
923 # Invalidate Input/Output list for error_if checks.
924 input_list = [ifm.name, filter.name, bias.name]
925 output_list = [result_tens.name]
926 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000927 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
928 self, error_name, input_list, output_list
929 )
Les Bell0e027d42021-11-09 14:42:14 +0000930
Les Bell729b0352021-11-24 10:28:21 +0000931 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000932 self.ser,
933 validator_fcns,
934 error_name,
935 op=op,
936 input_dtype=ifm.dtype,
937 weight_dtype=filter.dtype,
938 output_dtype=result_tens.dtype,
939 qinfo=qinfo,
940 input_list=input_list,
941 num_operands=num_operands,
942 output_list=output_list,
943 pad=padding,
944 stride=strides,
945 dilation=dilations,
946 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100947 weight_shape=filter.shape,
948 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000949 ):
950 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700951
Tai Lyd3797f02023-11-15 23:06:19 +0000952 # TODO - Test local_bound, for now set local bound attribute to False
953 local_bound = False
954
Kevin Cheng1533b852021-09-01 12:51:58 -0700955 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000956 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700957
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000958 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700959 return result_tens
960
Kevin Cheng550ccc52021-03-03 11:21:43 -0800961 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000962 self,
963 op,
964 ifm,
965 filter,
966 bias,
James Ward8b390432022-08-12 20:48:56 +0100967 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000968 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700969 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000970 output_shape,
971 validator_fcns=None,
972 error_name=None,
973 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800974 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700975 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000976 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100977 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000978 )
Les Bell0e027d42021-11-09 14:42:14 +0000979
980 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000981 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
982 DType.INT8,
983 DType.UINT8,
984 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000985 qinfo = [
986 TosaQuantGen.getZeroPoint(self, ifm.dtype),
987 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
988 ]
Les Bell0e027d42021-11-09 14:42:14 +0000989
990 # Invalidate Input/Output list for error_if checks.
991 input_list = [ifm.name, filter.name, bias.name]
992 output_list = [result_tens.name]
993 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000994 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
995 self, error_name, input_list, output_list
996 )
Les Bell0e027d42021-11-09 14:42:14 +0000997
Les Bell729b0352021-11-24 10:28:21 +0000998 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000999 self.ser,
1000 validator_fcns,
1001 error_name,
1002 op=op,
1003 input_dtype=ifm.dtype,
1004 weight_dtype=filter.dtype,
1005 output_dtype=result_tens.dtype,
1006 qinfo=qinfo,
1007 input_list=input_list,
1008 num_operands=num_operands,
1009 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001010 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001011 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001012 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001013 weight_shape=filter.shape,
1014 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001015 ):
1016 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001017
Tai Lyd3797f02023-11-15 23:06:19 +00001018 # TODO - Test local_bound, for now set local bound attribute to False
1019 local_bound = False
1020
Eric Kunzee5e26762020-10-13 16:11:07 -07001021 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001022 attr.TransposeConvAttribute(
1023 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1024 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001025
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001026 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001027 return result_tens
1028
Kevin Cheng550ccc52021-03-03 11:21:43 -08001029 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001030 self,
1031 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001032 inputs,
1033 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001034 validator_fcns=None,
1035 error_name=None,
1036 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001037 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001038 assert len(inputs) == 3
1039 ifm, filter, bias = inputs
1040 accum_dtype = args_dict["acc_type"]
1041 strides = args_dict["stride"]
1042 padding = args_dict["pad"]
1043 dilations = args_dict["dilation"]
1044
Kevin Cheng550ccc52021-03-03 11:21:43 -08001045 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001046 self.ser,
1047 self.rng,
1048 ifm,
1049 filter,
1050 accum_dtype,
1051 strides,
1052 padding,
1053 dilations,
1054 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001055 )
1056
1057 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001058 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1059 DType.INT8,
1060 DType.UINT8,
1061 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001062 qinfo = [
1063 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1064 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1065 ]
Les Bell0e027d42021-11-09 14:42:14 +00001066
1067 # Invalidate Input/Output list for error_if checks.
1068 input_list = [ifm.name, filter.name, bias.name]
1069 output_list = [result_tens.name]
1070 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001071 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1072 self, error_name, input_list, output_list
1073 )
Les Bell0e027d42021-11-09 14:42:14 +00001074
Les Bell729b0352021-11-24 10:28:21 +00001075 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001076 self.ser,
1077 validator_fcns,
1078 error_name,
1079 op=op,
1080 input_dtype=ifm.dtype,
1081 weight_dtype=filter.dtype,
1082 output_dtype=result_tens.dtype,
1083 qinfo=qinfo,
1084 input_list=input_list,
1085 num_operands=num_operands,
1086 output_list=output_list,
1087 pad=padding,
1088 stride=strides,
1089 dilation=dilations,
1090 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001091 weight_shape=filter.shape,
1092 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001093 ):
1094 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001095
Tai Lyd3797f02023-11-15 23:06:19 +00001096 # TODO - Test local_bound, for now set local bound attribute to False
1097 local_bound = False
1098
Eric Kunzee5e26762020-10-13 16:11:07 -07001099 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001100 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001101
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001102 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001103 return result_tens
1104
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001105 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001106 self,
1107 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001108 inputs,
1109 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001110 validator_fcns=None,
1111 error_name=None,
1112 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001113 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001114 assert len(inputs) == 3
1115 ifm, filter, bias = inputs
1116 accum_dtype = args_dict["acc_type"]
1117
1118 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001119 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001120 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001121
1122 # Invalidate Input/Output list for error if checks.
1123 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001124 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001125 pCount, cCount = op["operands"]
1126 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001127 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1128 self, error_name, input_list, output_list
1129 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001130
Les Bell729b0352021-11-24 10:28:21 +00001131 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001132 self.ser,
1133 validator_fcns,
1134 error_name,
1135 op=op,
1136 input_shape=ifm.shape,
1137 input_dtype=ifm.dtype,
1138 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001139 output_shape=result_tensor.shape,
1140 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001141 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001142 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001143 input_list=input_list,
1144 output_list=output_list,
1145 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001146 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001147 ):
1148 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001149
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001150 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001151 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001152
1153 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001154
1155 compliance = self.tensorComplianceMetaData(
1156 op, ifm.dtype, args_dict, result_tensor, error_name
1157 )
1158
1159 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001160
James Ward8b390432022-08-12 20:48:56 +01001161 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001162 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001163 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001164 assert len(inputs) == 2
1165 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001166 accum_dtype = args_dict["acc_type"]
1167 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001168 self.ser, self.rng, a, b, accum_dtype, error_name
1169 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001170
1171 # Invalidate Input/Output list for error if checks.
1172 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001173 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001174 pCount, cCount = op["operands"]
1175 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001176 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1177 self, error_name, input_list, output_list
1178 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001179
Les Bell729b0352021-11-24 10:28:21 +00001180 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001181 self.ser,
1182 validator_fcns,
1183 error_name,
1184 op=op,
1185 input_shape=a.shape,
1186 input_dtype=a.dtype,
1187 input2_shape=b.shape,
1188 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001189 output_shape=result_tensor.shape,
1190 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001191 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001192 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001193 input_list=input_list,
1194 output_list=output_list,
1195 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001196 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001197 ):
1198 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001199
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001200 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001201 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001202
1203 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001204
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001205 compliance = self.tensorComplianceMetaData(
1206 op, a.dtype, args_dict, result_tensor, error_name
1207 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001208
1209 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001210
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001211 def build_reduce(
1212 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1213 ):
1214 assert len(inputs) == 1
1215 a = inputs[0]
1216 axis = args_dict["axis"]
1217 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001218
1219 # Invalidate Input/Output list for error if checks.
1220 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001221 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001222 pCount, cCount = op["operands"]
1223 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001224 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1225 self, error_name, input_list, output_list
1226 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001227
Les Bell729b0352021-11-24 10:28:21 +00001228 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001229 self.ser,
1230 validator_fcns,
1231 error_name,
1232 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001233 axis=axis,
1234 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001235 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001236 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001237 output_dtype=result_tensor.dtype,
1238 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001239 input_list=input_list,
1240 output_list=output_list,
1241 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001242 ):
1243 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001244
1245 attr = ts.TosaSerializerAttribute()
1246 attr.AxisAttribute(axis)
1247
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001248 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001249
1250 if op["op"] == Op.REDUCE_PRODUCT:
1251 # TODO: Add compliance support!
1252 compliance = None
1253 else:
1254 compliance = self.tensorComplianceMetaData(
1255 op, a.dtype, args_dict, result_tensor, error_name
1256 )
1257
1258 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001259
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001260 def build_clamp(
1261 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1262 ):
1263 assert len(inputs) == 1
1264 a = inputs[0]
1265
1266 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001267
Jeremy Johnson18e26662021-07-22 16:15:29 +01001268 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001269
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001270 if error_name == ErrorIf.MaxSmallerMin:
1271 # Make sure the numbers are different to invoke this error
1272 while v[0] == v[1]:
1273 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1274 max_val = min(v)
1275 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001276 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001277 max_val = max(v)
1278 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001279
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001280 # Invalidate Input/Output list for error if checks.
1281 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001282 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001283 pCount, cCount = op["operands"]
1284 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001285 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1286 self, error_name, input_list, output_list
1287 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001288
Les Bell729b0352021-11-24 10:28:21 +00001289 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001290 self.ser,
1291 validator_fcns,
1292 error_name,
1293 op=op,
1294 max_val=max_val,
1295 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001296 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001297 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001298 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001299 output_dtype=result_tensor.dtype,
1300 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001301 input_list=input_list,
1302 output_list=output_list,
1303 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001304 ):
1305 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001306
1307 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001308 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1309 if a.dtype == DType.FP16:
1310 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1311 min_val = min_val.astype(np.float32)
1312 max_val = max_val.astype(np.float32)
1313
1314 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001315 else:
James Ward34071252022-12-07 15:48:47 +00001316 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001317
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001318 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001319
1320 compliance = self.tensorComplianceMetaData(
1321 op, a.dtype, args_dict, result_tensor, error_name
1322 )
1323
1324 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001325
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001326 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1327 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001328 attr = ts.TosaSerializerAttribute()
1329
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001330 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001331
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001332 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001333 return result_tens
1334
1335 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001336 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1337 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001338
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001339 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001340 return result_tens
1341
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001342 def build_activation(
1343 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1344 ):
1345 assert len(inputs) == 1
1346 a = inputs[0]
1347
1348 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001349
1350 # Invalidate Input/Output list for error if checks.
1351 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001352 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001353 pCount, cCount = op["operands"]
1354 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001355 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1356 self, error_name, input_list, output_list
1357 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001358
Les Bell729b0352021-11-24 10:28:21 +00001359 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001360 self.ser,
1361 validator_fcns,
1362 error_name,
1363 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001364 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001365 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001366 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001367 output_dtype=result_tensor.dtype,
1368 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001369 input_list=input_list,
1370 output_list=output_list,
1371 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001372 ):
1373 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001375 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001376
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001377 compliance = self.tensorComplianceMetaData(
1378 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001379 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001380
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001381 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001382
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001383 def build_concat(
1384 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1385 ):
1386 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001387 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001388 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001389
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001390 result_tensor = OutputShaper.concatOp(
1391 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001392 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001393
Matthew Haddon818ab902021-07-27 09:12:49 +01001394 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001395 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001396 input_tensor_names.append(tensor.name)
1397
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001398 # Invalidate Input/Output list for error if checks.
1399 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001400 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001401 pCount, cCount = op["operands"]
1402 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001403 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1404 self, error_name, input_list, output_list
1405 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001406
Les Bell729b0352021-11-24 10:28:21 +00001407 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001408 self.ser,
1409 validator_fcns,
1410 error_name,
1411 op=op,
1412 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001413 input_shape=inputs[0].shape,
1414 output_shape=result_tensor.shape,
1415 input_dtype=inputs[0].dtype,
1416 output_dtype=result_tensor.dtype,
1417 inputs=inputs,
1418 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001419 input_list=input_list,
1420 output_list=output_list,
1421 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001422 ):
1423 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424
1425 attr = ts.TosaSerializerAttribute()
1426 attr.AxisAttribute(axis)
1427
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001428 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001429 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001430
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001431 def build_pad(
1432 self,
1433 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001434 inputs,
1435 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001436 validator_fcns=None,
1437 error_name=None,
1438 qinfo=None,
1439 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001440 assert len(inputs) == 1
1441 a = inputs[0]
1442 padding = args_dict["pad"]
1443 pad_const_int = args_dict["pad_const_int"]
1444 pad_const_float = args_dict["pad_const_fp"]
1445
1446 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001447
Kevin Chengfe392ce2021-10-18 21:51:55 +00001448 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001449 attr.PadAttribute(
1450 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1451 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001452
Matthew Haddone807aae2021-10-11 18:12:58 +01001453 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001454 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001455 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001456 pCount, cCount = op["operands"]
1457 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001458 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1459 self, error_name, input_list, output_list
1460 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001461
Les Bell729b0352021-11-24 10:28:21 +00001462 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001463 self.ser,
1464 validator_fcns,
1465 error_name,
1466 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001467 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001468 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001469 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001470 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001471 pad=padding,
1472 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001473 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001474 input_list=input_list,
1475 output_list=output_list,
1476 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001477 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001478 ):
1479 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001480
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001481 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001482
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001483 compliance = self.tensorComplianceMetaData(
1484 op, a.dtype, args_dict, result_tensor, error_name
1485 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001486
1487 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001488
Won Jeona21b2e82023-08-10 10:33:01 +00001489 def build_dim(
1490 self,
1491 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001492 inputs,
1493 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001494 validator_fcns=None,
1495 error_name=None,
1496 qinfo=None,
1497 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001498 assert len(inputs) == 1
1499 a = inputs[0]
1500 axis = args_dict["axis"]
1501 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001502
1503 # Invalidate Input/Output list for error if checks.
1504 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001505 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001506 pCount, cCount = op["operands"]
1507 num_operands = pCount + cCount
1508 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1509 self, error_name, input_list, output_list
1510 )
1511
1512 if not TosaErrorValidator.evValidateErrorIfs(
1513 self.ser,
1514 validator_fcns,
1515 error_name,
1516 op=op,
1517 axis=axis,
1518 input_shape=a.shape,
1519 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001520 output_shape=result_tensor.shape,
1521 output_dtype=result_tensor.dtype,
1522 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001523 input_list=input_list,
1524 output_list=output_list,
1525 num_operands=num_operands,
1526 ):
1527 return None
1528
1529 attr = ts.TosaSerializerAttribute()
1530 attr.AxisAttribute(axis)
1531
1532 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001533 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001534
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001535 def build_reshape(
1536 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1537 ):
1538 assert len(inputs) == 1
1539 a = inputs[0]
1540 new_shape = args_dict["new_shape"]
1541 result_tensor = OutputShaper.reshapeOp(
1542 self.ser, self.rng, a, new_shape, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001543 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001544
1545 # Invalidate Input/Output list for error if checks.
1546 input_list = [a.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001547 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001548 pCount, cCount = op["operands"]
1549 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001550 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1551 self, error_name, input_list, output_list
1552 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001553
Les Bell729b0352021-11-24 10:28:21 +00001554 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001555 self.ser,
1556 validator_fcns,
1557 error_name,
1558 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001559 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001560 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001561 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001562 output_dtype=result_tensor.dtype,
1563 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001564 input_list=input_list,
1565 output_list=output_list,
1566 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001567 ):
1568 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001569
1570 attr = ts.TosaSerializerAttribute()
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001571 attr.ReshapeAttribute(new_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07001572
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001573 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001574
1575 compliance = self.tensorComplianceMetaData(
1576 op, a.dtype, args_dict, result_tensor, error_name
1577 )
1578
1579 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001580
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001581 def build_reverse(
1582 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1583 ):
1584 assert len(inputs) == 1
1585 a = inputs[0]
1586 axis = args_dict["axis"]
1587 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001588
1589 # Invalidate Input/Output list for error if checks.
1590 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001591 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001592 pCount, cCount = op["operands"]
1593 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001594 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1595 self, error_name, input_list, output_list
1596 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001597
Les Bell729b0352021-11-24 10:28:21 +00001598 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001599 self.ser,
1600 validator_fcns,
1601 error_name,
1602 op=op,
1603 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001604 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001605 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001606 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001607 output_dtype=result_tensor.dtype,
1608 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001609 input_list=input_list,
1610 output_list=output_list,
1611 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001612 ):
1613 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001614
1615 attr = ts.TosaSerializerAttribute()
1616 attr.AxisAttribute(axis)
1617
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001618 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001619 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001620
Matthew Haddone807aae2021-10-11 18:12:58 +01001621 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1622 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001623
Kevin Chengfe392ce2021-10-18 21:51:55 +00001624 attr = ts.TosaSerializerAttribute()
1625 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001626
Matthew Haddone807aae2021-10-11 18:12:58 +01001627 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001628 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001629 output_list = [result_tens.name]
1630 pCount, cCount = op["operands"]
1631 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001632 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1633 self, error_name, input_list, output_list
1634 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001635
Les Bell729b0352021-11-24 10:28:21 +00001636 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001637 self.ser,
1638 validator_fcns,
1639 error_name,
1640 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001641 input_shape=a.shape,
1642 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001643 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001644 input_dtype=a.dtype,
1645 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001646 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001647 input_list=input_list,
1648 output_list=output_list,
1649 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001650 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001651 ):
1652 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001653
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001654 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001655 return result_tens
1656
Matthew Haddone807aae2021-10-11 18:12:58 +01001657 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001658 result_tens = OutputShaper.sliceOp(
1659 self.ser, self.rng, a, start, size, error_name
1660 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001661
1662 # Invalidate Input/Output list for error if checks.
1663 input_list = [a.name]
1664 output_list = [result_tens.name]
1665 pCount, cCount = op["operands"]
1666 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001667 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1668 self, error_name, input_list, output_list
1669 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001670
Les Bell729b0352021-11-24 10:28:21 +00001671 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001672 self.ser,
1673 validator_fcns,
1674 error_name,
1675 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001676 input_shape=a.shape,
1677 output_shape=result_tens.shape,
1678 input_dtype=a.dtype,
1679 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001680 start=start,
1681 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001682 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001683 input_list=input_list,
1684 output_list=output_list,
1685 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001686 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001687 ):
1688 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001689
1690 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001691 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001692
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001693 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001694 return result_tens
1695
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001696 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1697 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1698
1699 # Invalidate Input/Output list for error if checks.
1700 input_list = [a.name]
1701 output_list = [result_tens.name]
1702 pCount, cCount = op["operands"]
1703 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001704 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1705 self, error_name, input_list, output_list
1706 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001707
Les Bell729b0352021-11-24 10:28:21 +00001708 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001709 self.ser,
1710 validator_fcns,
1711 error_name,
1712 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001713 input_shape=a.shape,
1714 output_shape=result_tens.shape,
1715 input_dtype=a.dtype,
1716 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001717 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001718 input_list=input_list,
1719 output_list=output_list,
1720 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001721 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001722 ):
1723 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001724
1725 attr = ts.TosaSerializerAttribute()
1726 attr.TileAttribute(multiples)
1727
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001728 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001729 return result_tens
1730
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001731 def build_gather(
1732 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1733 ):
1734 assert len(inputs) == 2
1735 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001736
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001737 result_tensor = OutputShaper.gatherOp(
1738 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001739 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001740
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001741 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001742 input_list = [values.name, indices.name]
1743 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001744 pCount, cCount = op["operands"]
1745 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001746 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1747 self, error_name, input_list, output_list
1748 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001749
Les Bell729b0352021-11-24 10:28:21 +00001750 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001751 self.ser,
1752 validator_fcns,
1753 error_name,
1754 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001755 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001756 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001757 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001758 output_dtype=result_tensor.dtype,
1759 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001760 input_list=input_list,
1761 output_list=output_list,
1762 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001763 ):
1764 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001765
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001766 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001767
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001768 compliance = self.tensorComplianceMetaData(
1769 op, values.dtype, args_dict, result_tensor, error_name
1770 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001771
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001772 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001773
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001774 def build_scatter(
1775 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1776 ):
1777 assert len(inputs) == 3
1778 values_in, indices, input = inputs
1779 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001780 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001781 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001782
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001783 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001784 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001785 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001786 pCount, cCount = op["operands"]
1787 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001788 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1789 self, error_name, input_list, output_list
1790 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001791
Les Bell729b0352021-11-24 10:28:21 +00001792 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001793 self.ser,
1794 validator_fcns,
1795 error_name,
1796 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001797 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001798 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001799 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001800 output_dtype=result_tensor.dtype,
1801 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001802 input_list=input_list,
1803 output_list=output_list,
1804 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001805 ):
1806 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001807
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001808 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001809
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001810 compliance = self.tensorComplianceMetaData(
1811 op, values_in.dtype, args_dict, result_tensor, error_name
1812 )
1813
1814 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001815
Kevin Cheng550ccc52021-03-03 11:21:43 -08001816 def build_resize(
1817 self,
1818 op,
1819 input,
1820 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001821 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001822 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001823 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001824 input_dtype,
1825 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001826 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001827 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001828 ):
1829 result_tens = OutputShaper.resizeOp(
1830 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001831 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001832 input,
1833 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001834 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001835 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001836 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001837 input_dtype,
1838 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001839 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001840 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001841
Matthew Haddon848efb42021-09-09 12:30:53 +01001842 # Invalidate Input/Output list for error if checks.
1843 input_list = [input.name]
1844 output_list = [result_tens.name]
1845 pCount, cCount = op["operands"]
1846 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001847 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1848 self, error_name, input_list, output_list
1849 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001850
Les Bell729b0352021-11-24 10:28:21 +00001851 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001852 self.ser,
1853 validator_fcns,
1854 error_name,
1855 op=op,
1856 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001857 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001858 input_dtype=input_dtype,
1859 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001860 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001861 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001862 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001863 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001864 input_list=input_list,
1865 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001866 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001867 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001868 ):
1869 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001870
Eric Kunzee5e26762020-10-13 16:11:07 -07001871 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001872
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001873 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001874
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001875 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001876 return result_tens
1877
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001878 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1879 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1880 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001881 self.ser.addOperator(
1882 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1883 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001884 return result_tens
1885
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001886 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001887 self.ser.addOutputTensor(val)
1888 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001889
1890 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00001891 def build_cast(
1892 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1893 ):
1894 assert len(inputs) == 1
1895 val = inputs[0]
1896 out_dtype = args_dict["out_type"]
1897
1898 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001899 self.ser, self.rng, val, out_dtype, error_name
1900 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001901
1902 # Invalidate Input/Output list for error if checks.
1903 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00001904 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001905 pCount, cCount = op["operands"]
1906 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001907 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1908 self, error_name, input_list, output_list
1909 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001910
Les Bell729b0352021-11-24 10:28:21 +00001911 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001912 self.ser,
1913 validator_fcns,
1914 error_name,
1915 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001916 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00001917 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001918 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00001919 output_dtype=result_tensor.dtype,
1920 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001921 input_list=input_list,
1922 output_list=output_list,
1923 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001924 ):
1925 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001926
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001927 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00001928
1929 compliance = self.tensorComplianceMetaData(
1930 op, val.dtype, args_dict, result_tensor, error_name
1931 )
1932
1933 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001934
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001935 def build_rescale(
1936 self,
1937 op,
1938 val,
1939 out_dtype,
1940 scale32,
1941 double_round,
1942 per_channel,
1943 validator_fcns,
1944 error_name,
1945 ):
1946 result_tens = OutputShaper.typeConversionOp(
1947 self.ser, self.rng, val, out_dtype, error_name
1948 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001949
1950 if per_channel:
1951 nc = val.shape[-1]
1952 else:
1953 nc = 1
1954
1955 in_type_width = self.typeWidth(val.dtype)
1956 out_type_width = self.typeWidth(out_dtype)
1957
Kevin Cheng3a478572021-01-22 17:21:02 -08001958 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001959 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001960 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001961 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001962 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001963 in_type_width += 1
1964 elif error_name in [
1965 ErrorIf.InputZeroPointNotZero,
1966 ErrorIf.U16InputZeroPointNotValid,
1967 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001968 input_zp = self.randInt(-128, 128)
1969 if input_zp == 0:
1970 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001971 in_type_width += 1
1972 elif val.dtype == DType.UINT16:
1973 # Must come after ErrorIf.U16InputZeroPointNotValid check
1974 input_zp = self.rng.choice([0, 32768])
1975 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001976 else:
1977 input_zp = 0
1978
Kevin Cheng3a478572021-01-22 17:21:02 -08001979 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001980 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001981 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001982 elif out_dtype == DType.UINT8:
1983 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001984 out_type_width += 1
1985 elif error_name in [
1986 ErrorIf.OutputZeroPointNotZero,
1987 ErrorIf.U16OutputZeroPointNotValid,
1988 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001989 output_zp = self.randInt(-128, 128)
1990 if output_zp == 0:
1991 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001992 out_type_width += 1
1993 elif out_dtype == DType.UINT16:
1994 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1995 output_zp = self.rng.choice([0, 32768])
1996 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001997 else:
1998 output_zp = 0
1999
2000 # Calculate scale based on:
2001 # scale = a *(2^output_width)/(2^input_width))
2002
2003 a = np.float32(self.rng.random(size=[nc]))
2004 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2005
2006 if scale32:
2007 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002008 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002009 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2010 else:
2011 # Cap the scaling at 2^15 - 1 for scale16
2012 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2013
Kevin Cheng550ccc52021-03-03 11:21:43 -08002014 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002015
2016 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2017 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002018 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2019 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002020
2021 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002022 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2023 scale_arr[i], scale32
2024 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002025 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2026 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002027
Kevin Cheng550ccc52021-03-03 11:21:43 -08002028 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002029 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002030 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002031 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002032 assert val.placeholderFilename
2033 values = np.load(
2034 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2035 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002036 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2037 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2038 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2039 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002040 if not np.all(np.array_equal(values, val_adj)):
2041 # Values changed so overwrite file with new values
2042 np.save(
2043 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2044 val_adj,
2045 False,
2046 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002047
Matthew Haddonc2025212021-10-08 21:21:05 +01002048 # Invalidate Input/Output list for error if checks.
2049 input_list = [val.name]
2050 output_list = [result_tens.name]
2051 pCount, cCount = op["operands"]
2052 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002053 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2054 self, error_name, input_list, output_list
2055 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002056
2057 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002058 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002059 self.ser,
2060 validator_fcns,
2061 error_name,
2062 op=op,
2063 input_dtype=val.dtype,
2064 output_dtype=out_dtype,
2065 input_shape=val.shape,
2066 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002067 scale32=scale32,
2068 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002069 input_list=input_list,
2070 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002071 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002072 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002073 ):
2074 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002075
Eric Kunzee5e26762020-10-13 16:11:07 -07002076 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002077 attr.RescaleAttribute(
2078 input_zp,
2079 output_zp,
2080 multiplier_arr,
2081 shift_arr,
2082 scale32,
2083 double_round,
2084 per_channel,
2085 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002086
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002087 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002088 return result_tens
2089
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002090 def _get_condition_tensor(self, op, cond, error_name):
2091 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002092 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002093 else:
2094 cond_type = DType.BOOL
2095 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2096 choice = self.rng.choice([1, 2])
2097 if choice == 1:
2098 cond_shape = [2]
2099 else:
2100 cond_shape = [1, 2]
2101 else:
2102 # Must be of size 1 (rank 0)
2103 cond_shape = []
2104 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2105 return cond_tens
2106
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002107 def build_cond_if_const(
2108 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2109 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002110 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002111 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002112 # and fill them with const nodes for the body.
2113
2114 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002115 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002116
2117 # Make then/else tensors
2118 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002119
2120 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002121 if error_name in [
2122 ErrorIf.CondIfOutputListThenGraphMismatch,
2123 ErrorIf.CondIfOutputListElseGraphMismatch,
2124 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002125 incorrect_shape = deepcopy(then_tens.shape)
2126 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002127 incorrect_shape[i] += (
2128 self.rng.choice([-3, -2, 2, 3])
2129 if incorrect_shape[i] > 3
2130 else self.rng.choice([1, 2, 4])
2131 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002132 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2133
Jeremy Johnson18e26662021-07-22 16:15:29 +01002134 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2135 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002136
2137 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002138 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002139
2140 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002141 then_block = "THEN_BLOCK"
2142 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002143 attr = ts.TosaSerializerAttribute()
2144 attr.CondIfAttribute(then_block, else_block)
2145
2146 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002147 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002148
Jerry Ge9e94af82022-10-27 09:57:00 -07002149 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002150 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002151 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2152 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2153 else:
2154 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002155 self.ser.addOutputTensor(then_tens)
2156
Jerry Ge9e94af82022-10-27 09:57:00 -07002157 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002158 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2159 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2160 else:
2161 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002162 self.ser.addOutputTensor(else_tens)
2163
Les Bell729b0352021-11-24 10:28:21 +00002164 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002165 self.ser,
2166 validator_fcns,
2167 error_name,
2168 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002169 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002170 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002171 ):
2172 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002173
Eric Kunzee5e26762020-10-13 16:11:07 -07002174 return result_tens
2175
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002176 def build_cond_if_binary(
2177 self, op, a, b, cond, validator_fcns=None, error_name=None
2178 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002179 # For cond_if with a binary op in the then/else blocks, take a and b and
2180 # alternately add or subtract them based on the condition
2181
2182 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002183 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002184
Kevin Cheng550ccc52021-03-03 11:21:43 -08002185 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002186
2187 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002188 then_block = "THEN_BLOCK"
2189 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002190 attr = ts.TosaSerializerAttribute()
2191 attr.CondIfAttribute(then_block, else_block)
2192
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002193 if error_name in [
2194 ErrorIf.CondIfInputListThenGraphMismatch,
2195 ErrorIf.CondIfInputListElseGraphMismatch,
2196 ErrorIf.CondIfOutputListElseGraphMismatch,
2197 ErrorIf.CondIfOutputListThenGraphMismatch,
2198 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002199 incorrect_shape = a.shape.copy()
2200 for i in range(len(incorrect_shape)):
2201 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2202 incorrect_block_input = deepcopy(a)
2203 incorrect_block_input.shape = incorrect_shape
2204
Eric Kunzee5e26762020-10-13 16:11:07 -07002205 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002206 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002207 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002208 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002209
James Ward24dbc422022-10-19 12:20:31 +01002210 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002211 then_op, else_op = Op.ADD, Op.SUB
2212 elif a.dtype in (DType.INT8, DType.INT16):
2213 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2214 else:
2215 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002216
Les Bell6040b4d2021-10-11 12:50:31 +01002217 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002218 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002219 if (
2220 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2221 and block == then_block
2222 ) or (
2223 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2224 and block == else_block
2225 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002226 self.ser.addInputTensor(incorrect_block_input)
2227 self.ser.addInputTensor(b)
2228 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002229 elif (
2230 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2231 and block == then_block
2232 ) or (
2233 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2234 and block == else_block
2235 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002236 self.ser.addInputTensor(a)
2237 self.ser.addInputTensor(b)
2238 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2239 else:
2240 self.ser.addInputTensor(a)
2241 self.ser.addInputTensor(b)
2242 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002243 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002244
Les Bell729b0352021-11-24 10:28:21 +00002245 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002246 self.ser,
2247 validator_fcns,
2248 error_name,
2249 op=op,
2250 a=a,
2251 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002252 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002253 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002254 ):
2255 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002256
Eric Kunzee5e26762020-10-13 16:11:07 -07002257 return result_tens
2258
Matthew Haddon630c17c2021-10-14 15:05:41 +01002259 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002260 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002261
Kevin Cheng550ccc52021-03-03 11:21:43 -08002262 cond_block = "COND_BLOCK"
2263 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002264
2265 attr = ts.TosaSerializerAttribute()
2266 attr.WhileLoopAttribute(cond_block, body_block)
2267
2268 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002269 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002270 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002271 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002272
2273 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002274 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2275 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002276 if error_name == ErrorIf.InputListOutputListMismatch:
2277 incorrect_acc = deepcopy(acc)
2278 for i in range(len(incorrect_acc.shape)):
2279 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2280 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2281 else:
2282 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002283
2284 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002285 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002286 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002287 [iter.name, a.name, acc.name],
2288 [iter_out.name, a_out.name, acc_out.name],
2289 attr,
2290 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002291 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002292
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002293 if error_name in [
2294 ErrorIf.InputListCondGraphMismatch,
2295 ErrorIf.InputListBodyGraphInputMismatch,
2296 ErrorIf.InputListBodyGraphOutputMismatch,
2297 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002298 incorrect_iter = deepcopy(iter)
2299 for i in range(len(incorrect_iter.shape)):
2300 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2301 if len(incorrect_iter.shape) == 0:
2302 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2303
2304 incorrect_acc = deepcopy(acc)
2305 for i in range(len(incorrect_acc.shape)):
2306 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2307
Eric Kunzee5e26762020-10-13 16:11:07 -07002308 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002309 self.ser.addBasicBlock(cond_block)
2310
Matthew Haddon630c17c2021-10-14 15:05:41 +01002311 if error_name == ErrorIf.InputListCondGraphMismatch:
2312 self.ser.addInputTensor(incorrect_iter)
2313 self.ser.addInputTensor(a)
2314 self.ser.addInputTensor(incorrect_acc)
2315 else:
2316 self.ser.addInputTensor(iter)
2317 self.ser.addInputTensor(a)
2318 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002319 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002320
2321 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002322 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002323 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002324 cond_type = DType.BOOL
2325 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2326 choice = self.rng.choice([1, 2])
2327 if choice == 1:
2328 cond_shape = [3]
2329 else:
2330 cond_shape = [1, 2]
2331 else:
2332 cond_shape = []
2333 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002334
Kevin Cheng550ccc52021-03-03 11:21:43 -08002335 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002336
2337 # BODY block (input: a, acc, iter, output: a, acc, iter)
2338 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002339 self.ser.addBasicBlock(body_block)
2340
Matthew Haddon630c17c2021-10-14 15:05:41 +01002341 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2342 self.ser.addInputTensor(incorrect_iter)
2343 self.ser.addInputTensor(a)
2344 self.ser.addInputTensor(incorrect_acc)
2345 else:
2346 self.ser.addInputTensor(iter)
2347 self.ser.addInputTensor(a)
2348 self.ser.addInputTensor(acc)
2349
Kevin Cheng550ccc52021-03-03 11:21:43 -08002350 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002351
2352 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002353 iter_body_out = self.ser.addIntermediate(
2354 incorrect_iter.shape, incorrect_iter.dtype
2355 )
2356 acc_body_out = self.ser.addIntermediate(
2357 incorrect_acc.shape, incorrect_acc.dtype
2358 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002359 else:
2360 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2361 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2362
Eric Kunzee5e26762020-10-13 16:11:07 -07002363 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2364 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2365 self.ser.addOutputTensor(iter_body_out)
2366 self.ser.addOutputTensor(a)
2367 self.ser.addOutputTensor(acc_body_out)
2368
Les Bell729b0352021-11-24 10:28:21 +00002369 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002370 self.ser,
2371 validator_fcns,
2372 error_name,
2373 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002374 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002375 ):
2376 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002377
Eric Kunzee5e26762020-10-13 16:11:07 -07002378 return acc_out
2379
Luke Hutton57287132023-02-06 14:54:18 +00002380 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002381 self,
2382 op,
2383 val1,
2384 val2,
2385 inverse,
2386 validator_fcns=None,
2387 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002388 ):
2389 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2390
2391 input_names = [val1.name, val2.name]
2392 pCount, cCount = op["operands"]
2393 num_operands = pCount + cCount
2394
2395 output_names = [res.name for res in results]
2396 output_shapes = [res.shape for res in results]
2397 output_dtypes = [res.dtype for res in results]
2398
2399 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2400 self, error_name, input_names, output_names
2401 )
2402
2403 if not TosaErrorValidator.evValidateErrorIfs(
2404 self.ser,
2405 validator_fcns,
2406 error_name,
2407 op=op,
2408 inverse=inverse,
2409 input1=val1,
2410 input2=val2,
2411 input_shape=val1.shape,
2412 input_dtype=val1.dtype,
2413 output_shape=output_shapes,
2414 output_dtype=output_dtypes,
2415 result_tensors=results,
2416 input_list=input_names,
2417 output_list=output_names,
2418 num_operands=num_operands,
2419 ):
2420 return None
2421
Tai Lyd3797f02023-11-15 23:06:19 +00002422 # TODO - Test local_bound, for now set local bound attribute to False
2423 local_bound = False
2424
Luke Hutton57287132023-02-06 14:54:18 +00002425 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002426 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002427
2428 self.ser.addOperator(op["op"], input_names, output_names, attr)
2429 return results
2430
Tai Lyd3797f02023-11-15 23:06:19 +00002431 def build_rfft2d(
2432 self,
2433 op,
2434 val,
2435 validator_fcns=None,
2436 error_name=None,
2437 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002438 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2439
2440 input_names = [val.name]
2441 pCount, cCount = op["operands"]
2442 num_operands = pCount + cCount
2443
2444 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002445 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002446 output_dtypes = [res.dtype for res in results]
2447
2448 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2449 self, error_name, input_names, output_names
2450 )
2451
2452 if not TosaErrorValidator.evValidateErrorIfs(
2453 self.ser,
2454 validator_fcns,
2455 error_name,
2456 op=op,
2457 input_shape=val.shape,
2458 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002459 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002460 output_dtype=output_dtypes,
2461 result_tensors=results,
2462 input_list=input_names,
2463 output_list=output_names,
2464 num_operands=num_operands,
2465 ):
2466 return None
2467
Tai Lyd3797f02023-11-15 23:06:19 +00002468 # TODO - Test local_bound, for now set local bound attribute to False
2469 local_bound = False
2470
2471 attr = ts.TosaSerializerAttribute()
2472 attr.RFFTAttribute(local_bound)
2473
2474 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002475 return results
2476
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002477 def create_filter_lists(
2478 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2479 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002480 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2481 default_test_rank_range = range(1, 5)
2482 if not shapeFilter:
2483 shapeFilter = [None]
2484
2485 # Calculate the filters based on what is requested and what the operator allows
2486 rmin, rmax = op["rank"]
2487 if rankFilter is not None:
2488 cleanRankFilter = []
2489 # Ensure rankFilter values are allowed by operator
2490 for rank in rankFilter:
2491 if rank >= rmin and rank <= rmax:
2492 cleanRankFilter.append(rank)
2493 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002494 # Ensure default behaviour is bounded by default range or by operator,
2495 # whichever is the smaller range of ranks.
2496 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002497 cleanRankFilter = (
2498 opRankRange
2499 if len(opRankRange) <= len(default_test_rank_range)
2500 else default_test_rank_range
2501 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002502 else:
2503 cleanRankFilter = range(rmin, rmax + 1)
2504
2505 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002506
Matthew Haddon1c00b712021-10-01 15:51:03 +01002507 if dtypeFilter is not None:
2508 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002509 # Create list of operator dtypes filtered by requested dtypes
2510 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002511 if dtype in dtypeFilter or (
2512 isinstance(dtype, list) and dtype[0] in dtypeFilter
2513 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002514 cleanDtypeFilter.append(dtype)
2515 else:
2516 cleanDtypeFilter = dtypes
2517
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002518 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002519 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002520 "shapeFilter": shapeFilter,
2521 "rankFilter": cleanRankFilter,
2522 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002523 }
2524 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002525 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002526 if validator is not None:
2527 validator_info = validator(check=False, op=op)
2528 else:
2529 return None
2530
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002531 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002532
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002533 # Set parameters as required
2534 if error_arguments["rank"] is not None:
2535 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002536 else:
2537 rankFilter = cleanRankFilter
2538
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002539 if error_arguments["dtype"] is not None:
2540 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002541 else:
2542 dtypeFilter = cleanDtypeFilter
2543
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002544 if error_arguments["shape"] is not None:
2545 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002546 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002547 shapeFilter = shapeFilter[
2548 :2
2549 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002550
2551 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002552 "shapeFilter": shapeFilter,
2553 "rankFilter": rankFilter,
2554 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002555 }
2556 return filterDict
2557
Kevin Cheng550ccc52021-03-03 11:21:43 -08002558 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002559 self,
2560 opName,
2561 shapeFilter=[None],
2562 rankFilter=None,
2563 dtypeFilter=None,
2564 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002565 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002566
2567 try:
2568 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002569 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002570 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002571
2572 # Initialize a new random number generator
2573 self.rng = np.random.default_rng(self.random_seed)
2574
Jeremy Johnson1271c442023-09-05 11:39:26 +01002575 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002576
Eric Kunzee5e26762020-10-13 16:11:07 -07002577 # Test list consists of a tuple of:
2578 # (opName, testNameStr, dtype, shapeList, argumentsList)
2579 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002580 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002581 error_if_validators = op["error_if_validators"]
2582 else:
2583 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002584
Matthew Haddon1c00b712021-10-01 15:51:03 +01002585 for validator in error_if_validators:
2586 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002587 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002588 else:
2589 error_name = None
2590
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002591 filterDict = self.create_filter_lists(
2592 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2593 )
2594 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002595 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002596 cleanRankFilter = filterDict["rankFilter"]
2597 cleanDtypeFilter = filterDict["dtypeFilter"]
2598 cleanShapeFilter = filterDict["shapeFilter"]
2599 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002600
2601 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002602 for t in cleanDtypeFilter:
2603 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002604 # Filter out by rank
2605 if shape is not None and len(shape) != r:
2606 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002607 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002608 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002609
Matthew Haddon74567092021-07-16 15:38:20 +01002610 shapeStr = self.shapeStr(shapeList[0])
2611 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002612
Matthew Haddon74567092021-07-16 15:38:20 +01002613 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2614 argList = []
2615 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002616 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002617 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002618 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002619
Matthew Haddon74567092021-07-16 15:38:20 +01002620 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002621 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002622 if argStr:
2623 testStr = "{}_{}_{}_{}".format(
2624 opName, shapeStr, typeStr, argStr
2625 )
2626 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002627 testStr = "{}_{}_{}".format(
2628 opName, shapeStr, typeStr
2629 )
2630 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002631 if argStr:
2632 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2633 opName, error_name, shapeStr, typeStr, argStr
2634 )
2635 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002636 testStr = "{}_ERRORIF_{}_{}_{}".format(
2637 opName, error_name, shapeStr, typeStr
2638 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002639
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002640 testList.append(
2641 (opName, testStr, t, error_name, shapeList, args)
2642 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002643
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002644 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002645 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2646 if "invalid_test_validators" in op:
2647 invalid_test_validators = op["invalid_test_validators"]
2648 clean_testList = []
2649 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002650 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002651 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002652 if validator_fcn(
2653 opName=test[0],
2654 input_dtype=test[2],
2655 shapeList=test[4],
2656 args=test[5],
2657 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002658 remove_test = True
2659 if not remove_test:
2660 clean_testList.append(test)
2661 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002662
2663 return testList
2664
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002665 def serializeTest(
2666 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2667 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002668 try:
2669 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002670 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002671 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002672
Jeremy Johnson0c716862023-04-13 17:18:19 +01002673 if self.args.verbose:
2674 print(f"Creating {testStr}")
2675
Eric Kunzee5e26762020-10-13 16:11:07 -07002676 # Create a serializer
2677 self.createSerializer(opName, testStr)
2678
Jeremy Johnson1271c442023-09-05 11:39:26 +01002679 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002680 if "error_if_validators" in op:
2681 error_if_validators = op["error_if_validators"]
2682 else:
2683 error_if_validators = None
2684
Kevin Cheng550ccc52021-03-03 11:21:43 -08002685 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002686 num_operands = pCount + cCount
2687
2688 if isinstance(dtype_or_dtypeList, list):
2689 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002690 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002691 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002692 else:
2693 dtypeList = [dtype_or_dtypeList] * (num_operands)
2694
Kevin Cheng93a16282021-08-31 16:14:03 -07002695 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002696 assert (
2697 len(shapeList) == num_operands
2698 ), "shapeList length {} must match number of operands {}".format(
2699 len(shapeList), num_operands
2700 )
2701 assert (
2702 len(dtypeList) == num_operands
2703 ), "dtypeList length {} must match number of operands {}".format(
2704 len(dtypeList), num_operands
2705 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002706
2707 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002708 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002709 except KeyError:
2710 qgen = None
2711
2712 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002713
Matthew Haddon1c00b712021-10-01 15:51:03 +01002714 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002715 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002716 else:
2717 qinfo = None
2718
Jeremy Johnson1271c442023-09-05 11:39:26 +01002719 # Extra meta data for the desc.json
2720 tensMeta = {}
2721
2722 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002723 if isinstance(testArgs, dict):
2724 # New interface with args info in dictionary
2725 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002726 assert "dg_type" in argsDict
2727 tvgInfo = tvgen_fcn(
2728 self, opName, dtypeList, shapeList, argsDict, error_name
2729 )
2730 if tvgInfo.dataGenDict:
2731 tensMeta["data_gen"] = tvgInfo.dataGenDict
2732 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002733
2734 result = build_fcn(
2735 self,
2736 op,
2737 tens,
2738 argsDict,
2739 validator_fcns=error_if_validators,
2740 error_name=error_name,
2741 qinfo=qinfo,
2742 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002743 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002744 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002745 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002746
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002747 try:
2748 if error_if_validators is None:
2749 if qinfo is not None:
2750 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2751 else:
2752 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002753 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002754 if qinfo is not None:
2755 result = build_fcn(
2756 self,
2757 op,
2758 *tens,
2759 *testArgs,
2760 validator_fcns=error_if_validators,
2761 error_name=error_name,
2762 qinfo=qinfo,
2763 )
2764 else:
2765 result = build_fcn(
2766 self,
2767 op,
2768 *tens,
2769 *testArgs,
2770 validator_fcns=error_if_validators,
2771 error_name=error_name,
2772 )
2773 except TypeError as e:
2774 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2775 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002776
Jeremy Johnson1271c442023-09-05 11:39:26 +01002777 if result:
Les Bell729b0352021-11-24 10:28:21 +00002778 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002779 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2780 # Add the compliance meta data
2781 # NOTE: This currently expects only one result output
2782 tensMeta["compliance"] = {
2783 "version": "0.1",
2784 "tensors": {result.resultTensor.name: result.complianceDict},
2785 }
2786 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002787 else:
2788 # The test is not valid
2789 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002790
Eric Kunzee5e26762020-10-13 16:11:07 -07002791 def createDynamicOpLists(self):
2792
Jeremy Johnson00423432022-09-12 17:27:37 +01002793 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2794 # Already created these lists (can occur when class is initialized more than once)
2795 return
2796
Eric Kunzee5e26762020-10-13 16:11:07 -07002797 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002798 if not self.args.level8k:
2799 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2800 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2801 else:
2802 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2803 KERNELS_2D = [[1, bigK], [bigK, 2]]
2804 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002805
Kevin Cheng1533b852021-09-01 12:51:58 -07002806 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002807 testName = "conv2d_{}x{}".format(k[0], k[1])
2808 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2809 self.TOSA_OP_LIST[testName]["filter"] = k
2810 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002811
Kevin Cheng550ccc52021-03-03 11:21:43 -08002812 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2813 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2814 "depthwise_conv2d_TEMPLATE"
2815 ].copy()
2816 self.TOSA_OP_LIST[testName]["filter"] = k
2817 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002818
Kevin Cheng550ccc52021-03-03 11:21:43 -08002819 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2820 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2821 "transpose_conv2d_TEMPLATE"
2822 ].copy()
2823 self.TOSA_OP_LIST[testName]["filter"] = k
2824 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002825
Kevin Cheng1533b852021-09-01 12:51:58 -07002826 for k in KERNELS_3D:
2827 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2828 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2829 self.TOSA_OP_LIST[testName]["filter"] = k
2830 self.TOSA_OP_LIST[testName]["template"] = False
2831
Eric Kunzee5e26762020-10-13 16:11:07 -07002832 # Delete any templates after having created any dynamic ops
2833 # This is a two-pass operation because it's bad practice to delete
2834 # keys from dictionaries while iterating
2835 keyList = []
2836 for k in self.TOSA_OP_LIST:
2837 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002838 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002839 keyList.append(k)
2840 continue
2841 except KeyError:
2842 pass
2843
2844 for k in keyList:
2845 del self.TOSA_OP_LIST[k]
2846
2847 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002848 """Fill in default fields for ops if they aren't already specified.
2849 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002850 for op in self.TOSA_OP_LIST:
2851
2852 # Required fields
2853 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002854 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002855 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002856 raise Exception(
2857 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2858 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002859
2860 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002861 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002862 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002863 raise Exception(
2864 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2865 op
2866 )
2867 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002868
2869 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002870 _ = self.TOSA_OP_LIST[op]["types"]
2871 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002872 raise Exception(
2873 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2874 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002875
2876 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002877 _ = self.TOSA_OP_LIST[op]["op"]
2878 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002879 raise Exception(
2880 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2881 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002882
2883 # Put in default rank range, if missing
2884 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002885 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002886 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002887 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002888
2889 # Tensor operator list
2890 # 'op': op name
2891 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002892 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2893 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002894 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2895 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002896 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002897
Kevin Cheng550ccc52021-03-03 11:21:43 -08002898 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002899 TYPE_INT_FP = [
2900 DType.INT8,
2901 DType.INT16,
2902 DType.INT32,
2903 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002904 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002905 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002906 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002907
Kevin Cheng550ccc52021-03-03 11:21:43 -08002908 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002909 TYPE_FI32 = [
2910 DType.FP32,
2911 DType.FP16,
2912 DType.BF16,
2913 DType.INT32,
2914 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002915 TYPE_FIB = [
2916 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002917 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002918 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002919 DType.INT8,
2920 DType.INT16,
2921 DType.INT32,
2922 DType.BOOL,
2923 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002924 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002925
James Ward24dbc422022-10-19 12:20:31 +01002926 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002927
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002928 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002929 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002930 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002931 [DType.INT8, DType.INT8, DType.INT32],
2932 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002933 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002934 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002935 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002936 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002937 ]
2938
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002939 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002940
2941 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002942 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002943 "argmax": {
2944 "op": Op.ARGMAX,
2945 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002946 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002947 "build_fcn": (
2948 build_argmax,
2949 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002950 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002951 TosaArgGen.agAxis,
2952 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002953 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002954 "error_if_validators": (
2955 TosaErrorValidator.evAxisSmallerZero,
2956 TosaErrorValidator.evAxisLargerRank,
2957 TosaErrorValidator.evArgmaxOutputRankMismatch,
2958 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2959 TosaErrorValidator.evWrongRank,
2960 TosaErrorValidator.evWrongInputType,
2961 TosaErrorValidator.evWrongOutputType,
2962 TosaErrorValidator.evWrongInputList,
2963 TosaErrorValidator.evWrongOutputList,
2964 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002965 "data_gen": {
2966 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
2967 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002968 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002969 "avg_pool2d": {
2970 "op": Op.AVG_POOL2D,
2971 "operands": (1, 0),
2972 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002973 "build_fcn": (
2974 build_pool2d,
2975 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002976 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002977 TosaArgGen.agPooling,
2978 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002979 "qgen": TosaQuantGen.qgUnary,
2980 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002981 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002982 "error_if_validators": (
2983 TosaErrorValidator.evKernelSmallerOne,
2984 TosaErrorValidator.evStrideSmallerOne,
2985 TosaErrorValidator.evPadSmallerZero,
2986 TosaErrorValidator.evWrongRank,
2987 TosaErrorValidator.evWrongInputType,
2988 TosaErrorValidator.evWrongOutputType,
2989 TosaErrorValidator.evWrongInputList,
2990 TosaErrorValidator.evWrongOutputList,
2991 TosaErrorValidator.evInputZeroPointNotZero,
2992 TosaErrorValidator.evOutputZeroPointNotZero,
2993 TosaErrorValidator.evPadLargerEqualKernel,
2994 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002995 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002996 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00002997 "data_gen": {
2998 "fp": (gtu.DataGenType.DOT_PRODUCT,),
2999 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003000 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003001 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003002 "conv2d_TEMPLATE": {
3003 "op": Op.CONV2D,
3004 "operands": (1, 2),
3005 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003006 "build_fcn": (
3007 build_conv2d,
3008 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003009 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003010 TosaArgGen.agConv,
3011 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003012 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003013 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003014 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3015 "error_if_validators": (
3016 TosaErrorValidator.evWrongInputType,
3017 TosaErrorValidator.evWrongOutputType,
3018 TosaErrorValidator.evWrongInputList,
3019 TosaErrorValidator.evWrongOutputList,
3020 TosaErrorValidator.evInputZeroPointNotZero,
3021 TosaErrorValidator.evWeightZeroPointNotZero,
3022 TosaErrorValidator.evPadSmallerZero,
3023 TosaErrorValidator.evStrideSmallerOne,
3024 TosaErrorValidator.evDilationSmallerOne,
3025 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003026 TosaErrorValidator.evConvOutputShapeMismatch,
3027 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003028 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003029 "data_gen": {
3030 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3031 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003032 "template": True,
3033 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003034 # Templated operator. Filled in by createDynamicOpLists
3035 "conv3d_TEMPLATE": {
3036 "op": Op.CONV3D,
3037 "operands": (1, 2),
3038 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003039 "build_fcn": (
3040 build_conv3d,
3041 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003042 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003043 TosaArgGen.agConv,
3044 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003045 "qgen": TosaQuantGen.qgConv,
3046 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003047 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3048 "error_if_validators": (
3049 TosaErrorValidator.evWrongInputType,
3050 TosaErrorValidator.evWrongOutputType,
3051 TosaErrorValidator.evWrongInputList,
3052 TosaErrorValidator.evWrongOutputList,
3053 TosaErrorValidator.evInputZeroPointNotZero,
3054 TosaErrorValidator.evWeightZeroPointNotZero,
3055 TosaErrorValidator.evPadSmallerZero,
3056 TosaErrorValidator.evStrideSmallerOne,
3057 TosaErrorValidator.evDilationSmallerOne,
3058 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003059 TosaErrorValidator.evConvOutputShapeMismatch,
3060 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003061 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003062 "template": True,
3063 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003064 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003065 "depthwise_conv2d_TEMPLATE": {
3066 "op": Op.DEPTHWISE_CONV2D,
3067 "operands": (1, 2),
3068 "filter": [1, 1],
3069 "rank": (4, 4),
3070 "build_fcn": (
3071 build_depthwise_conv2d,
3072 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003073 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003074 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003075 ),
3076 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003077 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003078 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3079 "error_if_validators": (
3080 TosaErrorValidator.evWrongInputType,
3081 TosaErrorValidator.evWrongOutputType,
3082 TosaErrorValidator.evWrongInputList,
3083 TosaErrorValidator.evWrongOutputList,
3084 TosaErrorValidator.evInputZeroPointNotZero,
3085 TosaErrorValidator.evWeightZeroPointNotZero,
3086 TosaErrorValidator.evPadSmallerZero,
3087 TosaErrorValidator.evStrideSmallerOne,
3088 TosaErrorValidator.evDilationSmallerOne,
3089 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003090 TosaErrorValidator.evConvOutputShapeMismatch,
3091 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003092 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003093 "template": True,
3094 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003095 "fully_connected": {
3096 "op": Op.FULLY_CONNECTED,
3097 "operands": (1, 2),
3098 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003099 "build_fcn": (
3100 build_fully_connected,
3101 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003102 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003103 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003104 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003105 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003106 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003107 "error_if_validators": (
3108 TosaErrorValidator.evInputZeroPointNotZero,
3109 TosaErrorValidator.evWeightZeroPointNotZero,
3110 TosaErrorValidator.evWrongRank,
3111 TosaErrorValidator.evWrongInputType,
3112 TosaErrorValidator.evWrongOutputType,
3113 TosaErrorValidator.evWrongInputList,
3114 TosaErrorValidator.evWrongOutputList,
3115 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003116 "data_gen": {
3117 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3118 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003119 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003120 "matmul": {
3121 "op": Op.MATMUL,
3122 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003123 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003124 "build_fcn": (
3125 build_matmul,
3126 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003127 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003128 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003129 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003130 "qgen": TosaQuantGen.qgMatmul,
3131 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003132 "error_if_validators": (
3133 TosaErrorValidator.evInputZeroPointNotZero,
3134 TosaErrorValidator.evWrongRank,
3135 TosaErrorValidator.evWrongInputType,
3136 TosaErrorValidator.evWrongOutputType,
3137 TosaErrorValidator.evWrongInputList,
3138 TosaErrorValidator.evWrongOutputList,
3139 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003140 "data_gen": {
3141 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003142 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003143 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003144 "max_pool2d": {
3145 "op": Op.MAX_POOL2D,
3146 "operands": (1, 0),
3147 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003148 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003149 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003150 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003151 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003152 TosaArgGen.agPooling,
3153 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003154 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003155 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003156 "error_if_validators": (
3157 TosaErrorValidator.evKernelSmallerOne,
3158 TosaErrorValidator.evStrideSmallerOne,
3159 TosaErrorValidator.evPadSmallerZero,
3160 TosaErrorValidator.evWrongRank,
3161 TosaErrorValidator.evWrongInputType,
3162 TosaErrorValidator.evWrongOutputType,
3163 TosaErrorValidator.evWrongInputList,
3164 TosaErrorValidator.evWrongOutputList,
3165 TosaErrorValidator.evPadLargerEqualKernel,
3166 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003167 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003168 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003169 "data_gen": {
3170 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3171 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003172 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003173 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003174 "transpose_conv2d_TEMPLATE": {
3175 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003176 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003177 "rank": (4, 4),
3178 "build_fcn": (
3179 build_transpose_conv2d,
3180 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003181 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003182 TosaArgGen.agTransposeConv2D,
3183 ),
3184 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003185 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003186 "invalid_test_validators": (
3187 TosaInvalidValidator.ivHeightWidthInvalid,
3188 TosaInvalidValidator.ivNonPositiveOutputShape,
3189 ),
3190 "error_if_validators": (
3191 TosaErrorValidator.evWrongInputType,
3192 TosaErrorValidator.evWrongOutputType,
3193 TosaErrorValidator.evWrongInputList,
3194 TosaErrorValidator.evWrongOutputList,
3195 TosaErrorValidator.evInputZeroPointNotZero,
3196 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003197 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003198 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003199 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003200 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003201 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003202 "template": True,
3203 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003204 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003205 "clamp": {
3206 "op": Op.CLAMP,
3207 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003208 "build_fcn": (
3209 build_clamp,
3210 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003211 TosaTensorValuesGen.tvgLazyGenDefault,
3212 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003213 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003214 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003215 "error_if_validators": (
3216 TosaErrorValidator.evMaxSmallerMin,
3217 TosaErrorValidator.evWrongInputType,
3218 TosaErrorValidator.evWrongOutputType,
3219 TosaErrorValidator.evWrongInputList,
3220 TosaErrorValidator.evWrongOutputList,
3221 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003222 "data_gen": {
3223 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3224 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003225 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003226 "sigmoid": {
3227 "op": Op.SIGMOID,
3228 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003229 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003230 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003231 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003232 TosaTensorValuesGen.tvgLazyGenDefault,
3233 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003234 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003235 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003236 "error_if_validators": (
3237 TosaErrorValidator.evWrongInputType,
3238 TosaErrorValidator.evWrongOutputType,
3239 TosaErrorValidator.evWrongInputList,
3240 TosaErrorValidator.evWrongOutputList,
3241 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003242 "data_gen": {
3243 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3244 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003245 },
3246 "tanh": {
3247 "op": Op.TANH,
3248 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003249 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003250 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003251 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003252 TosaTensorValuesGen.tvgLazyGenDefault,
3253 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003254 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003255 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003256 "error_if_validators": (
3257 TosaErrorValidator.evWrongInputType,
3258 TosaErrorValidator.evWrongOutputType,
3259 TosaErrorValidator.evWrongInputList,
3260 TosaErrorValidator.evWrongOutputList,
3261 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003262 "data_gen": {
3263 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3264 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003265 },
Won Jeon78155c62023-06-10 00:20:04 +00003266 "erf": {
3267 "op": Op.ERF,
3268 "operands": (1, 0),
3269 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003270 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003271 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003272 TosaTensorValuesGen.tvgLazyGenDefault,
3273 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003274 ),
3275 "types": TYPE_FP,
3276 "error_if_validators": (
3277 TosaErrorValidator.evWrongInputType,
3278 TosaErrorValidator.evWrongOutputType,
3279 TosaErrorValidator.evWrongInputList,
3280 TosaErrorValidator.evWrongOutputList,
3281 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003282 "data_gen": {
3283 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3284 },
3285 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003286 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003287 # Elementwise Binary Operators
3288 "add": {
3289 "op": Op.ADD,
3290 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003291 "build_fcn": (
3292 build_binary_broadcast,
3293 TosaTensorGen.tgBroadcastFuzz,
3294 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003295 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003296 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003297 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003298 "error_if_validators": (
3299 TosaErrorValidator.evRankMismatch,
3300 TosaErrorValidator.evWrongInputType,
3301 TosaErrorValidator.evWrongOutputType,
3302 TosaErrorValidator.evWrongInputList,
3303 TosaErrorValidator.evWrongOutputList,
3304 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003305 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003306 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003307 "data_gen": {
3308 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3309 },
3310 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003311 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003312 "arithmetic_right_shift": {
3313 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3314 "operands": (2, 0),
3315 "build_fcn": (
3316 build_arithmetic_right_shift,
3317 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003318 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003319 TosaArgGen.agArithmeticRightShift,
3320 ),
3321 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003322 "error_if_validators": (
3323 TosaErrorValidator.evRankMismatch,
3324 TosaErrorValidator.evWrongInputType,
3325 TosaErrorValidator.evWrongOutputType,
3326 TosaErrorValidator.evWrongInputList,
3327 TosaErrorValidator.evWrongOutputList,
3328 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003329 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003330 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003331 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003332 "bitwise_and": {
3333 "op": Op.BITWISE_AND,
3334 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003335 "build_fcn": (
3336 build_binary_broadcast,
3337 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003338 TosaTensorValuesGen.tvgLazyGenDefault,
3339 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003340 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003341 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003342 "error_if_validators": (
3343 TosaErrorValidator.evRankMismatch,
3344 TosaErrorValidator.evWrongInputType,
3345 TosaErrorValidator.evWrongOutputType,
3346 TosaErrorValidator.evWrongInputList,
3347 TosaErrorValidator.evWrongOutputList,
3348 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003349 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003350 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003351 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003352 "bitwise_or": {
3353 "op": Op.BITWISE_OR,
3354 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003355 "build_fcn": (
3356 build_binary_broadcast,
3357 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003358 TosaTensorValuesGen.tvgLazyGenDefault,
3359 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003360 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003361 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003362 "error_if_validators": (
3363 TosaErrorValidator.evRankMismatch,
3364 TosaErrorValidator.evWrongInputType,
3365 TosaErrorValidator.evWrongOutputType,
3366 TosaErrorValidator.evWrongInputList,
3367 TosaErrorValidator.evWrongOutputList,
3368 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003369 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003370 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003371 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003372 "bitwise_xor": {
3373 "op": Op.BITWISE_XOR,
3374 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003375 "build_fcn": (
3376 build_binary_broadcast,
3377 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003378 TosaTensorValuesGen.tvgLazyGenDefault,
3379 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003380 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003381 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003382 "error_if_validators": (
3383 TosaErrorValidator.evRankMismatch,
3384 TosaErrorValidator.evWrongInputType,
3385 TosaErrorValidator.evWrongOutputType,
3386 TosaErrorValidator.evWrongInputList,
3387 TosaErrorValidator.evWrongOutputList,
3388 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003389 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003390 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003391 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003392 "intdiv": {
3393 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003394 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003395 "build_fcn": (
3396 build_binary_broadcast,
3397 TosaTensorGen.tgBroadcastFuzz,
3398 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003399 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003400 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003401 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003402 "error_if_validators": (
3403 TosaErrorValidator.evRankMismatch,
3404 TosaErrorValidator.evWrongInputType,
3405 TosaErrorValidator.evWrongOutputType,
3406 TosaErrorValidator.evWrongInputList,
3407 TosaErrorValidator.evWrongOutputList,
3408 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003409 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003410 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003411 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003412 "logical_and": {
3413 "op": Op.LOGICAL_AND,
3414 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003415 "build_fcn": (
3416 build_binary_broadcast,
3417 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003418 TosaTensorValuesGen.tvgLazyGenDefault,
3419 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003420 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003421 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003422 "error_if_validators": (
3423 TosaErrorValidator.evRankMismatch,
3424 TosaErrorValidator.evWrongInputType,
3425 TosaErrorValidator.evWrongOutputType,
3426 TosaErrorValidator.evWrongInputList,
3427 TosaErrorValidator.evWrongOutputList,
3428 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003429 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003430 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003431 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003432 "logical_left_shift": {
3433 "op": Op.LOGICAL_LEFT_SHIFT,
3434 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003435 "build_fcn": (
3436 build_binary_broadcast,
3437 TosaTensorGen.tgBroadcastFuzz,
3438 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003439 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003440 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003441 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003442 "error_if_validators": (
3443 TosaErrorValidator.evRankMismatch,
3444 TosaErrorValidator.evWrongInputType,
3445 TosaErrorValidator.evWrongOutputType,
3446 TosaErrorValidator.evWrongInputList,
3447 TosaErrorValidator.evWrongOutputList,
3448 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003449 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003450 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003451 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003452 "logical_right_shift": {
3453 "op": Op.LOGICAL_RIGHT_SHIFT,
3454 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003455 "build_fcn": (
3456 build_binary_broadcast,
3457 TosaTensorGen.tgBroadcastFuzz,
3458 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003459 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003460 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003461 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003462 "error_if_validators": (
3463 TosaErrorValidator.evRankMismatch,
3464 TosaErrorValidator.evWrongInputType,
3465 TosaErrorValidator.evWrongOutputType,
3466 TosaErrorValidator.evWrongInputList,
3467 TosaErrorValidator.evWrongOutputList,
3468 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003469 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003470 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003471 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003472 "logical_or": {
3473 "op": Op.LOGICAL_OR,
3474 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003475 "build_fcn": (
3476 build_binary_broadcast,
3477 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003478 TosaTensorValuesGen.tvgLazyGenDefault,
3479 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003480 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003481 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003482 "error_if_validators": (
3483 TosaErrorValidator.evRankMismatch,
3484 TosaErrorValidator.evWrongInputType,
3485 TosaErrorValidator.evWrongOutputType,
3486 TosaErrorValidator.evWrongInputList,
3487 TosaErrorValidator.evWrongOutputList,
3488 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003489 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003490 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003491 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003492 "logical_xor": {
3493 "op": Op.LOGICAL_XOR,
3494 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003495 "build_fcn": (
3496 build_binary_broadcast,
3497 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003498 TosaTensorValuesGen.tvgLazyGenDefault,
3499 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003500 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003501 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003502 "error_if_validators": (
3503 TosaErrorValidator.evRankMismatch,
3504 TosaErrorValidator.evWrongInputType,
3505 TosaErrorValidator.evWrongOutputType,
3506 TosaErrorValidator.evWrongInputList,
3507 TosaErrorValidator.evWrongOutputList,
3508 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003509 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003510 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003511 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003512 "maximum": {
3513 "op": Op.MAXIMUM,
3514 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003515 "build_fcn": (
3516 build_binary_broadcast,
3517 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003518 TosaTensorValuesGen.tvgLazyGenDefault,
3519 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003520 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003521 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003522 "error_if_validators": (
3523 TosaErrorValidator.evRankMismatch,
3524 TosaErrorValidator.evWrongInputType,
3525 TosaErrorValidator.evWrongOutputType,
3526 TosaErrorValidator.evWrongInputList,
3527 TosaErrorValidator.evWrongOutputList,
3528 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003529 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003530 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003531 "data_gen": {
3532 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3533 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003534 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 "minimum": {
3536 "op": Op.MINIMUM,
3537 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003538 "build_fcn": (
3539 build_binary_broadcast,
3540 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003541 TosaTensorValuesGen.tvgLazyGenDefault,
3542 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003543 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003544 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003545 "error_if_validators": (
3546 TosaErrorValidator.evRankMismatch,
3547 TosaErrorValidator.evWrongInputType,
3548 TosaErrorValidator.evWrongOutputType,
3549 TosaErrorValidator.evWrongInputList,
3550 TosaErrorValidator.evWrongOutputList,
3551 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003552 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003553 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003554 "data_gen": {
3555 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3556 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003557 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003558 "mul": {
3559 "op": Op.MUL,
3560 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003561 "build_fcn": (
3562 build_mul,
3563 TosaTensorGen.tgBroadcastFuzz,
3564 TosaTensorValuesGen.tvgMul,
3565 TosaArgGen.agMul,
3566 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003567 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003568 "error_if_validators": (
3569 TosaErrorValidator.evWrongInputType,
3570 TosaErrorValidator.evWrongOutputType,
3571 TosaErrorValidator.evWrongInputList,
3572 TosaErrorValidator.evWrongOutputList,
3573 TosaErrorValidator.evRankMismatch,
3574 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003575 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003576 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003577 "data_gen": {
3578 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3579 },
3580 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003581 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003582 "pow": {
3583 "op": Op.POW,
3584 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003585 "build_fcn": (
3586 build_binary_broadcast,
3587 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003588 TosaTensorValuesGen.tvgPow,
3589 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003590 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003591 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003592 "error_if_validators": (
3593 TosaErrorValidator.evRankMismatch,
3594 TosaErrorValidator.evWrongInputType,
3595 TosaErrorValidator.evWrongOutputType,
3596 TosaErrorValidator.evWrongInputList,
3597 TosaErrorValidator.evWrongOutputList,
3598 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003599 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003600 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003601 "data_gen": {
3602 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3603 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003604 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003605 "sub": {
3606 "op": Op.SUB,
3607 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003608 "build_fcn": (
3609 build_binary_broadcast,
3610 TosaTensorGen.tgBroadcastFuzz,
3611 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003612 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003613 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003614 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003615 "error_if_validators": (
3616 TosaErrorValidator.evRankMismatch,
3617 TosaErrorValidator.evWrongInputType,
3618 TosaErrorValidator.evWrongOutputType,
3619 TosaErrorValidator.evWrongInputList,
3620 TosaErrorValidator.evWrongOutputList,
3621 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003622 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003623 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003624 "data_gen": {
3625 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3626 },
3627 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003628 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003629 "table": {
3630 "op": Op.TABLE,
3631 # Use the automatic generation functions to create the input array
3632 # but create the table tensor in the build function, as it may be
3633 # a different type from the input
3634 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003635 "build_fcn": (
3636 build_table,
3637 TosaTensorGen.tgBasic,
3638 TosaTensorValuesGen.tvgDefault,
3639 TosaArgGen.agTable,
3640 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003641 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003642 "error_if_validators": (
3643 TosaErrorValidator.evWrongInputType,
3644 TosaErrorValidator.evWrongOutputType,
3645 TosaErrorValidator.evWrongInputList,
3646 TosaErrorValidator.evWrongOutputList,
3647 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003648 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003649 # Elementwise Unary operators
3650 "abs": {
3651 "op": Op.ABS,
3652 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003653 "build_fcn": (
3654 build_unary,
3655 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003656 TosaTensorValuesGen.tvgLazyGenDefault,
3657 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003658 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003659 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003660 "error_if_validators": (
3661 TosaErrorValidator.evWrongInputType,
3662 TosaErrorValidator.evWrongOutputType,
3663 TosaErrorValidator.evWrongInputList,
3664 TosaErrorValidator.evWrongOutputList,
3665 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003666 "data_gen": {
3667 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3668 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003669 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003670 "bitwise_not": {
3671 "op": Op.BITWISE_NOT,
3672 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003673 "build_fcn": (
3674 build_unary,
3675 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003676 TosaTensorValuesGen.tvgLazyGenDefault,
3677 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003678 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003679 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003680 "error_if_validators": (
3681 TosaErrorValidator.evWrongInputType,
3682 TosaErrorValidator.evWrongOutputType,
3683 TosaErrorValidator.evWrongInputList,
3684 TosaErrorValidator.evWrongOutputList,
3685 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003686 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003687 "ceil": {
3688 "op": Op.CEIL,
3689 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003690 "build_fcn": (
3691 build_unary,
3692 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003693 TosaTensorValuesGen.tvgLazyGenDefault,
3694 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003695 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003696 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003697 "error_if_validators": (
3698 TosaErrorValidator.evWrongInputType,
3699 TosaErrorValidator.evWrongOutputType,
3700 TosaErrorValidator.evWrongInputList,
3701 TosaErrorValidator.evWrongOutputList,
3702 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003703 "data_gen": {
3704 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3705 },
3706 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003707 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003708 "clz": {
3709 "op": Op.CLZ,
3710 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003711 "build_fcn": (
3712 build_unary,
3713 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003714 TosaTensorValuesGen.tvgLazyGenDefault,
3715 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003716 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003717 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003718 "error_if_validators": (
3719 TosaErrorValidator.evWrongInputType,
3720 TosaErrorValidator.evWrongOutputType,
3721 TosaErrorValidator.evWrongInputList,
3722 TosaErrorValidator.evWrongOutputList,
3723 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003724 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003725 "exp": {
3726 "op": Op.EXP,
3727 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003728 "build_fcn": (
3729 build_unary,
3730 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003731 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003732 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003733 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003734 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003735 "error_if_validators": (
3736 TosaErrorValidator.evWrongInputType,
3737 TosaErrorValidator.evWrongOutputType,
3738 TosaErrorValidator.evWrongInputList,
3739 TosaErrorValidator.evWrongOutputList,
3740 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003741 "data_gen": {
3742 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3743 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003744 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003745 "floor": {
3746 "op": Op.FLOOR,
3747 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003748 "build_fcn": (
3749 build_unary,
3750 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003751 TosaTensorValuesGen.tvgLazyGenDefault,
3752 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003753 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003754 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003755 "error_if_validators": (
3756 TosaErrorValidator.evWrongInputType,
3757 TosaErrorValidator.evWrongOutputType,
3758 TosaErrorValidator.evWrongInputList,
3759 TosaErrorValidator.evWrongOutputList,
3760 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003761 "data_gen": {
3762 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3763 },
3764 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003765 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003766 "log": {
3767 "op": Op.LOG,
3768 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003769 "build_fcn": (
3770 build_unary,
3771 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003772 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003773 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003774 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003775 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003776 "error_if_validators": (
3777 TosaErrorValidator.evWrongInputType,
3778 TosaErrorValidator.evWrongOutputType,
3779 TosaErrorValidator.evWrongInputList,
3780 TosaErrorValidator.evWrongOutputList,
3781 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003782 "data_gen": {
3783 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3784 },
3785 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003786 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003787 "logical_not": {
3788 "op": Op.LOGICAL_NOT,
3789 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003790 "build_fcn": (
3791 build_unary,
3792 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003793 TosaTensorValuesGen.tvgLazyGenDefault,
3794 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003795 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003796 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003797 "error_if_validators": (
3798 TosaErrorValidator.evWrongInputType,
3799 TosaErrorValidator.evWrongOutputType,
3800 TosaErrorValidator.evWrongInputList,
3801 TosaErrorValidator.evWrongOutputList,
3802 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003803 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003804 "negate": {
3805 "op": Op.NEGATE,
3806 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003807 "build_fcn": (
3808 build_unary,
3809 TosaTensorGen.tgBasic,
3810 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003811 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003812 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003813 "qgen": TosaQuantGen.qgUnary,
3814 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003815 "error_if_validators": (
3816 TosaErrorValidator.evInputZeroPointNotZero,
3817 TosaErrorValidator.evOutputZeroPointNotZero,
3818 TosaErrorValidator.evWrongInputType,
3819 TosaErrorValidator.evWrongOutputType,
3820 TosaErrorValidator.evWrongInputList,
3821 TosaErrorValidator.evWrongOutputList,
3822 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003823 "data_gen": {
3824 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3825 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003826 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003827 "reciprocal": {
3828 "op": Op.RECIPROCAL,
3829 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003830 "build_fcn": (
3831 build_unary,
3832 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003833 TosaTensorValuesGen.tvgLazyGenDefault,
3834 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003835 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003836 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003837 "error_if_validators": (
3838 TosaErrorValidator.evWrongInputType,
3839 TosaErrorValidator.evWrongOutputType,
3840 TosaErrorValidator.evWrongInputList,
3841 TosaErrorValidator.evWrongOutputList,
3842 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003843 "data_gen": {
3844 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3845 },
3846 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003847 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003848 "rsqrt": {
3849 "op": Op.RSQRT,
3850 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003851 "build_fcn": (
3852 build_unary,
3853 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003854 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003855 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003856 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003857 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003858 "error_if_validators": (
3859 TosaErrorValidator.evWrongInputType,
3860 TosaErrorValidator.evWrongOutputType,
3861 TosaErrorValidator.evWrongInputList,
3862 TosaErrorValidator.evWrongOutputList,
3863 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003864 "data_gen": {
3865 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3866 },
3867 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003868 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003869 # Elementwise Ternary operators
3870 "select": {
3871 "op": Op.SELECT,
3872 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003873 "build_fcn": (
3874 build_select,
3875 TosaTensorGen.tgBroadcastFuzz,
3876 TosaTensorValuesGen.tvgSelect,
3877 None,
3878 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003879 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003880 "error_if_validators": (
3881 TosaErrorValidator.evRankMismatch,
3882 TosaErrorValidator.evWrongInputType,
3883 TosaErrorValidator.evWrongOutputType,
3884 TosaErrorValidator.evWrongInputList,
3885 TosaErrorValidator.evWrongOutputList,
3886 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003887 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003888 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003889 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003890 # Comparison operators
3891 "equal": {
3892 "op": Op.EQUAL,
3893 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003894 "build_fcn": (
3895 build_comparison,
3896 TosaTensorGen.tgBroadcastFuzz,
3897 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003898 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003899 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003900 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003901 "error_if_validators": (
3902 TosaErrorValidator.evRankMismatch,
3903 TosaErrorValidator.evWrongInputType,
3904 TosaErrorValidator.evWrongOutputType,
3905 TosaErrorValidator.evWrongInputList,
3906 TosaErrorValidator.evWrongOutputList,
3907 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003908 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003909 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003910 "data_gen": {
3911 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3912 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003913 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003914 "greater_equal": {
3915 "op": Op.GREATER_EQUAL,
3916 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003917 "build_fcn": (
3918 build_comparison,
3919 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003920 TosaTensorValuesGen.tvgLazyGenDefault,
3921 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003922 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003923 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003924 "error_if_validators": (
3925 TosaErrorValidator.evRankMismatch,
3926 TosaErrorValidator.evWrongInputType,
3927 TosaErrorValidator.evWrongOutputType,
3928 TosaErrorValidator.evWrongInputList,
3929 TosaErrorValidator.evWrongOutputList,
3930 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003931 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003932 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003933 "data_gen": {
3934 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3935 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003936 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003937 "greater": {
3938 "op": Op.GREATER,
3939 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003940 "build_fcn": (
3941 build_comparison,
3942 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003943 TosaTensorValuesGen.tvgLazyGenDefault,
3944 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003945 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003946 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003947 "error_if_validators": (
3948 TosaErrorValidator.evRankMismatch,
3949 TosaErrorValidator.evWrongInputType,
3950 TosaErrorValidator.evWrongOutputType,
3951 TosaErrorValidator.evWrongInputList,
3952 TosaErrorValidator.evWrongOutputList,
3953 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003954 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003955 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003956 "data_gen": {
3957 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3958 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003959 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003960 # Reduction operators
3961 "reduce_all": {
3962 "op": Op.REDUCE_ALL,
3963 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003964 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003965 "build_fcn": (
3966 build_reduce,
3967 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003968 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003969 TosaArgGen.agAxis,
3970 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003971 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003972 "error_if_validators": (
3973 TosaErrorValidator.evAxisLargerRank,
3974 TosaErrorValidator.evAxisSmallerZero,
3975 TosaErrorValidator.evShapeOfAxisNotOne,
3976 TosaErrorValidator.evWrongInputType,
3977 TosaErrorValidator.evWrongOutputType,
3978 TosaErrorValidator.evWrongRank,
3979 TosaErrorValidator.evWrongInputList,
3980 TosaErrorValidator.evWrongOutputList,
3981 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003982 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003983 "reduce_any": {
3984 "op": Op.REDUCE_ANY,
3985 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003986 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003987 "build_fcn": (
3988 build_reduce,
3989 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003990 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003991 TosaArgGen.agAxis,
3992 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003993 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003994 "error_if_validators": (
3995 TosaErrorValidator.evAxisLargerRank,
3996 TosaErrorValidator.evAxisSmallerZero,
3997 TosaErrorValidator.evShapeOfAxisNotOne,
3998 TosaErrorValidator.evWrongInputType,
3999 TosaErrorValidator.evWrongOutputType,
4000 TosaErrorValidator.evWrongRank,
4001 TosaErrorValidator.evWrongInputList,
4002 TosaErrorValidator.evWrongOutputList,
4003 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004004 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004005 "reduce_max": {
4006 "op": Op.REDUCE_MAX,
4007 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004008 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004009 "build_fcn": (
4010 build_reduce,
4011 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004012 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004013 TosaArgGen.agAxis,
4014 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004015 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004016 "error_if_validators": (
4017 TosaErrorValidator.evAxisLargerRank,
4018 TosaErrorValidator.evAxisSmallerZero,
4019 TosaErrorValidator.evShapeOfAxisNotOne,
4020 TosaErrorValidator.evWrongInputType,
4021 TosaErrorValidator.evWrongOutputType,
4022 TosaErrorValidator.evWrongRank,
4023 TosaErrorValidator.evWrongInputList,
4024 TosaErrorValidator.evWrongOutputList,
4025 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004026 "data_gen": {
4027 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4028 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004029 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004030 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004031 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004032 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004033 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004034 "build_fcn": (
4035 build_reduce,
4036 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004037 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004038 TosaArgGen.agAxis,
4039 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004040 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004041 "error_if_validators": (
4042 TosaErrorValidator.evAxisLargerRank,
4043 TosaErrorValidator.evAxisSmallerZero,
4044 TosaErrorValidator.evShapeOfAxisNotOne,
4045 TosaErrorValidator.evWrongInputType,
4046 TosaErrorValidator.evWrongOutputType,
4047 TosaErrorValidator.evWrongRank,
4048 TosaErrorValidator.evWrongInputList,
4049 TosaErrorValidator.evWrongOutputList,
4050 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004051 "data_gen": {
4052 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4053 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004054 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004055 "reduce_product": {
4056 "op": Op.REDUCE_PRODUCT,
4057 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004058 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004059 "build_fcn": (
4060 build_reduce,
4061 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004062 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004063 TosaArgGen.agAxis,
4064 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004065 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004066 "error_if_validators": (
4067 TosaErrorValidator.evAxisLargerRank,
4068 TosaErrorValidator.evAxisSmallerZero,
4069 TosaErrorValidator.evShapeOfAxisNotOne,
4070 TosaErrorValidator.evWrongInputType,
4071 TosaErrorValidator.evWrongOutputType,
4072 TosaErrorValidator.evWrongRank,
4073 TosaErrorValidator.evWrongInputList,
4074 TosaErrorValidator.evWrongOutputList,
4075 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004076 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004077 "reduce_sum": {
4078 "op": Op.REDUCE_SUM,
4079 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004080 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004081 "build_fcn": (
4082 build_reduce,
4083 TosaTensorGen.tgBasic,
4084 TosaTensorValuesGen.tvgReduceSum,
4085 TosaArgGen.agAxis,
4086 ),
James Ward24dbc422022-10-19 12:20:31 +01004087 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004088 "error_if_validators": (
4089 TosaErrorValidator.evAxisLargerRank,
4090 TosaErrorValidator.evAxisSmallerZero,
4091 TosaErrorValidator.evShapeOfAxisNotOne,
4092 TosaErrorValidator.evWrongInputType,
4093 TosaErrorValidator.evWrongOutputType,
4094 TosaErrorValidator.evWrongRank,
4095 TosaErrorValidator.evWrongInputList,
4096 TosaErrorValidator.evWrongOutputList,
4097 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004098 "data_gen": {
4099 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4100 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004101 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004102 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004103 "concat": {
4104 "op": Op.CONCAT,
4105 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004106 "build_fcn": (
4107 build_concat,
4108 TosaTensorGen.tgConcat,
4109 TosaTensorValuesGen.tvgConcat,
4110 TosaArgGen.agAxis,
4111 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004112 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004113 "error_if_validators": (
4114 TosaErrorValidator.evAxisLargerRank,
4115 TosaErrorValidator.evAxisSmallerZero,
4116 TosaErrorValidator.evConcatInputRankMismatch,
4117 TosaErrorValidator.evConcatShapeSumMismatch,
4118 TosaErrorValidator.evConcatInputDimMismatch,
4119 TosaErrorValidator.evWrongInputType,
4120 TosaErrorValidator.evWrongOutputType,
4121 TosaErrorValidator.evWrongOutputList,
4122 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004123 },
4124 "pad": {
4125 "op": Op.PAD,
4126 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004127 "build_fcn": (
4128 build_pad,
4129 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004130 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004131 TosaArgGen.agPad,
4132 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004133 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004134 "error_if_validators": (
4135 TosaErrorValidator.evWrongInputType,
4136 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004137 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004138 TosaErrorValidator.evWrongOutputType,
4139 TosaErrorValidator.evWrongInputList,
4140 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004141 TosaErrorValidator.evRankMismatch,
4142 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004143 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004144 "data_gen": {
4145 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4146 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004147 },
Won Jeona21b2e82023-08-10 10:33:01 +00004148 "dim": {
4149 "op": Op.DIM,
4150 "operands": (1, 0),
4151 "build_fcn": (
4152 build_dim,
4153 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004154 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004155 TosaArgGen.agAxis,
4156 ),
4157 "types": TYPE_FIB,
4158 "error_if_validators": (
4159 TosaErrorValidator.evAxisLargerRank,
4160 TosaErrorValidator.evAxisSmallerZero,
4161 TosaErrorValidator.evWrongInputType,
4162 TosaErrorValidator.evWrongInputList,
4163 TosaErrorValidator.evWrongOutputList,
4164 TosaErrorValidator.evWrongRank,
4165 ),
4166 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004167 "reshape": {
4168 "op": Op.RESHAPE,
4169 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004170 "build_fcn": (
4171 build_reshape,
4172 TosaTensorGen.tgBasic,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004173 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004174 TosaArgGen.agReshape,
4175 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004176 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004177 "error_if_validators": (
4178 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4179 TosaErrorValidator.evWrongInputType,
4180 TosaErrorValidator.evWrongOutputType,
4181 TosaErrorValidator.evWrongInputList,
4182 TosaErrorValidator.evWrongOutputList,
4183 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004184 "data_gen": {
4185 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4186 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004187 },
4188 "reverse": {
4189 "op": Op.REVERSE,
4190 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004191 "build_fcn": (
4192 build_reverse,
4193 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004194 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004195 TosaArgGen.agAxis,
4196 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004197 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004198 "error_if_validators": (
4199 TosaErrorValidator.evAxisSmallerZero,
4200 TosaErrorValidator.evAxisLargerRank,
4201 TosaErrorValidator.evWrongInputType,
4202 TosaErrorValidator.evWrongOutputType,
4203 TosaErrorValidator.evWrongInputList,
4204 TosaErrorValidator.evWrongOutputList,
4205 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004206 },
4207 "slice": {
4208 "op": Op.SLICE,
4209 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004210 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004211 "build_fcn": (
4212 build_slice,
4213 TosaTensorGen.tgBasic,
4214 TosaTensorValuesGen.tvgDefault,
4215 TosaArgGen.agSlice,
4216 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004217 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004218 "error_if_validators": (
4219 TosaErrorValidator.evStartSmallerZero,
4220 TosaErrorValidator.evSizeSmallerEqualZero,
4221 TosaErrorValidator.evStartSizeOutsideBounds,
4222 TosaErrorValidator.evSizeOutputShapeMismatch,
4223 TosaErrorValidator.evInputSizeStartLengthMismatch,
4224 TosaErrorValidator.evWrongRank,
4225 TosaErrorValidator.evWrongInputType,
4226 TosaErrorValidator.evWrongOutputType,
4227 TosaErrorValidator.evWrongInputList,
4228 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004229 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004230 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004231 },
4232 "tile": {
4233 "op": Op.TILE,
4234 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004235 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004236 "build_fcn": (
4237 build_tile,
4238 TosaTensorGen.tgBasic,
4239 TosaTensorValuesGen.tvgDefault,
4240 TosaArgGen.agTile,
4241 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004242 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004243 "error_if_validators": (
4244 TosaErrorValidator.evWrongInputType,
4245 TosaErrorValidator.evWrongOutputType,
4246 TosaErrorValidator.evWrongInputList,
4247 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004248 TosaErrorValidator.evRankMismatch,
4249 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004250 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004251 },
4252 "transpose": {
4253 "op": Op.TRANSPOSE,
4254 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004255 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004256 "build_fcn": (
4257 build_transpose,
4258 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004259 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004260 TosaArgGen.agTranspose,
4261 ),
4262 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004263 "error_if_validators": (
4264 TosaErrorValidator.evIndexOutsideBounds,
4265 TosaErrorValidator.evIndexUsedTwice,
4266 TosaErrorValidator.evWrongInputType,
4267 TosaErrorValidator.evWrongOutputType,
4268 TosaErrorValidator.evWrongInputList,
4269 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004270 TosaErrorValidator.evWrongRank,
4271 TosaErrorValidator.evRankMismatch,
4272 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004273 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004274 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004275 # Data nodes
4276 "const": {
4277 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004278 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004279 "build_fcn": (
4280 build_const,
4281 TosaTensorGen.tgBasic,
4282 TosaTensorValuesGen.tvgDefault,
4283 None,
4284 ),
Luke Hutton65872422023-02-20 10:33:04 +00004285 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004286 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004287 "identity": {
4288 "op": Op.IDENTITY,
4289 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004290 "build_fcn": (
4291 build_unary,
4292 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004293 TosaTensorValuesGen.tvgLazyGenDefault,
4294 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004295 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004296 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004297 "data_gen": {
4298 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4299 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004300 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004301 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004302 "gather": {
4303 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004304 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004305 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004306 "build_fcn": (
4307 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004308 TosaTensorGen.tgGather,
4309 TosaTensorValuesGen.tvgGather,
4310 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004311 ),
James Ward24dbc422022-10-19 12:20:31 +01004312 "types": (
4313 DType.INT8,
4314 DType.INT16,
4315 DType.INT32,
4316 DType.FP16,
4317 DType.BF16,
4318 DType.FP32,
4319 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004320 "error_if_validators": (
4321 TosaErrorValidator.evWrongInputType,
4322 TosaErrorValidator.evWrongOutputType,
4323 TosaErrorValidator.evWrongInputList,
4324 TosaErrorValidator.evWrongOutputList,
4325 TosaErrorValidator.evWrongRank,
4326 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004327 "data_gen": {
4328 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4329 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004330 },
4331 "scatter": {
4332 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004333 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004334 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004335 "build_fcn": (
4336 build_scatter,
4337 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004338 TosaTensorValuesGen.tvgScatter,
4339 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004340 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004341 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004342 "error_if_validators": (
4343 TosaErrorValidator.evWrongInputType,
4344 TosaErrorValidator.evWrongOutputType,
4345 TosaErrorValidator.evWrongInputList,
4346 TosaErrorValidator.evWrongOutputList,
4347 TosaErrorValidator.evWrongRank,
4348 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004349 "data_gen": {
4350 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4351 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004352 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004353 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004354 "resize": {
4355 "op": Op.RESIZE,
4356 "operands": (1, 0),
4357 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004358 "build_fcn": (
4359 build_resize,
4360 TosaTensorGen.tgNHWC,
4361 TosaTensorValuesGen.tvgDefault,
4362 TosaArgGen.agResize,
4363 ),
James Ward24dbc422022-10-19 12:20:31 +01004364 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004365 "invalid_test_validators": (
4366 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004367 ),
4368 "error_if_validators": (
4369 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004370 TosaErrorValidator.evScaleSmallerEqualZero,
4371 TosaErrorValidator.evScaleNLargerMax,
4372 TosaErrorValidator.evScaleDLargerMax,
4373 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004374 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004375 TosaErrorValidator.evBorderSmallerMin,
4376 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004377 TosaErrorValidator.evWrongInputType,
4378 TosaErrorValidator.evWrongOutputType,
4379 TosaErrorValidator.evWrongRank,
4380 TosaErrorValidator.evWrongInputList,
4381 TosaErrorValidator.evWrongOutputList,
4382 TosaErrorValidator.evBatchMismatch,
4383 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004384 TosaErrorValidator.evResizeOutputShapeMismatch,
4385 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004386 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004387 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004388 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004389 "cast": {
4390 "op": Op.CAST,
4391 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004392 "build_fcn": (
4393 build_cast,
4394 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004395 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004396 TosaArgGen.agCast,
4397 ),
James Ward8b390432022-08-12 20:48:56 +01004398 "types": (
4399 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004400 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004401 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004402 DType.INT8,
4403 DType.INT16,
4404 DType.INT32,
4405 DType.BOOL,
4406 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004407 "error_if_validators": (
4408 TosaErrorValidator.evWrongInputType,
4409 TosaErrorValidator.evWrongOutputType,
4410 TosaErrorValidator.evWrongInputList,
4411 TosaErrorValidator.evWrongOutputList,
4412 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004413 "data_gen": {
4414 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4415 },
4416 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004417 },
4418 "rescale": {
4419 "op": Op.RESCALE,
4420 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004421 "build_fcn": (
4422 build_rescale,
4423 TosaTensorGen.tgBasic,
4424 TosaTensorValuesGen.tvgDefault,
4425 TosaArgGen.agRescale,
4426 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004427 "types": [
4428 DType.UINT8,
4429 DType.INT8,
4430 DType.INT16,
4431 DType.INT32,
4432 DType.INT48,
4433 DType.UINT16,
4434 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004435 "error_if_validators": (
4436 TosaErrorValidator.evInputZeroPointNotZero,
4437 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004438 TosaErrorValidator.evU16InputZeroPointNotValid,
4439 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004440 TosaErrorValidator.evScaleTrue,
4441 TosaErrorValidator.evScaleNotTrue,
4442 TosaErrorValidator.evWrongInputType,
4443 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004444 TosaErrorValidator.evWrongInputList,
4445 TosaErrorValidator.evWrongOutputList,
4446 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004447 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004448 # Custom
4449 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004450 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004451 # Two varients of cond_if, one that generates one of two constant tensors (no
4452 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4453 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004454 "cond_if_const": {
4455 "op": Op.COND_IF,
4456 "operands": (0, 2),
4457 "build_fcn": (
4458 build_cond_if_const,
4459 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004460 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004461 TosaArgGen.agCondIf,
4462 ),
4463 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004464 "error_if_validators": (
4465 TosaErrorValidator.evOutputListThenGraphMismatch,
4466 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004467 TosaErrorValidator.evCondIfCondNotMatchingBool,
4468 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004469 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004470 },
4471 "cond_if_binary": {
4472 "op": Op.COND_IF,
4473 "operands": (2, 0),
4474 "build_fcn": (
4475 build_cond_if_binary,
4476 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004477 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004478 TosaArgGen.agCondIf,
4479 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004480 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004481 "error_if_validators": (
4482 TosaErrorValidator.evInputListThenGraphMismatch,
4483 TosaErrorValidator.evInputListElseGraphMismatch,
4484 TosaErrorValidator.evOutputListThenGraphMismatch,
4485 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004486 TosaErrorValidator.evCondIfCondNotMatchingBool,
4487 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004488 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004489 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004490 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004491 "while_loop": {
4492 "op": Op.WHILE_LOOP,
4493 "operands": (0, 1),
4494 "build_fcn": (
4495 build_while_loop,
4496 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004497 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004498 TosaArgGen.agWhileLoop,
4499 ),
4500 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004501 "error_if_validators": (
4502 TosaErrorValidator.evInputListOutputListMismatch,
4503 TosaErrorValidator.evInputListCondGraphMismatch,
4504 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4505 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4506 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004507 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004508 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004509 },
Luke Hutton57287132023-02-06 14:54:18 +00004510 "fft2d": {
4511 "op": Op.FFT2D,
4512 "operands": (2, 0),
4513 "rank": (3, 3),
4514 "build_fcn": (
4515 build_fft2d,
4516 TosaTensorGen.tgFFT2d,
4517 TosaTensorValuesGen.tvgDefault,
4518 TosaArgGen.agFFT2d,
4519 ),
4520 "types": [DType.FP32],
4521 "error_if_validators": (
4522 TosaErrorValidator.evWrongInputType,
4523 TosaErrorValidator.evWrongOutputType,
4524 TosaErrorValidator.evWrongInputList,
4525 TosaErrorValidator.evWrongOutputList,
4526 TosaErrorValidator.evWrongRank,
4527 TosaErrorValidator.evBatchMismatch,
4528 TosaErrorValidator.evKernelNotPowerOfTwo,
4529 TosaErrorValidator.evFFTInputShapeMismatch,
4530 TosaErrorValidator.evFFTOutputShapeMismatch,
4531 ),
4532 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004533 "rfft2d": {
4534 "op": Op.RFFT2D,
4535 "operands": (1, 0),
4536 "rank": (3, 3),
4537 "build_fcn": (
4538 build_rfft2d,
4539 TosaTensorGen.tgRFFT2d,
4540 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004541 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004542 ),
4543 "types": [DType.FP32],
4544 "error_if_validators": (
4545 TosaErrorValidator.evWrongInputType,
4546 TosaErrorValidator.evWrongOutputType,
4547 TosaErrorValidator.evWrongInputList,
4548 TosaErrorValidator.evWrongOutputList,
4549 TosaErrorValidator.evWrongRank,
4550 TosaErrorValidator.evBatchMismatch,
4551 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004552 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004553 ),
4554 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004555 }
4556
Kevin Cheng550ccc52021-03-03 11:21:43 -08004557
Eric Kunzee5e26762020-10-13 16:11:07 -07004558class OutputShaper:
4559 # Methods in this class compute the expected output shape and datatype
4560 # for common classes of operations
4561 def __init__(self):
4562 pass
4563
4564 # These methods return arguments that can be used for
4565 # creating a new output tensor
4566 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004567 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4568 if error_name != ErrorIf.RankMismatch:
4569 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004570 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004571
4572 shape = []
4573 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004574 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004575 shape.append(b.shape[i])
4576 else:
4577 shape.append(a.shape[i])
4578
Jerry Ge135c9552023-05-23 20:59:32 +00004579 fuzz_idx = rng.integers(0, len(a.shape))
4580 if error_name == ErrorIf.DimensionMismatch:
4581 shape[fuzz_idx] += 1
4582
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004583 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004584 all_dtypes = [
4585 DType.INT8,
4586 DType.INT16,
4587 DType.INT32,
4588 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004589 DType.FP16,
4590 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004591 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004592 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004593 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4594 outputDType = rng.choice(wrong_dtypes)
4595 else:
4596 outputDType = a.dtype
4597
4598 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004599
4600 @staticmethod
4601 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004602 assert len(a.shape) == len(b.shape)
4603 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004604
4605 shape = []
4606 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004607 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004608 shape.append(a.shape[i])
4609
Kevin Cheng550ccc52021-03-03 11:21:43 -08004610 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004611
4612 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004613 def unaryOp(ser, rng, a, error_name=None):
4614 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004615 all_dtypes = [
4616 DType.INT8,
4617 DType.INT16,
4618 DType.INT32,
4619 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004620 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004621 DType.FP16,
4622 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004623 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004624 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4625 outputDType = rng.choice(wrong_dtypes)
4626 else:
4627 outputDType = a.dtype
4628
4629 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004630
4631 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004632 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004633 if error_name != ErrorIf.RankMismatch:
4634 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004635 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004636
4637 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004638 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004639 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004640 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4641 else:
4642 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004643
Jerry Ge135c9552023-05-23 20:59:32 +00004644 fuzz_idx = rng.integers(0, len(a.shape))
4645 if error_name == ErrorIf.DimensionMismatch:
4646 shape[fuzz_idx] += 1
4647
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004648 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004649 all_dtypes = [
4650 DType.INT8,
4651 DType.INT16,
4652 DType.INT32,
4653 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004654 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004655 DType.FP16,
4656 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004657 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004658 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4659 outputDType = rng.choice(wrong_dtypes)
4660 else:
4661 outputDType = a.dtype
4662
4663 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004664
4665 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004666 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004667 if error_name != ErrorIf.RankMismatch:
4668 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004669 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004670
4671 # Do broadcast
4672 shape = []
4673 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004674 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004675 shape.append(b.shape[i])
4676 else:
4677 shape.append(a.shape[i])
4678
Jerry Ge135c9552023-05-23 20:59:32 +00004679 fuzz_idx = rng.integers(0, len(a.shape))
4680 if error_name == ErrorIf.DimensionMismatch:
4681 shape[fuzz_idx] += 1
4682
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004683 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004684 wrong_dtypes = [
4685 DType.INT8,
4686 DType.INT16,
4687 DType.INT32,
4688 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004689 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004690 DType.FP16,
4691 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004692 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004693 outputDType = rng.choice(wrong_dtypes)
4694 else:
4695 outputDType = DType.BOOL
4696
4697 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004698
4699 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004700 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004701 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004702 if error_name not in [
4703 ErrorIf.AxisSmallerZero,
4704 ErrorIf.AxisLargerRank,
4705 ErrorIf.ShapeOfAxisNotOne,
4706 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004707 shape[axis] = 1
4708 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4709 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004710
Matthew Haddond6ce7252021-09-29 15:35:44 +01004711 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004712 all_dtypes = [
4713 DType.INT8,
4714 DType.INT16,
4715 DType.INT32,
4716 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004717 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004718 DType.FP16,
4719 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004720 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004721 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4722 outputDType = rng.choice(wrong_dtypes)
4723 else:
4724 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004725
Matthew Haddond6ce7252021-09-29 15:35:44 +01004726 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004727
4728 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004729 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004730 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004731
4732 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4733 del shape[axis]
4734
4735 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4736 remove = rng.choice([True, False])
4737 if remove and len(shape) > 1:
4738 del shape[0]
4739 else:
4740 shape.append(1)
4741 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4742 for i in range(len(shape)):
4743 shape[i] = shape[i] + rng.integers(1, 10)
4744
4745 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004746 all_dtypes = [
4747 DType.INT8,
4748 DType.INT16,
4749 DType.INT32,
4750 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004751 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004752 DType.FP16,
4753 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004754 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004755 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4756 outputDType = rng.choice(wrong_dtypes)
4757 else:
4758 outputDType = DType.INT32
4759
4760 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004761
4762 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004763 def conv2dOp(
4764 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4765 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004766
4767 # IFM: NHWC
4768 # Filter: OHWI
4769 # OFM: NHWC
4770
Kevin Cheng550ccc52021-03-03 11:21:43 -08004771 h = (
4772 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004773 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004774 + padding[0]
4775 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004776 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004777 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004778
Kevin Cheng550ccc52021-03-03 11:21:43 -08004779 w = (
4780 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004781 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004782 + padding[2]
4783 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004784 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004785 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004786
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004787 if error_name == ErrorIf.ConvOutputShapeMismatch:
4788 choices = [1, 2, 3]
4789 change = rng.choice(choices)
4790 # increment in multiples of stride to not hit non-integer error case
4791 if change in [1, 3]:
4792 h = h + (rng.choice(choices) * strides[0])
4793 if change in [2, 3]:
4794 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004795
Eric Kunzee5e26762020-10-13 16:11:07 -07004796 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4797
James Ward8b390432022-08-12 20:48:56 +01004798 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004799 # Pick some potentially correct output dtype if input type is incorrect
4800 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004801 else:
James Ward8b390432022-08-12 20:48:56 +01004802 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004803
4804 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004805 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004806 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004807 else:
4808 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004809 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004810 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004811
Kevin Cheng550ccc52021-03-03 11:21:43 -08004812 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004813
4814 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004815 def conv3dOp(
4816 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4817 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004818
4819 # IFM: NDHWC
4820 # Filter: ODHWI
4821 # OFM: NDHWC
4822
4823 d = (
4824 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004825 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004826 + padding[0]
4827 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004828 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004829 ) // strides[0] + 1
4830
4831 h = (
4832 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004833 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004834 + padding[2]
4835 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004836 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004837 ) // strides[1] + 1
4838
4839 w = (
4840 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004841 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004842 + padding[4]
4843 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004844 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004845 ) // strides[2] + 1
4846
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004847 if error_name == ErrorIf.ConvOutputShapeMismatch:
4848 choices = [1, 2, 3, 4]
4849 change = rng.choice(choices)
4850 # increment in multiples of stride to not hit non-integer error case
4851 if change in [1, 4]:
4852 d = d + (rng.choice(choices) * strides[0])
4853 if change in [2, 4]:
4854 h = h + (rng.choice(choices) * strides[1])
4855 if change in [3, 4]:
4856 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004857
Kevin Cheng1533b852021-09-01 12:51:58 -07004858 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4859
James Ward8b390432022-08-12 20:48:56 +01004860 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004861 # Pick some potentially correct output dtype if input type is incorrect
4862 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004863 else:
James Ward8b390432022-08-12 20:48:56 +01004864 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004865
4866 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004867 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004868 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004869 else:
4870 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004871 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004872 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004873
4874 return ser.addOutput(ofm_shape, out_dtype)
4875
4876 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004877 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004878 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004879 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004880 # IFM: NHWC
4881 # Filter: HWCM
4882 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004883
Kevin Cheng550ccc52021-03-03 11:21:43 -08004884 h = (
4885 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004886 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004887 + padding[0]
4888 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004889 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004890 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004891
Kevin Cheng550ccc52021-03-03 11:21:43 -08004892 w = (
4893 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004894 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004895 + padding[2]
4896 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004897 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004898 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004899
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004900 if error_name == ErrorIf.ConvOutputShapeMismatch:
4901 choices = [1, 2, 3]
4902 change = rng.choice(choices)
4903 # increment in multiples of stride to not hit non-integer error case
4904 if change in [1, 3]:
4905 h = h + (rng.choice(choices) * strides[0])
4906 if change in [2, 3]:
4907 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004908
Eric Kunzee5e26762020-10-13 16:11:07 -07004909 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4910
James Ward8b390432022-08-12 20:48:56 +01004911 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004912 # Pick some potentially correct output dtype if input type is incorrect
4913 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004914 else:
James Ward8b390432022-08-12 20:48:56 +01004915 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004916
4917 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004918 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004919 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004920 else:
4921 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004922 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004923 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004924
Kevin Cheng550ccc52021-03-03 11:21:43 -08004925 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004926
4927 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004928 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004929 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004930 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004931 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004932 h = 1
4933 w = 1
4934 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004935 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4936 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004937
4938 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004939 choices = [1, 2, 3]
4940 change = rng.choice(choices)
4941 # increment in multiples of stride to not hit non-integer error case
4942 if change in [1, 3]:
4943 h = h + (rng.choice(choices) * stride[0])
4944 if change in [2, 3]:
4945 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004946 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004947
4948 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004949 all_dtypes = [
4950 DType.INT8,
4951 DType.INT16,
4952 DType.INT32,
4953 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004954 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004955 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004956 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004957 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004958 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4959 outputDType = rng.choice(wrong_dtypes)
4960 else:
4961 outputDType = ifm.dtype
4962
4963 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004964
4965 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004966 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004967 # input: N, IC
4968 # filter: OC, IC
4969 # output: N, OC
4970
4971 output_shape = [input.shape[0], filter.shape[0]]
4972
James Ward8b390432022-08-12 20:48:56 +01004973 # Validated in arg_gen (also invalidated for ErrorIf)
4974 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004975
Kevin Cheng550ccc52021-03-03 11:21:43 -08004976 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004977
4978 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004979 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004980 # a: N, H, C
4981 # b: N, C, W
4982 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004983
Kevin Cheng2d60f002021-06-09 14:18:32 -07004984 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004985
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004986 if error_name == ErrorIf.WrongOutputType:
4987 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004988 incorrect_types = (
4989 DType.INT4,
4990 DType.INT8,
4991 DType.INT16,
4992 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004993 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004994 DType.FP16,
4995 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004996 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004997 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004998 incorrect_types = (
4999 DType.INT4,
5000 DType.INT8,
5001 DType.INT16,
5002 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005003 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005004 DType.FP16,
5005 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005006 )
James Ward24dbc422022-10-19 12:20:31 +01005007 elif (
5008 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5009 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005010 incorrect_types = (
5011 DType.INT4,
5012 DType.INT8,
5013 DType.INT16,
5014 DType.INT32,
5015 DType.INT48,
5016 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005017 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005018 elif error_name == ErrorIf.WrongInputType:
5019 # Pick some potentially correct output dtype if input type is incorrect
5020 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005021 else:
James Ward8b390432022-08-12 20:48:56 +01005022 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005023
Kevin Cheng550ccc52021-03-03 11:21:43 -08005024 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005025
5026 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005027 def concatOp(ser, rng, axis, inputs, error_name=None):
5028 input1 = inputs[0]
5029 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005030
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005031 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005032 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005033 if not (
5034 # unable to concat tensors of different ranks
5035 error_name == ErrorIf.ConcatInputRankMismatch
5036 # unable to concat tensors along an invalid axis
5037 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005038 ):
5039 for tensor in remaining_inputs:
5040 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005041
Matthew Haddon01c359d2021-10-15 16:30:48 +01005042 if error_name == ErrorIf.ConcatShapeSumMismatch:
5043 output_shape[axis] += rng.integers(5, 10)
5044
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005045 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005046 all_dtypes = {
5047 DType.INT8,
5048 DType.INT16,
5049 DType.INT32,
5050 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005051 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005052 DType.FP16,
5053 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005054 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005055 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5056 outputDType = rng.choice(wrong_dtypes)
5057 else:
5058 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005059
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005060 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005061
5062 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005063 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005064
5065 output_shape = a.shape.copy()
5066
5067 for i in range(len(output_shape)):
5068 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5069
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005070 if error_name == ErrorIf.PadOutputShapeMismatch:
5071 bad_dim = rng.choice(range(len(output_shape)))
5072 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005073 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005074 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005075
Matthew Haddone807aae2021-10-11 18:12:58 +01005076 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005077 all_dtypes = [
5078 DType.INT8,
5079 DType.INT16,
5080 DType.INT32,
5081 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005082 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005083 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005084 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005085 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005086 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5087 outputDType = rng.choice(wrong_dtypes)
5088 else:
5089 outputDType = a.dtype
5090
5091 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005092
5093 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005094 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00005095 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00005096
5097 if error_name == ErrorIf.WrongOutputType:
5098 all_dtypes = [
5099 DType.INT8,
5100 DType.INT16,
5101 DType.INT32,
5102 DType.INT48,
5103 DType.FP32,
5104 DType.FP16,
5105 DType.BF16,
5106 ]
5107 wrong_dtypes = list(set(all_dtypes))
5108 outputDType = rng.choice(wrong_dtypes)
5109 else:
5110 outputDType = DType.SHAPE
5111
5112 return ser.addOutput(output_shape, outputDType)
5113
5114 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005115 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005116 output_shape = shape.copy()
5117
Matthew Haddone807aae2021-10-11 18:12:58 +01005118 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5119 for i in range(len(output_shape)):
5120 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5121
5122 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005123 all_dtypes = [
5124 DType.INT8,
5125 DType.INT16,
5126 DType.INT32,
5127 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005128 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005129 DType.FP16,
5130 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005131 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005132 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5133 outputDType = rng.choice(wrong_dtypes)
5134 else:
5135 outputDType = a.dtype
5136
5137 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005138
5139 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005140 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005141
Matthew Haddone807aae2021-10-11 18:12:58 +01005142 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005143 all_dtypes = [
5144 DType.INT8,
5145 DType.INT16,
5146 DType.INT32,
5147 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005148 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005149 DType.FP16,
5150 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005151 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005152 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005153 outputDType = rng.choice(wrong_dtypes)
5154 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005155 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005156
Luke Huttona4e48ca2023-02-22 11:53:48 +00005157 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005158 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005159 for index in range(len(output_shape)):
5160 if output_shape[index] <= 2:
5161 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5162 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005163 output_shape[index] = output_shape[index] + rng.choice(
5164 [-2, -1, 1, 2]
5165 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005166 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5167 output_shape = input.shape.copy()
5168 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005169 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005170
5171 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005172
5173 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005174 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005175
5176 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005177 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005178
5179 for i in range(len(output_shape)):
5180 output_shape[i] = a.shape[i] * multiples[i]
5181
Luke Huttona4e48ca2023-02-22 11:53:48 +00005182 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005183 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005184
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005185 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005186 all_dtypes = [
5187 DType.INT8,
5188 DType.INT16,
5189 DType.INT32,
5190 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005191 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005192 DType.FP16,
5193 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005194 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005195 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5196 outputDType = rng.choice(wrong_dtypes)
5197 else:
5198 outputDType = a.dtype
5199
5200 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005201
5202 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005203 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005204 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005205
Kevin Cheng550ccc52021-03-03 11:21:43 -08005206 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005207
Luke Huttona4e48ca2023-02-22 11:53:48 +00005208 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005209 for i in range(len(output_shape)):
5210 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005211
Luke Huttona4e48ca2023-02-22 11:53:48 +00005212 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5213 for i in range(len(output_shape)):
5214 output_shape[i] += rng.integers(1, 10)
5215 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005216 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005217
Matthew Haddone807aae2021-10-11 18:12:58 +01005218 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005219 all_dtypes = [
5220 DType.INT8,
5221 DType.INT16,
5222 DType.INT32,
5223 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005224 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005225 DType.FP16,
5226 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005227 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005228 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5229 outputDType = rng.choice(wrong_dtypes)
5230 else:
5231 outputDType = a.dtype
5232
5233 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005234
5235 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005236 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005237 if error_name != ErrorIf.WrongRank:
5238 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005239 assert len(indices.shape) == 2
5240 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005241
Kevin Cheng77d0f762020-11-24 10:26:32 -08005242 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5243
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005244 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005245 all_dtypes = [
5246 DType.INT8,
5247 DType.INT16,
5248 DType.INT32,
5249 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005250 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005251 DType.FP16,
5252 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005253 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005254 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5255 outputDType = rng.choice(wrong_dtypes)
5256 else:
5257 outputDType = values.dtype
5258
5259 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005260
5261 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005262 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005263 if error_name != ErrorIf.WrongRank:
5264 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005265 assert len(indices.shape) == 2
5266 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005267 assert values_in.shape[0] == indices.shape[0] # N
5268 assert input.shape[1] == indices.shape[1] # W
5269 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005270
5271 output_shape = values_in.shape
5272
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005273 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005274 all_dtypes = [
5275 DType.INT8,
5276 DType.INT16,
5277 DType.INT32,
5278 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005279 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005280 DType.FP16,
5281 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005282 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005283 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5284 outputDType = rng.choice(wrong_dtypes)
5285 else:
5286 outputDType = values_in.dtype
5287
5288 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005289
5290 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005291 def tableOp(ser, rng, input, error_name=None):
5292 # Same shape as the input, dtype dependent on input dtype
5293 if error_name != ErrorIf.WrongInputType:
5294 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005295 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005296 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005297 wrong_dtypes = [
5298 DType.INT8,
5299 DType.INT16,
5300 DType.INT32,
5301 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005302 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005303 DType.FP16,
5304 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005305 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005306 wrong_dtypes.remove(output_dtype)
5307 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005308 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005309
5310 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005311 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005312 serializer,
5313 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005314 input,
5315 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005316 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005317 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005318 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005319 input_dtype,
5320 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005321 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005322 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005323 # Calculate OH, OW
5324 scale_y_n = scale[0]
5325 scale_y_d = scale[1]
5326 scale_x_n = scale[2]
5327 scale_x_d = scale[3]
5328 if error_name == ErrorIf.ScaleSmallerEqualZero:
5329 scale_y_n = max(scale_y_n, 1)
5330 scale_y_d = max(scale_y_d, 1)
5331 scale_x_n = max(scale_x_n, 1)
5332 scale_x_d = max(scale_x_d, 1)
5333
5334 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5335 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5336
5337 if error_name is not None:
5338 # Make sure the output tensor is valid, which can occur when
5339 # scale, offset or border have been changed for ERROR_IFs
5340 oh = max(oh, 1)
5341 ow = max(ow, 1)
5342 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005343 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5344 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005345
5346 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5347 choices = [1, 2, 3]
5348 change = rng.choice(choices)
5349 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5350 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005351 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005352 oh -= scale_y_d
5353 assert oh > 0 # Should have been caught in agResize
5354 else:
5355 oh += scale_y_d
5356 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005357 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005358 ow -= scale_x_d
5359 assert ow > 0 # Should have been caught in agResize
5360 else:
5361 ow += scale_x_d
5362
Matthew Haddon848efb42021-09-09 12:30:53 +01005363 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005364 output_dims = [
5365 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005366 oh,
5367 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005368 input.shape[0],
5369 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005370 elif error_name == ErrorIf.BatchMismatch:
5371 output_dims = [
5372 input.shape[0] + rng.integers(1, 10),
5373 oh,
5374 ow,
5375 input.shape[3],
5376 ]
5377 elif error_name == ErrorIf.ChannelMismatch:
5378 output_dims = [
5379 input.shape[0],
5380 oh,
5381 ow,
5382 input.shape[3] + rng.integers(1, 10),
5383 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005384 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005385 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005386
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005387 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005388
5389 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005390 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005391 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005392
5393 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005394 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005395 if error_name == ErrorIf.ConvOutputShapeMismatch:
5396 choices = [1, 2, 3]
5397 change = rng.choice(choices)
5398 if change in [1, 3]:
5399 output_shape[1] = output_shape[1] + rng.choice(choices)
5400 if change in [2, 3]:
5401 output_shape[2] = output_shape[2] + rng.choice(choices)
5402
James Ward8b390432022-08-12 20:48:56 +01005403 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005404 # Pick some potentially correct output dtype if input type is incorrect
5405 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005406 else:
James Ward8b390432022-08-12 20:48:56 +01005407 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005408
5409 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005410 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005411 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005412 else:
5413 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005414 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005415 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005416
Kevin Cheng550ccc52021-03-03 11:21:43 -08005417 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005418
5419 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005420 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5421 outputs = []
5422
5423 assert ifm1.dtype == ifm2.dtype
5424 input_dtype = ifm1.dtype
5425
5426 if error_name != ErrorIf.FFTInputShapeMismatch:
5427 assert ifm1.shape == ifm2.shape
5428
5429 input_shape = ifm1.shape
5430 if error_name != ErrorIf.WrongRank:
5431 assert len(input_shape) == 3
5432
5433 output_shape = input_shape.copy()
5434 output_dtype = input_dtype
5435
5436 if error_name == ErrorIf.WrongOutputType:
5437 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005438 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005439 output_dtype = rng.choice(wrong_dtypes)
5440 elif error_name == ErrorIf.BatchMismatch:
5441 output_shape[0] += rng.integers(1, 10)
5442 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5443 modify_dim = rng.choice([1, 2])
5444 output_shape[modify_dim] += rng.integers(1, 10)
5445
5446 outputs.append(serializer.addOutput(output_shape, output_dtype))
5447 outputs.append(serializer.addOutput(output_shape, output_dtype))
5448 return outputs
5449
5450 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005451 def rfft2dOp(serializer, rng, value, error_name=None):
5452 outputs = []
5453
5454 input_shape = value.shape
5455 if error_name != ErrorIf.WrongRank:
5456 assert len(input_shape) == 3
5457
5458 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5459
5460 output_dtype = value.dtype
5461 if error_name == ErrorIf.WrongOutputType:
5462 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005463 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005464 output_dtype = rng.choice(wrong_dtypes)
5465 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005466 output_shape[0] += rng.integers(1, 10)
5467 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5468 modify_dim = rng.choice([1, 2])
5469 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005470
5471 outputs.append(serializer.addOutput(output_shape, output_dtype))
5472 outputs.append(serializer.addOutput(output_shape, output_dtype))
5473 return outputs