blob: 63958a98b1c7b4df1274c88de1501cf313b8264b [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
69 vals.append(v)
70 return tuple(sorted(vals))
71
72 self.random_float_range = {}
73 for dtype in (DType.FP32, DType.FP16, DType.BF16):
74 self.random_float_range[dtype] = convertFPRange(
75 args.tensor_fp_value_range,
76 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
77 )
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 def createSerializer(self, opName, testPath):
80 self.testPath = os.path.join(opName, testPath)
81
82 fullPath = os.path.join(self.basePath, self.testPath)
83 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010084 # Embed const data in the flatbuffer
85 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010086 if self.args.lazy_data_gen:
87 # Lazy data generation - so make constants files
88 constMode = ts.ConstMode.INPUTS
89 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 constMode = ts.ConstMode.EMBED_DUMP
91 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070092
93 def getSerializer(self):
94 return self.ser
95
Jeremy Johnson1271c442023-09-05 11:39:26 +010096 def serialize(self, testName, metaData=None):
97 path = Path(self.basePath) / self.testPath
98
99 # Write out TOSA flatbuffer binary
100 path_fb = path / f"{testName}.tosa"
101 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 fd.write(self.ser.serialize())
103
Jeremy Johnson1271c442023-09-05 11:39:26 +0100104 # Get JSON descriptor from serializer
105 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
106
107 if metaData:
108 # Add extra meta data to desc.json
109 desc["meta"] = metaData
110
111 # Validate desc.json before we output it
112 self.descSchemaValidator.validate_config(desc)
113
114 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100115 if "data_gen" in metaData:
116 if self.args.lazy_data_gen:
117 # Output datagen meta data as CPP data
118 path_md = path / f"{testName}_meta_data_gen.cpp"
119 with path_md.open("w") as fd:
120 fd.write(TOSA_AUTOGENERATED_HEADER)
121 fd.write("// Test meta data for data generation setup\n\n")
122 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
123 json.dump(metaData["data_gen"], fd)
124 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100125 if "compliance" in metaData:
126 # Output datagen meta data as CPP data
127 path_md = path / f"{testName}_meta_compliance.cpp"
128 with path_md.open("w") as fd:
129 fd.write(TOSA_AUTOGENERATED_HEADER)
130 fd.write("// Test meta data for compliance validation\n\n")
131 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
132 json.dump(metaData["compliance"], fd)
133 fd.write(')";\n\n')
134
135 # Write desc.json
136 path_desc = path / "desc.json"
137 with path_desc.open("w") as fd:
138 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700139
Matthew Haddon74567092021-07-16 15:38:20 +0100140 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000141 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100142 seed = self.random_seed + 1
143 self.rng = np.random.default_rng(seed)
144
Jeremy Johnson1271c442023-09-05 11:39:26 +0100145 def getDTypeRange(self, dtype, high_inclusive=False):
146 # Returns dtype value range boundaries (low, high)
147 # The high boundary is excluded in the range
148 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100149 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100150 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 elif dtype == DType.BOOL:
152 rng = (0, 2)
153 elif dtype == DType.UINT8:
154 rng = (0, 256)
155 elif dtype == DType.UINT16:
156 rng = (0, 65536)
157 elif dtype == DType.INT4:
158 # TOSA specific INT4 weight range from -7 to 7
159 rng = (-7, 8)
160 elif dtype == DType.INT8:
161 rng = (-128, 128)
162 elif dtype == DType.INT16:
163 rng = (-32768, 32768)
164 elif dtype in (DType.INT32, DType.SHAPE):
165 # restricting too large value for SHAPE
166 rng = (-(1 << 31), (1 << 31))
167 elif dtype == DType.INT48:
168 rng = (-(1 << 47), (1 << 47))
169 else:
170 raise Exception("Unknown dtype: {}".format(dtype))
171
172 if not high_inclusive:
173 # Exclusive high: low <= range < high
174 return rng
175 else:
176 # Inclusive range: low <= range <= high
177 return (rng[0], rng[1] - 1)
178
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000179 def getRandTensor(self, shape, dtype, data_range=None):
180 if data_range is None:
181 low, high = self.getDTypeRange(dtype)
182 else:
183 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100184
Eric Kunzee5e26762020-10-13 16:11:07 -0700185 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100188 return np.int64(self.rng.integers(low=low, high=high, size=shape))
189 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
190 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
191
192 if dtype == DType.FP16:
193 return np.float16(f_tensor)
194 else:
195 f32_tensor = np.float32(f_tensor)
196 if dtype == DType.BF16:
197 # Floor the last 16 bits of each f32 value
198 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
199 else:
200 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100202 # All other integer types
203 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700204
Kevin Cheng989cb052021-04-28 16:29:44 -0700205 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 placeholders = []
207
Kevin Cheng989cb052021-04-28 16:29:44 -0700208 assert len(shape_list) == len(dtype_list)
209
Jeremy Johnson1271c442023-09-05 11:39:26 +0100210 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100212 if not self.args.lazy_data_gen:
213 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700214 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
216 return placeholders
217
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700219 consts = []
220
Kevin Cheng989cb052021-04-28 16:29:44 -0700221 assert len(shape_list) == len(dtype_list)
222
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100225 if not self.args.lazy_data_gen:
226 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700227 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700228
229 return consts
230
231 def makeShape(self, rank):
232 if self.targetted_shape:
233 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800234 return np.int32(
235 self.rng.integers(
236 low=self.args.tensor_shape_range[0],
237 high=self.args.tensor_shape_range[1],
238 size=rank,
239 )
240 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700241
242 def setTargetShape(self, shape):
243 self.targetted_shape = shape
244
245 def randInt(self, low=0, high=256):
246 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
247
248 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100249 low, high = self.getDTypeRange(dtype)
250
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100251 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100252 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100253 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100254 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100255 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100256 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
257 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 elif dtype == DType.BOOL:
259 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700260 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700261 # Special size
262 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700263
264 return np.int32(self.rng.integers(low, high, size=1))[0]
265
266 def shapeStr(self, shape):
267
268 sStr = []
269 # Convert to strings
270 for i in shape:
271 sStr.append(str(i))
272
Kevin Cheng550ccc52021-03-03 11:21:43 -0800273 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100275 def typeStr(self, dtype):
276 if isinstance(dtype, list) or isinstance(dtype, tuple):
277 assert len(dtype) >= 2
278 strs = [self.typeStr(t) for t in dtype]
279 # Limit types to the first 2 as the 3rd is the accumulator
280 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700281 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100282 if dtype in gtu.DTYPE_ATTRIBUTES:
283 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700284 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100285 raise Exception(
286 "Unknown dtype, cannot convert to string: {}".format(dtype)
287 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700288
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100289 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100290 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100291 if dtype in gtu.DTYPE_ATTRIBUTES:
292 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700293 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100294 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700295
Luke Hutton57287132023-02-06 14:54:18 +0000296 def constrictBatchSize(self, shape):
297 # Limit the batch size unless an explicit target shape set
298 if self.args.max_batch_size and not self.args.target_shapes:
299 shape[0] = min(shape[0], self.args.max_batch_size)
300 return shape
301
James Ward30124a82023-02-02 14:56:33 +0000302 def makeDimension(self):
303 return self.randInt(
304 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
305 )
306
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100307 def tensorComplianceMetaData(
308 self, op, inputType, argsDict, outputTensor, errorName
309 ):
310 if (
311 errorName
312 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
313 or not gtu.dtypeIsSupportedByCompliance(inputType)
314 ):
315 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100316 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100317
Jeremy Johnson1271c442023-09-05 11:39:26 +0100318 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100319 compliance_tens = {
320 "mode": None,
321 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
322 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
323 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100324 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
325 mode = gtu.ComplianceMode.DOT_PRODUCT
326 compliance_tens["dot_product_info"] = {
327 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100328 "ks": int(argsDict["ksb"])
329 if "ksb" in argsDict
330 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100331 }
332 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
333 mode = gtu.ComplianceMode.FP_SPECIAL
334 elif "compliance" in op and "ulp" in op["compliance"]:
335 mode = gtu.ComplianceMode.ULP
336 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
337 elif op["op"] == Op.REDUCE_PRODUCT:
338 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnson9a758382023-11-07 16:27:35 +0000339 elif op["op"] in (Op.EXP, Op.POW):
340 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnson1271c442023-09-05 11:39:26 +0100341 else:
342 mode = gtu.ComplianceMode.EXACT
343 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
344
345 return compliance_tens
346
347 # Build Op functions
348 # Create the output tensor (calling OutputShaper as needed)
349 # Do final tweaks to attributes (if necessary for errorIf)
350 # Add Op into graph
351 # Return resulting tensor information or BuildInfo
352
353 class BuildInfo:
354 """Enhanced build information containing result tensor and associated compliance dict."""
355
356 def __init__(self, resultTensor, complianceDict):
357 self.resultTensor = resultTensor
358 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700359
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000360 def build_unary(
361 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
362 ):
363 assert len(inputs) == 1
364 a = inputs[0]
365 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100366
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000367 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100368
369 # Ensure new output type has correct qinfo
370 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000371 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000372 qinfo = [
373 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000374 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000375 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100376
377 # Invalidate Input/Output list for error if checks.
378 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000379 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100380 pCount, cCount = op["operands"]
381 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000382 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
383 self, error_name, input_list, output_list
384 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100385
Les Bell729b0352021-11-24 10:28:21 +0000386 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100387 self.ser,
388 validator_fcns,
389 error_name,
390 op=op,
391 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000392 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000394 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100395 input_list=input_list,
396 output_list=output_list,
397 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000398 ):
399 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100400
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000401 attr = None
402 if op["op"] == Op.NEGATE:
403 attr = ts.TosaSerializerAttribute()
404 attr.NegateAttribute(qinfo[0], qinfo[1])
405
406 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000407
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000408 compliance = self.tensorComplianceMetaData(
409 op, a.dtype, args_dict, result_tensor, error_name
410 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000411 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700412
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000413 def build_binary_broadcast(
414 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
415 ):
416 assert len(inputs) == 2
417 a, b = inputs
418 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000419 self.ser, self.rng, a, b, error_name
420 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100421
422 # Invalidate Input/Output list for error if checks.
423 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000424 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100425 pCount, cCount = op["operands"]
426 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000427 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
428 self, error_name, input_list, output_list
429 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100430
Les Bell729b0352021-11-24 10:28:21 +0000431 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100432 self.ser,
433 validator_fcns,
434 error_name,
435 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000436 input1=a,
437 input2=b,
438 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000439 output_dtype=result_tensor.dtype,
440 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100441 input_list=input_list,
442 output_list=output_list,
443 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000444 ):
445 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100446
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000447 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000448
Jeremy Johnson9a758382023-11-07 16:27:35 +0000449 compliance = self.tensorComplianceMetaData(
450 op, a.dtype, args_dict, result_tensor, error_name
451 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000452
453 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700454
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100455 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700456 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000457 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700458 return result_tens
459
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000460 def build_arithmetic_right_shift(
461 self, op, a, b, round, validator_fcns=None, error_name=None
462 ):
463 result_tens = OutputShaper.binaryBroadcastOp(
464 self.ser, self.rng, a, b, error_name
465 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100466
467 # Invalidate Input/Output list for error if checks.
468 input_list = [a.name, b.name]
469 output_list = [result_tens.name]
470 pCount, cCount = op["operands"]
471 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000472 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
473 self, error_name, input_list, output_list
474 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100475
Les Bell729b0352021-11-24 10:28:21 +0000476 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100477 self.ser,
478 validator_fcns,
479 error_name,
480 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000481 input1=a,
482 input2=b,
483 input_dtype=a.dtype,
484 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000485 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100486 input_list=input_list,
487 output_list=output_list,
488 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000489 ):
490 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800491
492 attr = ts.TosaSerializerAttribute()
493 attr.ArithmeticRightShiftAttribute(round)
494
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000495 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800496 return result_tens
497
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100498 def build_mul(
499 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
500 ):
501 assert len(inputs) == 2
502 a, b = inputs
503 shift = args_dict["shift"]
504
505 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000506 self.ser, self.rng, a, b, error_name
507 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700508
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100509 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100510 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100511 result_tensor.setDtype(DType.INT32)
512
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100513 if error_name == ErrorIf.WrongOutputType:
514 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
515 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100516 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100517
518 # Invalidate Input/Output list for error if checks.
519 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100520 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100521 pCount, cCount = op["operands"]
522 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000523 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
524 self, error_name, input_list, output_list
525 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100526
Les Bell729b0352021-11-24 10:28:21 +0000527 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100528 self.ser,
529 validator_fcns,
530 error_name,
531 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000532 input1=a,
533 input2=b,
534 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100535 output_dtype=result_tensor.dtype,
536 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100537 input_list=input_list,
538 output_list=output_list,
539 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000540 ):
541 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700542
Kevin Chengaee1fac2020-11-11 13:54:06 -0800543 attr = ts.TosaSerializerAttribute()
544 attr.MulAttribute(shift)
545
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000546 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100547
548 compliance = self.tensorComplianceMetaData(
549 op, a.dtype, args_dict, result_tensor, error_name
550 )
551
552 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700553
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100554 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
555 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700556
Kevin Chengfe392ce2021-10-18 21:51:55 +0000557 attr = ts.TosaSerializerAttribute()
558 attr.TableAttribute(table)
559
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100560 # Invalidate Input/Output list for error if checks.
561 input_list = [a.name]
562 output_list = [result_tens.name]
563 pCount, cCount = op["operands"]
564 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000565 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
566 self, error_name, input_list, output_list
567 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100568
Les Bell729b0352021-11-24 10:28:21 +0000569 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100570 self.ser,
571 validator_fcns,
572 error_name,
573 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000574 input_shape=a.shape,
575 input_dtype=a.dtype,
576 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000577 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100578 input_list=input_list,
579 output_list=output_list,
580 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000581 ):
582 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100583
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000584 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700585
586 return result_tens
587
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100588 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
589 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
590
591 # Invalidate Input/Output list for error if checks.
592 input_list = [cond.name, a.name, b.name]
593 output_list = [result_tens.name]
594 pCount, cCount = op["operands"]
595 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000596 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
597 self, error_name, input_list, output_list
598 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100599
Les Bell729b0352021-11-24 10:28:21 +0000600 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100601 self.ser,
602 validator_fcns,
603 error_name,
604 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000605 input1=cond,
606 input2=a,
607 input3=b,
608 input_shape=a.shape,
609 input_dtype=a.dtype,
610 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000611 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100612 input_list=input_list,
613 output_list=output_list,
614 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000615 ):
616 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100617
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000618 self.ser.addOperator(
619 op["op"],
620 input_list,
621 output_list,
622 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700623 return result_tens
624
Jeremy Johnsona0150012023-11-15 15:52:06 +0000625 def build_comparison(
626 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
627 ):
628 assert len(inputs) == 2
629 a, b = inputs
630
631 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000632 self.ser, self.rng, a, b, error_name
633 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100634
635 # Invalidate Input/Output list for error if checks.
636 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000637 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100638 pCount, cCount = op["operands"]
639 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000640 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
641 self, error_name, input_list, output_list
642 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100643
Les Bell729b0352021-11-24 10:28:21 +0000644 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100645 self.ser,
646 validator_fcns,
647 error_name,
648 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000649 input1=a,
650 input2=b,
651 input_shape=a.shape,
652 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000653 output_shape=result_tensor.shape,
654 output_dtype=result_tensor.dtype,
655 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100656 input_list=input_list,
657 output_list=output_list,
658 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000659 ):
660 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100661
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000662 self.ser.addOperator(
663 op["op"],
664 input_list,
665 output_list,
666 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000667
668 compliance = self.tensorComplianceMetaData(
669 op, a.dtype, args_dict, result_tensor, error_name
670 )
671 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700672
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000673 def build_argmax(
674 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
675 ):
676 assert len(inputs) == 1
677 a = inputs[0]
678 axis = args_dict["axis"]
679 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100680
681 # Invalidate Input/Output list for error if checks.
682 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000683 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100684 pCount, cCount = op["operands"]
685 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000686 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
687 self, error_name, input_list, output_list
688 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100689
Les Bell729b0352021-11-24 10:28:21 +0000690 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100691 self.ser,
692 validator_fcns,
693 error_name,
694 op=op,
695 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000696 input_shape=a.shape,
697 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000698 output_shape=result_tensor.shape,
699 output_dtype=result_tensor.dtype,
700 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100701 input_list=input_list,
702 output_list=output_list,
703 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000704 ):
705 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700706
707 attr = ts.TosaSerializerAttribute()
708 attr.AxisAttribute(axis)
709
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000710 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000711
712 compliance = self.tensorComplianceMetaData(
713 op, inputs[0].dtype, args_dict, result_tensor, error_name
714 )
715 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700716
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000717 def build_pool2d(
718 self,
719 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100720 inputs,
721 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000722 validator_fcns=None,
723 error_name=None,
724 qinfo=None,
725 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100726 assert len(inputs) == 1
727 input = inputs[0]
728 # max_pool has no accum_dtype
729 accum_dtype = (
730 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
731 )
732 stride = args_dict["stride"]
733 pad = args_dict["pad"]
734 kernel = args_dict["kernel"]
735
Jeremy Johnson0601f802023-11-08 16:28:09 +0000736 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000737 self.ser, self.rng, input, kernel, stride, pad, error_name
738 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100739
740 # Ensure new output type has correct qinfo
741 if error_name == ErrorIf.WrongInputType:
742 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000743 qinfo = [
744 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000745 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000746 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100747
748 # Invalidate Input/Output list for error if checks.
749 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000750 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100751 pCount, cCount = op["operands"]
752 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000753 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
754 self, error_name, input_list, output_list
755 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100756
Les Bell729b0352021-11-24 10:28:21 +0000757 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100758 self.ser,
759 validator_fcns,
760 error_name,
761 op=op,
762 input_shape=input.shape,
763 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000764 output_shape=result_tensor.shape,
765 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100766 kernel=kernel,
767 stride=stride,
768 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000769 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000770 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100771 input_list=input_list,
772 output_list=output_list,
773 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000774 ):
775 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700776
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000777 if qinfo is None:
778 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700779
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000780 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100781 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000782
783 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700784
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100785 compliance = self.tensorComplianceMetaData(
786 op, inputs[0].dtype, args_dict, result_tensor, error_name
787 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100788
789 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100790
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000791 def build_conv2d(
792 self,
793 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100794 inputs,
795 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000796 validator_fcns=None,
797 error_name=None,
798 qinfo=None,
799 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100800 assert len(inputs) == 3
801 ifm, filter, bias = inputs
802 accum_dtype = args_dict["acc_type"]
803 strides = args_dict["stride"]
804 padding = args_dict["pad"]
805 dilations = args_dict["dilation"]
806
Kevin Cheng550ccc52021-03-03 11:21:43 -0800807 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100808 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100809 self.ser,
810 self.rng,
811 ifm,
812 filter,
813 accum_dtype,
814 strides,
815 padding,
816 dilations,
817 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000818 )
819
820 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000821 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
822 DType.INT8,
823 DType.UINT8,
824 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000825 qinfo = [
826 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100827 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000828 ]
Les Bell0e027d42021-11-09 14:42:14 +0000829
830 # Invalidate Input/Output list for error_if checks.
831 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100832 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000833 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000834 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
835 self, error_name, input_list, output_list
836 )
Les Bell0e027d42021-11-09 14:42:14 +0000837
Les Bell729b0352021-11-24 10:28:21 +0000838 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000839 self.ser,
840 validator_fcns,
841 error_name,
842 op=op,
843 input_dtype=ifm.dtype,
844 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100845 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000846 qinfo=qinfo,
847 input_list=input_list,
848 num_operands=num_operands,
849 output_list=output_list,
850 pad=padding,
851 stride=strides,
852 dilation=dilations,
853 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100854 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100855 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000856 ):
857 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700858
Tai Lyd3797f02023-11-15 23:06:19 +0000859 # TODO - Test local_bound, for now set local bound attribute to False
860 local_bound = False
861
Eric Kunzee5e26762020-10-13 16:11:07 -0700862 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000863 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700864
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000865 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100866
867 compliance = self.tensorComplianceMetaData(
868 op, ifm.dtype, args_dict, result_tensor, error_name
869 )
870
871 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700872
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000873 def build_conv3d(
874 self,
875 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100876 inputs,
877 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000878 validator_fcns=None,
879 error_name=None,
880 qinfo=None,
881 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100882 assert len(inputs) == 3
883 ifm, filter, bias = inputs
884 accum_dtype = args_dict["acc_type"]
885 strides = args_dict["stride"]
886 padding = args_dict["pad"]
887 dilations = args_dict["dilation"]
888
Kevin Cheng1533b852021-09-01 12:51:58 -0700889 assert len(padding) == 6
890 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100891 self.ser,
892 self.rng,
893 ifm,
894 filter,
895 accum_dtype,
896 strides,
897 padding,
898 dilations,
899 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000900 )
901
902 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000903 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
904 DType.INT8,
905 DType.UINT8,
906 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000907 qinfo = [
908 TosaQuantGen.getZeroPoint(self, ifm.dtype),
909 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
910 ]
Les Bell0e027d42021-11-09 14:42:14 +0000911
912 # Invalidate Input/Output list for error_if checks.
913 input_list = [ifm.name, filter.name, bias.name]
914 output_list = [result_tens.name]
915 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000916 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
917 self, error_name, input_list, output_list
918 )
Les Bell0e027d42021-11-09 14:42:14 +0000919
Les Bell729b0352021-11-24 10:28:21 +0000920 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000921 self.ser,
922 validator_fcns,
923 error_name,
924 op=op,
925 input_dtype=ifm.dtype,
926 weight_dtype=filter.dtype,
927 output_dtype=result_tens.dtype,
928 qinfo=qinfo,
929 input_list=input_list,
930 num_operands=num_operands,
931 output_list=output_list,
932 pad=padding,
933 stride=strides,
934 dilation=dilations,
935 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100936 weight_shape=filter.shape,
937 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000938 ):
939 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700940
Tai Lyd3797f02023-11-15 23:06:19 +0000941 # TODO - Test local_bound, for now set local bound attribute to False
942 local_bound = False
943
Kevin Cheng1533b852021-09-01 12:51:58 -0700944 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000945 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700946
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000947 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700948 return result_tens
949
Kevin Cheng550ccc52021-03-03 11:21:43 -0800950 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000951 self,
952 op,
953 ifm,
954 filter,
955 bias,
James Ward8b390432022-08-12 20:48:56 +0100956 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000957 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700958 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000959 output_shape,
960 validator_fcns=None,
961 error_name=None,
962 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800963 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700964 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000965 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100966 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000967 )
Les Bell0e027d42021-11-09 14:42:14 +0000968
969 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000970 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
971 DType.INT8,
972 DType.UINT8,
973 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000974 qinfo = [
975 TosaQuantGen.getZeroPoint(self, ifm.dtype),
976 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
977 ]
Les Bell0e027d42021-11-09 14:42:14 +0000978
979 # Invalidate Input/Output list for error_if checks.
980 input_list = [ifm.name, filter.name, bias.name]
981 output_list = [result_tens.name]
982 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000983 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
984 self, error_name, input_list, output_list
985 )
Les Bell0e027d42021-11-09 14:42:14 +0000986
Les Bell729b0352021-11-24 10:28:21 +0000987 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000988 self.ser,
989 validator_fcns,
990 error_name,
991 op=op,
992 input_dtype=ifm.dtype,
993 weight_dtype=filter.dtype,
994 output_dtype=result_tens.dtype,
995 qinfo=qinfo,
996 input_list=input_list,
997 num_operands=num_operands,
998 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700999 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001000 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001001 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001002 weight_shape=filter.shape,
1003 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001004 ):
1005 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001006
Tai Lyd3797f02023-11-15 23:06:19 +00001007 # TODO - Test local_bound, for now set local bound attribute to False
1008 local_bound = False
1009
Eric Kunzee5e26762020-10-13 16:11:07 -07001010 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001011 attr.TransposeConvAttribute(
1012 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1013 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001014
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001015 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001016 return result_tens
1017
Kevin Cheng550ccc52021-03-03 11:21:43 -08001018 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001019 self,
1020 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001021 inputs,
1022 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001023 validator_fcns=None,
1024 error_name=None,
1025 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001026 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001027 assert len(inputs) == 3
1028 ifm, filter, bias = inputs
1029 accum_dtype = args_dict["acc_type"]
1030 strides = args_dict["stride"]
1031 padding = args_dict["pad"]
1032 dilations = args_dict["dilation"]
1033
Kevin Cheng550ccc52021-03-03 11:21:43 -08001034 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001035 self.ser,
1036 self.rng,
1037 ifm,
1038 filter,
1039 accum_dtype,
1040 strides,
1041 padding,
1042 dilations,
1043 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001044 )
1045
1046 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001047 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1048 DType.INT8,
1049 DType.UINT8,
1050 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001051 qinfo = [
1052 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1053 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1054 ]
Les Bell0e027d42021-11-09 14:42:14 +00001055
1056 # Invalidate Input/Output list for error_if checks.
1057 input_list = [ifm.name, filter.name, bias.name]
1058 output_list = [result_tens.name]
1059 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001060 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1061 self, error_name, input_list, output_list
1062 )
Les Bell0e027d42021-11-09 14:42:14 +00001063
Les Bell729b0352021-11-24 10:28:21 +00001064 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001065 self.ser,
1066 validator_fcns,
1067 error_name,
1068 op=op,
1069 input_dtype=ifm.dtype,
1070 weight_dtype=filter.dtype,
1071 output_dtype=result_tens.dtype,
1072 qinfo=qinfo,
1073 input_list=input_list,
1074 num_operands=num_operands,
1075 output_list=output_list,
1076 pad=padding,
1077 stride=strides,
1078 dilation=dilations,
1079 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001080 weight_shape=filter.shape,
1081 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001082 ):
1083 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001084
Tai Lyd3797f02023-11-15 23:06:19 +00001085 # TODO - Test local_bound, for now set local bound attribute to False
1086 local_bound = False
1087
Eric Kunzee5e26762020-10-13 16:11:07 -07001088 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001089 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001090
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001091 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001092 return result_tens
1093
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001094 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001095 self,
1096 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001097 inputs,
1098 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001099 validator_fcns=None,
1100 error_name=None,
1101 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001102 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001103 assert len(inputs) == 3
1104 ifm, filter, bias = inputs
1105 accum_dtype = args_dict["acc_type"]
1106
1107 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001108 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001109 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001110
1111 # Invalidate Input/Output list for error if checks.
1112 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001113 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001114 pCount, cCount = op["operands"]
1115 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001116 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1117 self, error_name, input_list, output_list
1118 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001119
Les Bell729b0352021-11-24 10:28:21 +00001120 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001121 self.ser,
1122 validator_fcns,
1123 error_name,
1124 op=op,
1125 input_shape=ifm.shape,
1126 input_dtype=ifm.dtype,
1127 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001128 output_shape=result_tensor.shape,
1129 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001130 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001131 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001132 input_list=input_list,
1133 output_list=output_list,
1134 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001135 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001136 ):
1137 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001138
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001139 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001140 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001141
1142 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001143
1144 compliance = self.tensorComplianceMetaData(
1145 op, ifm.dtype, args_dict, result_tensor, error_name
1146 )
1147
1148 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001149
James Ward8b390432022-08-12 20:48:56 +01001150 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001151 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001152 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001153 assert len(inputs) == 2
1154 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001155 accum_dtype = args_dict["acc_type"]
1156 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001157 self.ser, self.rng, a, b, accum_dtype, error_name
1158 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001159
1160 # Invalidate Input/Output list for error if checks.
1161 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001162 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001163 pCount, cCount = op["operands"]
1164 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001165 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1166 self, error_name, input_list, output_list
1167 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001168
Les Bell729b0352021-11-24 10:28:21 +00001169 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001170 self.ser,
1171 validator_fcns,
1172 error_name,
1173 op=op,
1174 input_shape=a.shape,
1175 input_dtype=a.dtype,
1176 input2_shape=b.shape,
1177 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001178 output_shape=result_tensor.shape,
1179 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001180 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001181 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001182 input_list=input_list,
1183 output_list=output_list,
1184 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001185 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001186 ):
1187 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001188
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001189 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001190 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001191
1192 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001193
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001194 compliance = self.tensorComplianceMetaData(
1195 op, a.dtype, args_dict, result_tensor, error_name
1196 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001197
1198 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001199
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001200 def build_reduce(
1201 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1202 ):
1203 assert len(inputs) == 1
1204 a = inputs[0]
1205 axis = args_dict["axis"]
1206 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001207
1208 # Invalidate Input/Output list for error if checks.
1209 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001210 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001211 pCount, cCount = op["operands"]
1212 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001213 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1214 self, error_name, input_list, output_list
1215 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001216
Les Bell729b0352021-11-24 10:28:21 +00001217 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001218 self.ser,
1219 validator_fcns,
1220 error_name,
1221 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001222 axis=axis,
1223 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001224 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001225 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001226 output_dtype=result_tensor.dtype,
1227 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001228 input_list=input_list,
1229 output_list=output_list,
1230 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001231 ):
1232 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001233
1234 attr = ts.TosaSerializerAttribute()
1235 attr.AxisAttribute(axis)
1236
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001237 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001238
1239 if op["op"] == Op.REDUCE_PRODUCT:
1240 # TODO: Add compliance support!
1241 compliance = None
1242 else:
1243 compliance = self.tensorComplianceMetaData(
1244 op, a.dtype, args_dict, result_tensor, error_name
1245 )
1246
1247 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001248
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001249 def build_clamp(
1250 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1251 ):
1252 assert len(inputs) == 1
1253 a = inputs[0]
1254
1255 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001256
Jeremy Johnson18e26662021-07-22 16:15:29 +01001257 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001258
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001259 if error_name == ErrorIf.MaxSmallerMin:
1260 # Make sure the numbers are different to invoke this error
1261 while v[0] == v[1]:
1262 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1263 max_val = min(v)
1264 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001265 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001266 max_val = max(v)
1267 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001268
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001269 # Invalidate Input/Output list for error if checks.
1270 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001271 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001272 pCount, cCount = op["operands"]
1273 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001274 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1275 self, error_name, input_list, output_list
1276 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001277
Les Bell729b0352021-11-24 10:28:21 +00001278 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001279 self.ser,
1280 validator_fcns,
1281 error_name,
1282 op=op,
1283 max_val=max_val,
1284 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001285 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001286 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001287 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001288 output_dtype=result_tensor.dtype,
1289 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001290 input_list=input_list,
1291 output_list=output_list,
1292 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001293 ):
1294 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001295
1296 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001297 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1298 if a.dtype == DType.FP16:
1299 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1300 min_val = min_val.astype(np.float32)
1301 max_val = max_val.astype(np.float32)
1302
1303 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001304 else:
James Ward34071252022-12-07 15:48:47 +00001305 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001306
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001307 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001308
1309 compliance = self.tensorComplianceMetaData(
1310 op, a.dtype, args_dict, result_tensor, error_name
1311 )
1312
1313 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001314
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001315 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1316 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001317 attr = ts.TosaSerializerAttribute()
1318
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001319 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001320
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001321 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001322 return result_tens
1323
1324 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001325 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1326 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001327
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001328 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001329 return result_tens
1330
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001331 def build_activation(
1332 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1333 ):
1334 assert len(inputs) == 1
1335 a = inputs[0]
1336
1337 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001338
1339 # Invalidate Input/Output list for error if checks.
1340 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001341 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001342 pCount, cCount = op["operands"]
1343 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001344 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1345 self, error_name, input_list, output_list
1346 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001347
Les Bell729b0352021-11-24 10:28:21 +00001348 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001349 self.ser,
1350 validator_fcns,
1351 error_name,
1352 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001353 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001354 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001355 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001356 output_dtype=result_tensor.dtype,
1357 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001358 input_list=input_list,
1359 output_list=output_list,
1360 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001361 ):
1362 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001363
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001364 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001365
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001366 compliance = self.tensorComplianceMetaData(
1367 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001368 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001369
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001370 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001371
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001372 def build_concat(
1373 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1374 ):
1375 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001376 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001377 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001378
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001379 result_tensor = OutputShaper.concatOp(
1380 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001381 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001382
Matthew Haddon818ab902021-07-27 09:12:49 +01001383 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001384 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001385 input_tensor_names.append(tensor.name)
1386
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001387 # Invalidate Input/Output list for error if checks.
1388 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001389 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001390 pCount, cCount = op["operands"]
1391 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001392 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1393 self, error_name, input_list, output_list
1394 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001395
Les Bell729b0352021-11-24 10:28:21 +00001396 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001397 self.ser,
1398 validator_fcns,
1399 error_name,
1400 op=op,
1401 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001402 input_shape=inputs[0].shape,
1403 output_shape=result_tensor.shape,
1404 input_dtype=inputs[0].dtype,
1405 output_dtype=result_tensor.dtype,
1406 inputs=inputs,
1407 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001408 input_list=input_list,
1409 output_list=output_list,
1410 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001411 ):
1412 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001413
1414 attr = ts.TosaSerializerAttribute()
1415 attr.AxisAttribute(axis)
1416
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001417 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001418 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001419
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001420 def build_pad(
1421 self,
1422 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001423 inputs,
1424 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001425 validator_fcns=None,
1426 error_name=None,
1427 qinfo=None,
1428 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001429 assert len(inputs) == 1
1430 a = inputs[0]
1431 padding = args_dict["pad"]
1432 pad_const_int = args_dict["pad_const_int"]
1433 pad_const_float = args_dict["pad_const_fp"]
1434
1435 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001436
Kevin Chengfe392ce2021-10-18 21:51:55 +00001437 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001438 attr.PadAttribute(
1439 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1440 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001441
Matthew Haddone807aae2021-10-11 18:12:58 +01001442 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001443 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001444 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001445 pCount, cCount = op["operands"]
1446 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001447 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1448 self, error_name, input_list, output_list
1449 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001450
Les Bell729b0352021-11-24 10:28:21 +00001451 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001452 self.ser,
1453 validator_fcns,
1454 error_name,
1455 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001456 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001457 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001458 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001459 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001460 pad=padding,
1461 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001462 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001463 input_list=input_list,
1464 output_list=output_list,
1465 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001466 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001467 ):
1468 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001469
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001470 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001471
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001472 compliance = self.tensorComplianceMetaData(
1473 op, a.dtype, args_dict, result_tensor, error_name
1474 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001475
1476 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001477
Won Jeona21b2e82023-08-10 10:33:01 +00001478 def build_dim(
1479 self,
1480 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001481 inputs,
1482 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001483 validator_fcns=None,
1484 error_name=None,
1485 qinfo=None,
1486 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001487 assert len(inputs) == 1
1488 a = inputs[0]
1489 axis = args_dict["axis"]
1490 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001491
1492 # Invalidate Input/Output list for error if checks.
1493 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001494 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001495 pCount, cCount = op["operands"]
1496 num_operands = pCount + cCount
1497 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1498 self, error_name, input_list, output_list
1499 )
1500
1501 if not TosaErrorValidator.evValidateErrorIfs(
1502 self.ser,
1503 validator_fcns,
1504 error_name,
1505 op=op,
1506 axis=axis,
1507 input_shape=a.shape,
1508 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001509 output_shape=result_tensor.shape,
1510 output_dtype=result_tensor.dtype,
1511 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001512 input_list=input_list,
1513 output_list=output_list,
1514 num_operands=num_operands,
1515 ):
1516 return None
1517
1518 attr = ts.TosaSerializerAttribute()
1519 attr.AxisAttribute(axis)
1520
1521 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001522 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001523
Matthew Haddone807aae2021-10-11 18:12:58 +01001524 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001525 result_tens = OutputShaper.reshapeOp(
1526 self.ser, self.rng, a, newShape, error_name
1527 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001528
1529 # Invalidate Input/Output list for error if checks.
1530 input_list = [a.name]
1531 output_list = [result_tens.name]
1532 pCount, cCount = op["operands"]
1533 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001534 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1535 self, error_name, input_list, output_list
1536 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001537
Les Bell729b0352021-11-24 10:28:21 +00001538 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001539 self.ser,
1540 validator_fcns,
1541 error_name,
1542 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001543 input_shape=a.shape,
1544 output_shape=result_tens.shape,
1545 input_dtype=a.dtype,
1546 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001547 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001548 input_list=input_list,
1549 output_list=output_list,
1550 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001551 ):
1552 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001553
1554 attr = ts.TosaSerializerAttribute()
1555 attr.ReshapeAttribute(newShape)
1556
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001557 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001558 return result_tens
1559
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001560 def build_reverse(
1561 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1562 ):
1563 assert len(inputs) == 1
1564 a = inputs[0]
1565 axis = args_dict["axis"]
1566 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001567
1568 # Invalidate Input/Output list for error if checks.
1569 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001570 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001571 pCount, cCount = op["operands"]
1572 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001573 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1574 self, error_name, input_list, output_list
1575 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001576
Les Bell729b0352021-11-24 10:28:21 +00001577 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001578 self.ser,
1579 validator_fcns,
1580 error_name,
1581 op=op,
1582 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001583 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001584 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001585 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001586 output_dtype=result_tensor.dtype,
1587 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001588 input_list=input_list,
1589 output_list=output_list,
1590 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001591 ):
1592 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001593
1594 attr = ts.TosaSerializerAttribute()
1595 attr.AxisAttribute(axis)
1596
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001597 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001598 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001599
Matthew Haddone807aae2021-10-11 18:12:58 +01001600 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1601 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001602
Kevin Chengfe392ce2021-10-18 21:51:55 +00001603 attr = ts.TosaSerializerAttribute()
1604 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001605
Matthew Haddone807aae2021-10-11 18:12:58 +01001606 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001607 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001608 output_list = [result_tens.name]
1609 pCount, cCount = op["operands"]
1610 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001611 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1612 self, error_name, input_list, output_list
1613 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001614
Les Bell729b0352021-11-24 10:28:21 +00001615 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001616 self.ser,
1617 validator_fcns,
1618 error_name,
1619 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001620 input_shape=a.shape,
1621 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001622 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001623 input_dtype=a.dtype,
1624 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001625 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001626 input_list=input_list,
1627 output_list=output_list,
1628 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001629 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001630 ):
1631 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001632
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001633 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001634 return result_tens
1635
Matthew Haddone807aae2021-10-11 18:12:58 +01001636 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001637 result_tens = OutputShaper.sliceOp(
1638 self.ser, self.rng, a, start, size, error_name
1639 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001640
1641 # Invalidate Input/Output list for error if checks.
1642 input_list = [a.name]
1643 output_list = [result_tens.name]
1644 pCount, cCount = op["operands"]
1645 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001646 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1647 self, error_name, input_list, output_list
1648 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001649
Les Bell729b0352021-11-24 10:28:21 +00001650 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001651 self.ser,
1652 validator_fcns,
1653 error_name,
1654 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001655 input_shape=a.shape,
1656 output_shape=result_tens.shape,
1657 input_dtype=a.dtype,
1658 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001659 start=start,
1660 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001661 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001662 input_list=input_list,
1663 output_list=output_list,
1664 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001665 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001666 ):
1667 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001668
1669 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001670 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001671
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001672 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001673 return result_tens
1674
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001675 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1676 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1677
1678 # Invalidate Input/Output list for error if checks.
1679 input_list = [a.name]
1680 output_list = [result_tens.name]
1681 pCount, cCount = op["operands"]
1682 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001683 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1684 self, error_name, input_list, output_list
1685 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001686
Les Bell729b0352021-11-24 10:28:21 +00001687 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001688 self.ser,
1689 validator_fcns,
1690 error_name,
1691 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001692 input_shape=a.shape,
1693 output_shape=result_tens.shape,
1694 input_dtype=a.dtype,
1695 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001696 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001697 input_list=input_list,
1698 output_list=output_list,
1699 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001700 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001701 ):
1702 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001703
1704 attr = ts.TosaSerializerAttribute()
1705 attr.TileAttribute(multiples)
1706
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001707 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001708 return result_tens
1709
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001710 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001711
1712 # Create a new indicies tensor
1713 # here with data that doesn't exceed the dimensions of the values tensor
1714
Kevin Cheng550ccc52021-03-03 11:21:43 -08001715 K = values.shape[1] # K
1716 W = self.randInt(
1717 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1718 ) # W
1719 indicies_arr = np.int32(
1720 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1721 ) # (N, W)
1722 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001723
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001724 result_tens = OutputShaper.gatherOp(
1725 self.ser, self.rng, values, indicies, error_name
1726 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001727
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001728 # Invalidate Input/Output list for error if checks.
1729 input_list = [values.name, indicies.name]
1730 output_list = [result_tens.name]
1731 pCount, cCount = op["operands"]
1732 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001733 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1734 self, error_name, input_list, output_list
1735 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001736
Les Bell729b0352021-11-24 10:28:21 +00001737 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001738 self.ser,
1739 validator_fcns,
1740 error_name,
1741 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001742 input_shape=values.shape,
1743 output_shape=result_tens.shape,
1744 input_dtype=values.dtype,
1745 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001746 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001747 input_list=input_list,
1748 output_list=output_list,
1749 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001750 ):
1751 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001752
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001753 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001754
1755 return result_tens
1756
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001757 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001758
1759 # Create a new indicies tensor
1760 # here with data that doesn't exceed the dimensions of the values_in tensor
1761
Kevin Cheng550ccc52021-03-03 11:21:43 -08001762 K = values_in.shape[1] # K
1763 W = input.shape[1] # W
1764 indicies_arr = np.int32(
1765 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1766 ) # (N, W)
1767 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001768
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001769 result_tens = OutputShaper.scatterOp(
1770 self.ser, self.rng, values_in, indicies, input, error_name
1771 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001772
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001773 # Invalidate Input/Output list for error if checks.
1774 input_list = [values_in.name, indicies.name, input.name]
1775 output_list = [result_tens.name]
1776 pCount, cCount = op["operands"]
1777 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001778 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1779 self, error_name, input_list, output_list
1780 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001781
Les Bell729b0352021-11-24 10:28:21 +00001782 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001783 self.ser,
1784 validator_fcns,
1785 error_name,
1786 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001787 input_shape=values_in.shape,
1788 output_shape=result_tens.shape,
1789 input_dtype=values_in.dtype,
1790 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001791 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001792 input_list=input_list,
1793 output_list=output_list,
1794 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001795 ):
1796 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001797
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001798 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001799
Kevin Cheng77d0f762020-11-24 10:26:32 -08001800 return result_tens
1801
Kevin Cheng550ccc52021-03-03 11:21:43 -08001802 def build_resize(
1803 self,
1804 op,
1805 input,
1806 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001807 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001808 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001809 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001810 input_dtype,
1811 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001812 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001813 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001814 ):
1815 result_tens = OutputShaper.resizeOp(
1816 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001817 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001818 input,
1819 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001820 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001821 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001822 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001823 input_dtype,
1824 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001825 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001826 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001827
Matthew Haddon848efb42021-09-09 12:30:53 +01001828 # Invalidate Input/Output list for error if checks.
1829 input_list = [input.name]
1830 output_list = [result_tens.name]
1831 pCount, cCount = op["operands"]
1832 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001833 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1834 self, error_name, input_list, output_list
1835 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001836
Les Bell729b0352021-11-24 10:28:21 +00001837 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001838 self.ser,
1839 validator_fcns,
1840 error_name,
1841 op=op,
1842 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001843 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001844 input_dtype=input_dtype,
1845 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001846 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001847 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001848 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001849 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001850 input_list=input_list,
1851 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001852 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001853 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001854 ):
1855 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001856
Eric Kunzee5e26762020-10-13 16:11:07 -07001857 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001858
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001859 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001860
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001861 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001862 return result_tens
1863
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001864 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1865 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1866 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001867 self.ser.addOperator(
1868 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1869 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001870 return result_tens
1871
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001872 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001873 self.ser.addOutputTensor(val)
1874 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001875
1876 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001877 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001878 result_tens = OutputShaper.typeConversionOp(
1879 self.ser, self.rng, val, out_dtype, error_name
1880 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001881
1882 # Invalidate Input/Output list for error if checks.
1883 input_list = [val.name]
1884 output_list = [result_tens.name]
1885 pCount, cCount = op["operands"]
1886 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001887 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1888 self, error_name, input_list, output_list
1889 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001890
Les Bell729b0352021-11-24 10:28:21 +00001891 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001892 self.ser,
1893 validator_fcns,
1894 error_name,
1895 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001896 input_shape=val.shape,
1897 output_shape=result_tens.shape,
1898 input_dtype=val.dtype,
1899 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001900 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001901 input_list=input_list,
1902 output_list=output_list,
1903 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001904 ):
1905 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001906
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001907 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001908 return result_tens
1909
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001910 def build_rescale(
1911 self,
1912 op,
1913 val,
1914 out_dtype,
1915 scale32,
1916 double_round,
1917 per_channel,
1918 validator_fcns,
1919 error_name,
1920 ):
1921 result_tens = OutputShaper.typeConversionOp(
1922 self.ser, self.rng, val, out_dtype, error_name
1923 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001924
1925 if per_channel:
1926 nc = val.shape[-1]
1927 else:
1928 nc = 1
1929
1930 in_type_width = self.typeWidth(val.dtype)
1931 out_type_width = self.typeWidth(out_dtype)
1932
Kevin Cheng3a478572021-01-22 17:21:02 -08001933 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001934 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001935 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001936 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001937 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001938 in_type_width += 1
1939 elif error_name in [
1940 ErrorIf.InputZeroPointNotZero,
1941 ErrorIf.U16InputZeroPointNotValid,
1942 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001943 input_zp = self.randInt(-128, 128)
1944 if input_zp == 0:
1945 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001946 in_type_width += 1
1947 elif val.dtype == DType.UINT16:
1948 # Must come after ErrorIf.U16InputZeroPointNotValid check
1949 input_zp = self.rng.choice([0, 32768])
1950 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001951 else:
1952 input_zp = 0
1953
Kevin Cheng3a478572021-01-22 17:21:02 -08001954 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001955 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001956 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001957 elif out_dtype == DType.UINT8:
1958 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001959 out_type_width += 1
1960 elif error_name in [
1961 ErrorIf.OutputZeroPointNotZero,
1962 ErrorIf.U16OutputZeroPointNotValid,
1963 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001964 output_zp = self.randInt(-128, 128)
1965 if output_zp == 0:
1966 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001967 out_type_width += 1
1968 elif out_dtype == DType.UINT16:
1969 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1970 output_zp = self.rng.choice([0, 32768])
1971 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001972 else:
1973 output_zp = 0
1974
1975 # Calculate scale based on:
1976 # scale = a *(2^output_width)/(2^input_width))
1977
1978 a = np.float32(self.rng.random(size=[nc]))
1979 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1980
1981 if scale32:
1982 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001983 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001984 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1985 else:
1986 # Cap the scaling at 2^15 - 1 for scale16
1987 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1988
Kevin Cheng550ccc52021-03-03 11:21:43 -08001989 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001990
1991 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1992 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001993 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1994 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001995
1996 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001997 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1998 scale_arr[i], scale32
1999 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002000 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2001 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002002
Kevin Cheng550ccc52021-03-03 11:21:43 -08002003 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002004 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002005 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002006 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002007 assert val.placeholderFilename
2008 values = np.load(
2009 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2010 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002011 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2012 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2013 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2014 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002015 if not np.all(np.array_equal(values, val_adj)):
2016 # Values changed so overwrite file with new values
2017 np.save(
2018 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2019 val_adj,
2020 False,
2021 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002022
Matthew Haddonc2025212021-10-08 21:21:05 +01002023 # Invalidate Input/Output list for error if checks.
2024 input_list = [val.name]
2025 output_list = [result_tens.name]
2026 pCount, cCount = op["operands"]
2027 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002028 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2029 self, error_name, input_list, output_list
2030 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002031
2032 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002033 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002034 self.ser,
2035 validator_fcns,
2036 error_name,
2037 op=op,
2038 input_dtype=val.dtype,
2039 output_dtype=out_dtype,
2040 input_shape=val.shape,
2041 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002042 scale32=scale32,
2043 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002044 input_list=input_list,
2045 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002046 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002047 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002048 ):
2049 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002050
Eric Kunzee5e26762020-10-13 16:11:07 -07002051 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002052 attr.RescaleAttribute(
2053 input_zp,
2054 output_zp,
2055 multiplier_arr,
2056 shift_arr,
2057 scale32,
2058 double_round,
2059 per_channel,
2060 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002061
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002062 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002063 return result_tens
2064
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002065 def _get_condition_tensor(self, op, cond, error_name):
2066 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002067 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002068 else:
2069 cond_type = DType.BOOL
2070 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2071 choice = self.rng.choice([1, 2])
2072 if choice == 1:
2073 cond_shape = [2]
2074 else:
2075 cond_shape = [1, 2]
2076 else:
2077 # Must be of size 1 (rank 0)
2078 cond_shape = []
2079 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2080 return cond_tens
2081
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002082 def build_cond_if_const(
2083 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2084 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002085 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002086 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002087 # and fill them with const nodes for the body.
2088
2089 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002090 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002091
2092 # Make then/else tensors
2093 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002094
2095 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002096 if error_name in [
2097 ErrorIf.CondIfOutputListThenGraphMismatch,
2098 ErrorIf.CondIfOutputListElseGraphMismatch,
2099 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002100 incorrect_shape = deepcopy(then_tens.shape)
2101 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002102 incorrect_shape[i] += (
2103 self.rng.choice([-3, -2, 2, 3])
2104 if incorrect_shape[i] > 3
2105 else self.rng.choice([1, 2, 4])
2106 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002107 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2108
Jeremy Johnson18e26662021-07-22 16:15:29 +01002109 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2110 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002111
2112 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002113 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002114
2115 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002116 then_block = "THEN_BLOCK"
2117 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002118 attr = ts.TosaSerializerAttribute()
2119 attr.CondIfAttribute(then_block, else_block)
2120
2121 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002122 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002123
Jerry Ge9e94af82022-10-27 09:57:00 -07002124 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002125 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002126 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2127 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2128 else:
2129 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002130 self.ser.addOutputTensor(then_tens)
2131
Jerry Ge9e94af82022-10-27 09:57:00 -07002132 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002133 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2134 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2135 else:
2136 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002137 self.ser.addOutputTensor(else_tens)
2138
Les Bell729b0352021-11-24 10:28:21 +00002139 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002140 self.ser,
2141 validator_fcns,
2142 error_name,
2143 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002144 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002145 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002146 ):
2147 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002148
Eric Kunzee5e26762020-10-13 16:11:07 -07002149 return result_tens
2150
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002151 def build_cond_if_binary(
2152 self, op, a, b, cond, validator_fcns=None, error_name=None
2153 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002154 # For cond_if with a binary op in the then/else blocks, take a and b and
2155 # alternately add or subtract them based on the condition
2156
2157 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002158 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002159
Kevin Cheng550ccc52021-03-03 11:21:43 -08002160 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002161
2162 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002163 then_block = "THEN_BLOCK"
2164 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002165 attr = ts.TosaSerializerAttribute()
2166 attr.CondIfAttribute(then_block, else_block)
2167
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002168 if error_name in [
2169 ErrorIf.CondIfInputListThenGraphMismatch,
2170 ErrorIf.CondIfInputListElseGraphMismatch,
2171 ErrorIf.CondIfOutputListElseGraphMismatch,
2172 ErrorIf.CondIfOutputListThenGraphMismatch,
2173 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002174 incorrect_shape = a.shape.copy()
2175 for i in range(len(incorrect_shape)):
2176 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2177 incorrect_block_input = deepcopy(a)
2178 incorrect_block_input.shape = incorrect_shape
2179
Eric Kunzee5e26762020-10-13 16:11:07 -07002180 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002181 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002182 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002183 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002184
James Ward24dbc422022-10-19 12:20:31 +01002185 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002186 then_op, else_op = Op.ADD, Op.SUB
2187 elif a.dtype in (DType.INT8, DType.INT16):
2188 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2189 else:
2190 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002191
Les Bell6040b4d2021-10-11 12:50:31 +01002192 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002193 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002194 if (
2195 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2196 and block == then_block
2197 ) or (
2198 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2199 and block == else_block
2200 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002201 self.ser.addInputTensor(incorrect_block_input)
2202 self.ser.addInputTensor(b)
2203 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002204 elif (
2205 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2206 and block == then_block
2207 ) or (
2208 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2209 and block == else_block
2210 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002211 self.ser.addInputTensor(a)
2212 self.ser.addInputTensor(b)
2213 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2214 else:
2215 self.ser.addInputTensor(a)
2216 self.ser.addInputTensor(b)
2217 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002218 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002219
Les Bell729b0352021-11-24 10:28:21 +00002220 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002221 self.ser,
2222 validator_fcns,
2223 error_name,
2224 op=op,
2225 a=a,
2226 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002227 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002228 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002229 ):
2230 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002231
Eric Kunzee5e26762020-10-13 16:11:07 -07002232 return result_tens
2233
Matthew Haddon630c17c2021-10-14 15:05:41 +01002234 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002235 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002236
Kevin Cheng550ccc52021-03-03 11:21:43 -08002237 cond_block = "COND_BLOCK"
2238 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002239
2240 attr = ts.TosaSerializerAttribute()
2241 attr.WhileLoopAttribute(cond_block, body_block)
2242
2243 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002244 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002245 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002246 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002247
2248 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002249 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2250 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002251 if error_name == ErrorIf.InputListOutputListMismatch:
2252 incorrect_acc = deepcopy(acc)
2253 for i in range(len(incorrect_acc.shape)):
2254 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2255 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2256 else:
2257 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002258
2259 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002260 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002261 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002262 [iter.name, a.name, acc.name],
2263 [iter_out.name, a_out.name, acc_out.name],
2264 attr,
2265 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002266 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002267
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002268 if error_name in [
2269 ErrorIf.InputListCondGraphMismatch,
2270 ErrorIf.InputListBodyGraphInputMismatch,
2271 ErrorIf.InputListBodyGraphOutputMismatch,
2272 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002273 incorrect_iter = deepcopy(iter)
2274 for i in range(len(incorrect_iter.shape)):
2275 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2276 if len(incorrect_iter.shape) == 0:
2277 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2278
2279 incorrect_acc = deepcopy(acc)
2280 for i in range(len(incorrect_acc.shape)):
2281 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2282
Eric Kunzee5e26762020-10-13 16:11:07 -07002283 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002284 self.ser.addBasicBlock(cond_block)
2285
Matthew Haddon630c17c2021-10-14 15:05:41 +01002286 if error_name == ErrorIf.InputListCondGraphMismatch:
2287 self.ser.addInputTensor(incorrect_iter)
2288 self.ser.addInputTensor(a)
2289 self.ser.addInputTensor(incorrect_acc)
2290 else:
2291 self.ser.addInputTensor(iter)
2292 self.ser.addInputTensor(a)
2293 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002294 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002295
2296 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002297 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002298 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002299 cond_type = DType.BOOL
2300 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2301 choice = self.rng.choice([1, 2])
2302 if choice == 1:
2303 cond_shape = [3]
2304 else:
2305 cond_shape = [1, 2]
2306 else:
2307 cond_shape = []
2308 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002309
Kevin Cheng550ccc52021-03-03 11:21:43 -08002310 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002311
2312 # BODY block (input: a, acc, iter, output: a, acc, iter)
2313 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002314 self.ser.addBasicBlock(body_block)
2315
Matthew Haddon630c17c2021-10-14 15:05:41 +01002316 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2317 self.ser.addInputTensor(incorrect_iter)
2318 self.ser.addInputTensor(a)
2319 self.ser.addInputTensor(incorrect_acc)
2320 else:
2321 self.ser.addInputTensor(iter)
2322 self.ser.addInputTensor(a)
2323 self.ser.addInputTensor(acc)
2324
Kevin Cheng550ccc52021-03-03 11:21:43 -08002325 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002326
2327 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002328 iter_body_out = self.ser.addIntermediate(
2329 incorrect_iter.shape, incorrect_iter.dtype
2330 )
2331 acc_body_out = self.ser.addIntermediate(
2332 incorrect_acc.shape, incorrect_acc.dtype
2333 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002334 else:
2335 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2336 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2337
Eric Kunzee5e26762020-10-13 16:11:07 -07002338 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2339 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2340 self.ser.addOutputTensor(iter_body_out)
2341 self.ser.addOutputTensor(a)
2342 self.ser.addOutputTensor(acc_body_out)
2343
Les Bell729b0352021-11-24 10:28:21 +00002344 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002345 self.ser,
2346 validator_fcns,
2347 error_name,
2348 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002349 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002350 ):
2351 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002352
Eric Kunzee5e26762020-10-13 16:11:07 -07002353 return acc_out
2354
Luke Hutton57287132023-02-06 14:54:18 +00002355 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002356 self,
2357 op,
2358 val1,
2359 val2,
2360 inverse,
2361 validator_fcns=None,
2362 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002363 ):
2364 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2365
2366 input_names = [val1.name, val2.name]
2367 pCount, cCount = op["operands"]
2368 num_operands = pCount + cCount
2369
2370 output_names = [res.name for res in results]
2371 output_shapes = [res.shape for res in results]
2372 output_dtypes = [res.dtype for res in results]
2373
2374 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2375 self, error_name, input_names, output_names
2376 )
2377
2378 if not TosaErrorValidator.evValidateErrorIfs(
2379 self.ser,
2380 validator_fcns,
2381 error_name,
2382 op=op,
2383 inverse=inverse,
2384 input1=val1,
2385 input2=val2,
2386 input_shape=val1.shape,
2387 input_dtype=val1.dtype,
2388 output_shape=output_shapes,
2389 output_dtype=output_dtypes,
2390 result_tensors=results,
2391 input_list=input_names,
2392 output_list=output_names,
2393 num_operands=num_operands,
2394 ):
2395 return None
2396
Tai Lyd3797f02023-11-15 23:06:19 +00002397 # TODO - Test local_bound, for now set local bound attribute to False
2398 local_bound = False
2399
Luke Hutton57287132023-02-06 14:54:18 +00002400 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002401 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002402
2403 self.ser.addOperator(op["op"], input_names, output_names, attr)
2404 return results
2405
Tai Lyd3797f02023-11-15 23:06:19 +00002406 def build_rfft2d(
2407 self,
2408 op,
2409 val,
2410 validator_fcns=None,
2411 error_name=None,
2412 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002413 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2414
2415 input_names = [val.name]
2416 pCount, cCount = op["operands"]
2417 num_operands = pCount + cCount
2418
2419 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002420 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002421 output_dtypes = [res.dtype for res in results]
2422
2423 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2424 self, error_name, input_names, output_names
2425 )
2426
2427 if not TosaErrorValidator.evValidateErrorIfs(
2428 self.ser,
2429 validator_fcns,
2430 error_name,
2431 op=op,
2432 input_shape=val.shape,
2433 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002434 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002435 output_dtype=output_dtypes,
2436 result_tensors=results,
2437 input_list=input_names,
2438 output_list=output_names,
2439 num_operands=num_operands,
2440 ):
2441 return None
2442
Tai Lyd3797f02023-11-15 23:06:19 +00002443 # TODO - Test local_bound, for now set local bound attribute to False
2444 local_bound = False
2445
2446 attr = ts.TosaSerializerAttribute()
2447 attr.RFFTAttribute(local_bound)
2448
2449 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002450 return results
2451
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002452 def create_filter_lists(
2453 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2454 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002455 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2456 default_test_rank_range = range(1, 5)
2457 if not shapeFilter:
2458 shapeFilter = [None]
2459
2460 # Calculate the filters based on what is requested and what the operator allows
2461 rmin, rmax = op["rank"]
2462 if rankFilter is not None:
2463 cleanRankFilter = []
2464 # Ensure rankFilter values are allowed by operator
2465 for rank in rankFilter:
2466 if rank >= rmin and rank <= rmax:
2467 cleanRankFilter.append(rank)
2468 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002469 # Ensure default behaviour is bounded by default range or by operator,
2470 # whichever is the smaller range of ranks.
2471 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002472 cleanRankFilter = (
2473 opRankRange
2474 if len(opRankRange) <= len(default_test_rank_range)
2475 else default_test_rank_range
2476 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002477 else:
2478 cleanRankFilter = range(rmin, rmax + 1)
2479
2480 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002481
Matthew Haddon1c00b712021-10-01 15:51:03 +01002482 if dtypeFilter is not None:
2483 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002484 # Create list of operator dtypes filtered by requested dtypes
2485 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002486 if dtype in dtypeFilter or (
2487 isinstance(dtype, list) and dtype[0] in dtypeFilter
2488 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002489 cleanDtypeFilter.append(dtype)
2490 else:
2491 cleanDtypeFilter = dtypes
2492
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002493 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002494 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002495 "shapeFilter": shapeFilter,
2496 "rankFilter": cleanRankFilter,
2497 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002498 }
2499 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002500 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002501 if validator is not None:
2502 validator_info = validator(check=False, op=op)
2503 else:
2504 return None
2505
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002506 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002507
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002508 # Set parameters as required
2509 if error_arguments["rank"] is not None:
2510 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002511 else:
2512 rankFilter = cleanRankFilter
2513
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002514 if error_arguments["dtype"] is not None:
2515 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002516 else:
2517 dtypeFilter = cleanDtypeFilter
2518
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002519 if error_arguments["shape"] is not None:
2520 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002521 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002522 shapeFilter = shapeFilter[
2523 :2
2524 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002525
2526 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002527 "shapeFilter": shapeFilter,
2528 "rankFilter": rankFilter,
2529 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002530 }
2531 return filterDict
2532
Kevin Cheng550ccc52021-03-03 11:21:43 -08002533 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002534 self,
2535 opName,
2536 shapeFilter=[None],
2537 rankFilter=None,
2538 dtypeFilter=None,
2539 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002540 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002541
2542 try:
2543 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002544 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002545 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002546
2547 # Initialize a new random number generator
2548 self.rng = np.random.default_rng(self.random_seed)
2549
Jeremy Johnson1271c442023-09-05 11:39:26 +01002550 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002551
Eric Kunzee5e26762020-10-13 16:11:07 -07002552 # Test list consists of a tuple of:
2553 # (opName, testNameStr, dtype, shapeList, argumentsList)
2554 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002555 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002556 error_if_validators = op["error_if_validators"]
2557 else:
2558 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002559
Matthew Haddon1c00b712021-10-01 15:51:03 +01002560 for validator in error_if_validators:
2561 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002562 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002563 else:
2564 error_name = None
2565
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002566 filterDict = self.create_filter_lists(
2567 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2568 )
2569 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002570 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002571 cleanRankFilter = filterDict["rankFilter"]
2572 cleanDtypeFilter = filterDict["dtypeFilter"]
2573 cleanShapeFilter = filterDict["shapeFilter"]
2574 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002575
2576 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002577 for t in cleanDtypeFilter:
2578 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002579 # Filter out by rank
2580 if shape is not None and len(shape) != r:
2581 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002582 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002583 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002584
Matthew Haddon74567092021-07-16 15:38:20 +01002585 shapeStr = self.shapeStr(shapeList[0])
2586 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002587
Matthew Haddon74567092021-07-16 15:38:20 +01002588 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2589 argList = []
2590 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002591 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002592 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002593 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002594
Matthew Haddon74567092021-07-16 15:38:20 +01002595 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002596 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002597 if argStr:
2598 testStr = "{}_{}_{}_{}".format(
2599 opName, shapeStr, typeStr, argStr
2600 )
2601 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002602 testStr = "{}_{}_{}".format(
2603 opName, shapeStr, typeStr
2604 )
2605 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002606 if argStr:
2607 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2608 opName, error_name, shapeStr, typeStr, argStr
2609 )
2610 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002611 testStr = "{}_ERRORIF_{}_{}_{}".format(
2612 opName, error_name, shapeStr, typeStr
2613 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002614
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002615 testList.append(
2616 (opName, testStr, t, error_name, shapeList, args)
2617 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002618
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002619 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002620 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2621 if "invalid_test_validators" in op:
2622 invalid_test_validators = op["invalid_test_validators"]
2623 clean_testList = []
2624 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002625 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002626 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002627 if validator_fcn(
2628 opName=test[0],
2629 input_dtype=test[2],
2630 shapeList=test[4],
2631 args=test[5],
2632 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002633 remove_test = True
2634 if not remove_test:
2635 clean_testList.append(test)
2636 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002637
2638 return testList
2639
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002640 def serializeTest(
2641 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2642 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002643 try:
2644 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002645 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002646 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002647
Jeremy Johnson0c716862023-04-13 17:18:19 +01002648 if self.args.verbose:
2649 print(f"Creating {testStr}")
2650
Eric Kunzee5e26762020-10-13 16:11:07 -07002651 # Create a serializer
2652 self.createSerializer(opName, testStr)
2653
Jeremy Johnson1271c442023-09-05 11:39:26 +01002654 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002655 if "error_if_validators" in op:
2656 error_if_validators = op["error_if_validators"]
2657 else:
2658 error_if_validators = None
2659
Kevin Cheng550ccc52021-03-03 11:21:43 -08002660 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002661 num_operands = pCount + cCount
2662
2663 if isinstance(dtype_or_dtypeList, list):
2664 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002665 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002666 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002667 else:
2668 dtypeList = [dtype_or_dtypeList] * (num_operands)
2669
Kevin Cheng93a16282021-08-31 16:14:03 -07002670 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002671 assert (
2672 len(shapeList) == num_operands
2673 ), "shapeList length {} must match number of operands {}".format(
2674 len(shapeList), num_operands
2675 )
2676 assert (
2677 len(dtypeList) == num_operands
2678 ), "dtypeList length {} must match number of operands {}".format(
2679 len(dtypeList), num_operands
2680 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002681
2682 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002683 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002684 except KeyError:
2685 qgen = None
2686
2687 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002688
Matthew Haddon1c00b712021-10-01 15:51:03 +01002689 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002690 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002691 else:
2692 qinfo = None
2693
Jeremy Johnson1271c442023-09-05 11:39:26 +01002694 # Extra meta data for the desc.json
2695 tensMeta = {}
2696
2697 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002698 if isinstance(testArgs, dict):
2699 # New interface with args info in dictionary
2700 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002701 assert "dg_type" in argsDict
2702 tvgInfo = tvgen_fcn(
2703 self, opName, dtypeList, shapeList, argsDict, error_name
2704 )
2705 if tvgInfo.dataGenDict:
2706 tensMeta["data_gen"] = tvgInfo.dataGenDict
2707 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002708
2709 result = build_fcn(
2710 self,
2711 op,
2712 tens,
2713 argsDict,
2714 validator_fcns=error_if_validators,
2715 error_name=error_name,
2716 qinfo=qinfo,
2717 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002718 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002719 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002720 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002721
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002722 try:
2723 if error_if_validators is None:
2724 if qinfo is not None:
2725 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2726 else:
2727 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002728 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002729 if qinfo is not None:
2730 result = build_fcn(
2731 self,
2732 op,
2733 *tens,
2734 *testArgs,
2735 validator_fcns=error_if_validators,
2736 error_name=error_name,
2737 qinfo=qinfo,
2738 )
2739 else:
2740 result = build_fcn(
2741 self,
2742 op,
2743 *tens,
2744 *testArgs,
2745 validator_fcns=error_if_validators,
2746 error_name=error_name,
2747 )
2748 except TypeError as e:
2749 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2750 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002751
Jeremy Johnson1271c442023-09-05 11:39:26 +01002752 if result:
Les Bell729b0352021-11-24 10:28:21 +00002753 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002754 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2755 # Add the compliance meta data
2756 # NOTE: This currently expects only one result output
2757 tensMeta["compliance"] = {
2758 "version": "0.1",
2759 "tensors": {result.resultTensor.name: result.complianceDict},
2760 }
2761 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002762 else:
2763 # The test is not valid
2764 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002765
Eric Kunzee5e26762020-10-13 16:11:07 -07002766 def createDynamicOpLists(self):
2767
Jeremy Johnson00423432022-09-12 17:27:37 +01002768 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2769 # Already created these lists (can occur when class is initialized more than once)
2770 return
2771
Eric Kunzee5e26762020-10-13 16:11:07 -07002772 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002773 if not self.args.level8k:
2774 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2775 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2776 else:
2777 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2778 KERNELS_2D = [[1, bigK], [bigK, 2]]
2779 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002780
Kevin Cheng1533b852021-09-01 12:51:58 -07002781 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002782 testName = "conv2d_{}x{}".format(k[0], k[1])
2783 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2784 self.TOSA_OP_LIST[testName]["filter"] = k
2785 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002786
Kevin Cheng550ccc52021-03-03 11:21:43 -08002787 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2788 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2789 "depthwise_conv2d_TEMPLATE"
2790 ].copy()
2791 self.TOSA_OP_LIST[testName]["filter"] = k
2792 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002793
Kevin Cheng550ccc52021-03-03 11:21:43 -08002794 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2795 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2796 "transpose_conv2d_TEMPLATE"
2797 ].copy()
2798 self.TOSA_OP_LIST[testName]["filter"] = k
2799 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002800
Kevin Cheng1533b852021-09-01 12:51:58 -07002801 for k in KERNELS_3D:
2802 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2803 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2804 self.TOSA_OP_LIST[testName]["filter"] = k
2805 self.TOSA_OP_LIST[testName]["template"] = False
2806
Eric Kunzee5e26762020-10-13 16:11:07 -07002807 # Delete any templates after having created any dynamic ops
2808 # This is a two-pass operation because it's bad practice to delete
2809 # keys from dictionaries while iterating
2810 keyList = []
2811 for k in self.TOSA_OP_LIST:
2812 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002813 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002814 keyList.append(k)
2815 continue
2816 except KeyError:
2817 pass
2818
2819 for k in keyList:
2820 del self.TOSA_OP_LIST[k]
2821
2822 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002823 """Fill in default fields for ops if they aren't already specified.
2824 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002825 for op in self.TOSA_OP_LIST:
2826
2827 # Required fields
2828 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002829 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002830 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002831 raise Exception(
2832 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2833 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002834
2835 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002836 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002837 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002838 raise Exception(
2839 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2840 op
2841 )
2842 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002843
2844 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002845 _ = self.TOSA_OP_LIST[op]["types"]
2846 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002847 raise Exception(
2848 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2849 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002850
2851 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002852 _ = self.TOSA_OP_LIST[op]["op"]
2853 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002854 raise Exception(
2855 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2856 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002857
2858 # Put in default rank range, if missing
2859 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002860 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002861 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002862 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002863
2864 # Tensor operator list
2865 # 'op': op name
2866 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002867 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2868 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002869 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2870 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002871 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002872
Kevin Cheng550ccc52021-03-03 11:21:43 -08002873 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002874 TYPE_INT_FP = [
2875 DType.INT8,
2876 DType.INT16,
2877 DType.INT32,
2878 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002879 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002880 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002881 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002882
Kevin Cheng550ccc52021-03-03 11:21:43 -08002883 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002884 TYPE_FI32 = [
2885 DType.FP32,
2886 DType.FP16,
2887 DType.BF16,
2888 DType.INT32,
2889 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002890 TYPE_FIB = [
2891 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002892 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002893 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002894 DType.INT8,
2895 DType.INT16,
2896 DType.INT32,
2897 DType.BOOL,
2898 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002899 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002900
James Ward24dbc422022-10-19 12:20:31 +01002901 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002902
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002903 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002904 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002905 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002906 [DType.INT8, DType.INT8, DType.INT32],
2907 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002908 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002909 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002910 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002911 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002912 ]
2913
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002914 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002915
2916 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002917 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002918 "argmax": {
2919 "op": Op.ARGMAX,
2920 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002921 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002922 "build_fcn": (
2923 build_argmax,
2924 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002925 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002926 TosaArgGen.agAxis,
2927 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002928 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002929 "error_if_validators": (
2930 TosaErrorValidator.evAxisSmallerZero,
2931 TosaErrorValidator.evAxisLargerRank,
2932 TosaErrorValidator.evArgmaxOutputRankMismatch,
2933 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2934 TosaErrorValidator.evWrongRank,
2935 TosaErrorValidator.evWrongInputType,
2936 TosaErrorValidator.evWrongOutputType,
2937 TosaErrorValidator.evWrongInputList,
2938 TosaErrorValidator.evWrongOutputList,
2939 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002940 "data_gen": {
2941 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
2942 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002943 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002944 "avg_pool2d": {
2945 "op": Op.AVG_POOL2D,
2946 "operands": (1, 0),
2947 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002948 "build_fcn": (
2949 build_pool2d,
2950 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002951 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002952 TosaArgGen.agPooling,
2953 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002954 "qgen": TosaQuantGen.qgUnary,
2955 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002956 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002957 "error_if_validators": (
2958 TosaErrorValidator.evKernelSmallerOne,
2959 TosaErrorValidator.evStrideSmallerOne,
2960 TosaErrorValidator.evPadSmallerZero,
2961 TosaErrorValidator.evWrongRank,
2962 TosaErrorValidator.evWrongInputType,
2963 TosaErrorValidator.evWrongOutputType,
2964 TosaErrorValidator.evWrongInputList,
2965 TosaErrorValidator.evWrongOutputList,
2966 TosaErrorValidator.evInputZeroPointNotZero,
2967 TosaErrorValidator.evOutputZeroPointNotZero,
2968 TosaErrorValidator.evPadLargerEqualKernel,
2969 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002970 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002971 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00002972 "data_gen": {
2973 "fp": (gtu.DataGenType.DOT_PRODUCT,),
2974 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002975 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002976 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002977 "conv2d_TEMPLATE": {
2978 "op": Op.CONV2D,
2979 "operands": (1, 2),
2980 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002981 "build_fcn": (
2982 build_conv2d,
2983 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002984 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002985 TosaArgGen.agConv,
2986 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002987 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002988 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002989 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2990 "error_if_validators": (
2991 TosaErrorValidator.evWrongInputType,
2992 TosaErrorValidator.evWrongOutputType,
2993 TosaErrorValidator.evWrongInputList,
2994 TosaErrorValidator.evWrongOutputList,
2995 TosaErrorValidator.evInputZeroPointNotZero,
2996 TosaErrorValidator.evWeightZeroPointNotZero,
2997 TosaErrorValidator.evPadSmallerZero,
2998 TosaErrorValidator.evStrideSmallerOne,
2999 TosaErrorValidator.evDilationSmallerOne,
3000 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003001 TosaErrorValidator.evConvOutputShapeMismatch,
3002 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003003 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003004 "data_gen": {
3005 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3006 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003007 "template": True,
3008 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003009 # Templated operator. Filled in by createDynamicOpLists
3010 "conv3d_TEMPLATE": {
3011 "op": Op.CONV3D,
3012 "operands": (1, 2),
3013 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003014 "build_fcn": (
3015 build_conv3d,
3016 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003017 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003018 TosaArgGen.agConv,
3019 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003020 "qgen": TosaQuantGen.qgConv,
3021 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003022 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3023 "error_if_validators": (
3024 TosaErrorValidator.evWrongInputType,
3025 TosaErrorValidator.evWrongOutputType,
3026 TosaErrorValidator.evWrongInputList,
3027 TosaErrorValidator.evWrongOutputList,
3028 TosaErrorValidator.evInputZeroPointNotZero,
3029 TosaErrorValidator.evWeightZeroPointNotZero,
3030 TosaErrorValidator.evPadSmallerZero,
3031 TosaErrorValidator.evStrideSmallerOne,
3032 TosaErrorValidator.evDilationSmallerOne,
3033 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003034 TosaErrorValidator.evConvOutputShapeMismatch,
3035 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003036 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003037 "template": True,
3038 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003039 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003040 "depthwise_conv2d_TEMPLATE": {
3041 "op": Op.DEPTHWISE_CONV2D,
3042 "operands": (1, 2),
3043 "filter": [1, 1],
3044 "rank": (4, 4),
3045 "build_fcn": (
3046 build_depthwise_conv2d,
3047 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003048 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003049 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003050 ),
3051 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003052 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003053 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3054 "error_if_validators": (
3055 TosaErrorValidator.evWrongInputType,
3056 TosaErrorValidator.evWrongOutputType,
3057 TosaErrorValidator.evWrongInputList,
3058 TosaErrorValidator.evWrongOutputList,
3059 TosaErrorValidator.evInputZeroPointNotZero,
3060 TosaErrorValidator.evWeightZeroPointNotZero,
3061 TosaErrorValidator.evPadSmallerZero,
3062 TosaErrorValidator.evStrideSmallerOne,
3063 TosaErrorValidator.evDilationSmallerOne,
3064 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003065 TosaErrorValidator.evConvOutputShapeMismatch,
3066 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003067 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003068 "template": True,
3069 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003070 "fully_connected": {
3071 "op": Op.FULLY_CONNECTED,
3072 "operands": (1, 2),
3073 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003074 "build_fcn": (
3075 build_fully_connected,
3076 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003077 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003078 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003079 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003080 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003081 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003082 "error_if_validators": (
3083 TosaErrorValidator.evInputZeroPointNotZero,
3084 TosaErrorValidator.evWeightZeroPointNotZero,
3085 TosaErrorValidator.evWrongRank,
3086 TosaErrorValidator.evWrongInputType,
3087 TosaErrorValidator.evWrongOutputType,
3088 TosaErrorValidator.evWrongInputList,
3089 TosaErrorValidator.evWrongOutputList,
3090 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003091 "data_gen": {
3092 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3093 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003094 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003095 "matmul": {
3096 "op": Op.MATMUL,
3097 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003098 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003099 "build_fcn": (
3100 build_matmul,
3101 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003102 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003103 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003104 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003105 "qgen": TosaQuantGen.qgMatmul,
3106 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003107 "error_if_validators": (
3108 TosaErrorValidator.evInputZeroPointNotZero,
3109 TosaErrorValidator.evWrongRank,
3110 TosaErrorValidator.evWrongInputType,
3111 TosaErrorValidator.evWrongOutputType,
3112 TosaErrorValidator.evWrongInputList,
3113 TosaErrorValidator.evWrongOutputList,
3114 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003115 "data_gen": {
3116 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003117 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003118 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003119 "max_pool2d": {
3120 "op": Op.MAX_POOL2D,
3121 "operands": (1, 0),
3122 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003123 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003124 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003125 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003126 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003127 TosaArgGen.agPooling,
3128 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003129 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003130 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003131 "error_if_validators": (
3132 TosaErrorValidator.evKernelSmallerOne,
3133 TosaErrorValidator.evStrideSmallerOne,
3134 TosaErrorValidator.evPadSmallerZero,
3135 TosaErrorValidator.evWrongRank,
3136 TosaErrorValidator.evWrongInputType,
3137 TosaErrorValidator.evWrongOutputType,
3138 TosaErrorValidator.evWrongInputList,
3139 TosaErrorValidator.evWrongOutputList,
3140 TosaErrorValidator.evPadLargerEqualKernel,
3141 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003142 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003143 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003144 "data_gen": {
3145 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3146 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003147 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003148 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003149 "transpose_conv2d_TEMPLATE": {
3150 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003151 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003152 "rank": (4, 4),
3153 "build_fcn": (
3154 build_transpose_conv2d,
3155 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003156 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003157 TosaArgGen.agTransposeConv2D,
3158 ),
3159 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003160 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003161 "invalid_test_validators": (
3162 TosaInvalidValidator.ivHeightWidthInvalid,
3163 TosaInvalidValidator.ivNonPositiveOutputShape,
3164 ),
3165 "error_if_validators": (
3166 TosaErrorValidator.evWrongInputType,
3167 TosaErrorValidator.evWrongOutputType,
3168 TosaErrorValidator.evWrongInputList,
3169 TosaErrorValidator.evWrongOutputList,
3170 TosaErrorValidator.evInputZeroPointNotZero,
3171 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003172 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003173 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003174 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003175 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003176 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003177 "template": True,
3178 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003179 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003180 "clamp": {
3181 "op": Op.CLAMP,
3182 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003183 "build_fcn": (
3184 build_clamp,
3185 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003186 TosaTensorValuesGen.tvgLazyGenDefault,
3187 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003188 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003189 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003190 "error_if_validators": (
3191 TosaErrorValidator.evMaxSmallerMin,
3192 TosaErrorValidator.evWrongInputType,
3193 TosaErrorValidator.evWrongOutputType,
3194 TosaErrorValidator.evWrongInputList,
3195 TosaErrorValidator.evWrongOutputList,
3196 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003197 "data_gen": {
3198 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3199 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003200 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003201 "sigmoid": {
3202 "op": Op.SIGMOID,
3203 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003204 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003205 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003206 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003207 TosaTensorValuesGen.tvgLazyGenDefault,
3208 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003209 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003210 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003211 "error_if_validators": (
3212 TosaErrorValidator.evWrongInputType,
3213 TosaErrorValidator.evWrongOutputType,
3214 TosaErrorValidator.evWrongInputList,
3215 TosaErrorValidator.evWrongOutputList,
3216 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003217 "data_gen": {
3218 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3219 },
3220 "compliance": {"ulp": 5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08003221 },
3222 "tanh": {
3223 "op": Op.TANH,
3224 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003225 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003226 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003227 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003228 TosaTensorValuesGen.tvgLazyGenDefault,
3229 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003230 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003231 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003232 "error_if_validators": (
3233 TosaErrorValidator.evWrongInputType,
3234 TosaErrorValidator.evWrongOutputType,
3235 TosaErrorValidator.evWrongInputList,
3236 TosaErrorValidator.evWrongOutputList,
3237 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003238 "data_gen": {
3239 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3240 },
3241 "compliance": {"ulp": 5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08003242 },
Won Jeon78155c62023-06-10 00:20:04 +00003243 "erf": {
3244 "op": Op.ERF,
3245 "operands": (1, 0),
3246 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003247 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003248 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003249 TosaTensorValuesGen.tvgLazyGenDefault,
3250 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003251 ),
3252 "types": TYPE_FP,
3253 "error_if_validators": (
3254 TosaErrorValidator.evWrongInputType,
3255 TosaErrorValidator.evWrongOutputType,
3256 TosaErrorValidator.evWrongInputList,
3257 TosaErrorValidator.evWrongOutputList,
3258 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003259 "data_gen": {
3260 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3261 },
3262 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003263 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003264 # Elementwise Binary Operators
3265 "add": {
3266 "op": Op.ADD,
3267 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003268 "build_fcn": (
3269 build_binary_broadcast,
3270 TosaTensorGen.tgBroadcastFuzz,
3271 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003272 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003273 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003274 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003275 "error_if_validators": (
3276 TosaErrorValidator.evRankMismatch,
3277 TosaErrorValidator.evWrongInputType,
3278 TosaErrorValidator.evWrongOutputType,
3279 TosaErrorValidator.evWrongInputList,
3280 TosaErrorValidator.evWrongOutputList,
3281 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003282 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003283 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003284 "data_gen": {
3285 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3286 },
3287 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003288 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003289 "arithmetic_right_shift": {
3290 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3291 "operands": (2, 0),
3292 "build_fcn": (
3293 build_arithmetic_right_shift,
3294 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003295 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003296 TosaArgGen.agArithmeticRightShift,
3297 ),
3298 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003299 "error_if_validators": (
3300 TosaErrorValidator.evRankMismatch,
3301 TosaErrorValidator.evWrongInputType,
3302 TosaErrorValidator.evWrongOutputType,
3303 TosaErrorValidator.evWrongInputList,
3304 TosaErrorValidator.evWrongOutputList,
3305 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003306 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003307 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003308 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003309 "bitwise_and": {
3310 "op": Op.BITWISE_AND,
3311 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003312 "build_fcn": (
3313 build_binary_broadcast,
3314 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003315 TosaTensorValuesGen.tvgLazyGenDefault,
3316 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003317 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003318 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003319 "error_if_validators": (
3320 TosaErrorValidator.evRankMismatch,
3321 TosaErrorValidator.evWrongInputType,
3322 TosaErrorValidator.evWrongOutputType,
3323 TosaErrorValidator.evWrongInputList,
3324 TosaErrorValidator.evWrongOutputList,
3325 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003326 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003327 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003328 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003329 "bitwise_or": {
3330 "op": Op.BITWISE_OR,
3331 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003332 "build_fcn": (
3333 build_binary_broadcast,
3334 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003335 TosaTensorValuesGen.tvgLazyGenDefault,
3336 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003337 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003339 "error_if_validators": (
3340 TosaErrorValidator.evRankMismatch,
3341 TosaErrorValidator.evWrongInputType,
3342 TosaErrorValidator.evWrongOutputType,
3343 TosaErrorValidator.evWrongInputList,
3344 TosaErrorValidator.evWrongOutputList,
3345 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003346 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003347 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003348 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003349 "bitwise_xor": {
3350 "op": Op.BITWISE_XOR,
3351 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003352 "build_fcn": (
3353 build_binary_broadcast,
3354 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003355 TosaTensorValuesGen.tvgLazyGenDefault,
3356 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003357 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003358 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003359 "error_if_validators": (
3360 TosaErrorValidator.evRankMismatch,
3361 TosaErrorValidator.evWrongInputType,
3362 TosaErrorValidator.evWrongOutputType,
3363 TosaErrorValidator.evWrongInputList,
3364 TosaErrorValidator.evWrongOutputList,
3365 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003366 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003367 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003368 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003369 "intdiv": {
3370 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003371 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003372 "build_fcn": (
3373 build_binary_broadcast,
3374 TosaTensorGen.tgBroadcastFuzz,
3375 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003376 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003377 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003378 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003379 "error_if_validators": (
3380 TosaErrorValidator.evRankMismatch,
3381 TosaErrorValidator.evWrongInputType,
3382 TosaErrorValidator.evWrongOutputType,
3383 TosaErrorValidator.evWrongInputList,
3384 TosaErrorValidator.evWrongOutputList,
3385 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003386 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003387 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003388 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003389 "logical_and": {
3390 "op": Op.LOGICAL_AND,
3391 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003392 "build_fcn": (
3393 build_binary_broadcast,
3394 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003395 TosaTensorValuesGen.tvgLazyGenDefault,
3396 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003397 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003398 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003399 "error_if_validators": (
3400 TosaErrorValidator.evRankMismatch,
3401 TosaErrorValidator.evWrongInputType,
3402 TosaErrorValidator.evWrongOutputType,
3403 TosaErrorValidator.evWrongInputList,
3404 TosaErrorValidator.evWrongOutputList,
3405 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003406 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003407 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003408 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003409 "logical_left_shift": {
3410 "op": Op.LOGICAL_LEFT_SHIFT,
3411 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003412 "build_fcn": (
3413 build_binary_broadcast,
3414 TosaTensorGen.tgBroadcastFuzz,
3415 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003416 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003417 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003418 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003419 "error_if_validators": (
3420 TosaErrorValidator.evRankMismatch,
3421 TosaErrorValidator.evWrongInputType,
3422 TosaErrorValidator.evWrongOutputType,
3423 TosaErrorValidator.evWrongInputList,
3424 TosaErrorValidator.evWrongOutputList,
3425 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003426 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003427 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003428 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003429 "logical_right_shift": {
3430 "op": Op.LOGICAL_RIGHT_SHIFT,
3431 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003432 "build_fcn": (
3433 build_binary_broadcast,
3434 TosaTensorGen.tgBroadcastFuzz,
3435 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003436 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003437 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003438 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003439 "error_if_validators": (
3440 TosaErrorValidator.evRankMismatch,
3441 TosaErrorValidator.evWrongInputType,
3442 TosaErrorValidator.evWrongOutputType,
3443 TosaErrorValidator.evWrongInputList,
3444 TosaErrorValidator.evWrongOutputList,
3445 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003446 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003447 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003448 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003449 "logical_or": {
3450 "op": Op.LOGICAL_OR,
3451 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003452 "build_fcn": (
3453 build_binary_broadcast,
3454 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003455 TosaTensorValuesGen.tvgLazyGenDefault,
3456 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003457 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003458 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003459 "error_if_validators": (
3460 TosaErrorValidator.evRankMismatch,
3461 TosaErrorValidator.evWrongInputType,
3462 TosaErrorValidator.evWrongOutputType,
3463 TosaErrorValidator.evWrongInputList,
3464 TosaErrorValidator.evWrongOutputList,
3465 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003466 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003467 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003468 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003469 "logical_xor": {
3470 "op": Op.LOGICAL_XOR,
3471 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003472 "build_fcn": (
3473 build_binary_broadcast,
3474 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003475 TosaTensorValuesGen.tvgLazyGenDefault,
3476 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003477 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003478 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003479 "error_if_validators": (
3480 TosaErrorValidator.evRankMismatch,
3481 TosaErrorValidator.evWrongInputType,
3482 TosaErrorValidator.evWrongOutputType,
3483 TosaErrorValidator.evWrongInputList,
3484 TosaErrorValidator.evWrongOutputList,
3485 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003486 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003487 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003488 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003489 "maximum": {
3490 "op": Op.MAXIMUM,
3491 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003492 "build_fcn": (
3493 build_binary_broadcast,
3494 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003495 TosaTensorValuesGen.tvgLazyGenDefault,
3496 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003497 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003498 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003499 "error_if_validators": (
3500 TosaErrorValidator.evRankMismatch,
3501 TosaErrorValidator.evWrongInputType,
3502 TosaErrorValidator.evWrongOutputType,
3503 TosaErrorValidator.evWrongInputList,
3504 TosaErrorValidator.evWrongOutputList,
3505 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003506 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003507 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003508 "data_gen": {
3509 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3510 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003511 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003512 "minimum": {
3513 "op": Op.MINIMUM,
3514 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003515 "build_fcn": (
3516 build_binary_broadcast,
3517 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003518 TosaTensorValuesGen.tvgLazyGenDefault,
3519 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003520 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003521 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003522 "error_if_validators": (
3523 TosaErrorValidator.evRankMismatch,
3524 TosaErrorValidator.evWrongInputType,
3525 TosaErrorValidator.evWrongOutputType,
3526 TosaErrorValidator.evWrongInputList,
3527 TosaErrorValidator.evWrongOutputList,
3528 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003529 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003530 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003531 "data_gen": {
3532 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3533 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003534 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 "mul": {
3536 "op": Op.MUL,
3537 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003538 "build_fcn": (
3539 build_mul,
3540 TosaTensorGen.tgBroadcastFuzz,
3541 TosaTensorValuesGen.tvgMul,
3542 TosaArgGen.agMul,
3543 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003544 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003545 "error_if_validators": (
3546 TosaErrorValidator.evWrongInputType,
3547 TosaErrorValidator.evWrongOutputType,
3548 TosaErrorValidator.evWrongInputList,
3549 TosaErrorValidator.evWrongOutputList,
3550 TosaErrorValidator.evRankMismatch,
3551 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003552 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003553 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003554 "data_gen": {
3555 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3556 },
3557 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003558 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003559 "pow": {
3560 "op": Op.POW,
3561 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003562 "build_fcn": (
3563 build_binary_broadcast,
3564 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003565 TosaTensorValuesGen.tvgPow,
3566 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003567 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003568 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003569 "error_if_validators": (
3570 TosaErrorValidator.evRankMismatch,
3571 TosaErrorValidator.evWrongInputType,
3572 TosaErrorValidator.evWrongOutputType,
3573 TosaErrorValidator.evWrongInputList,
3574 TosaErrorValidator.evWrongOutputList,
3575 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003576 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003577 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003578 "data_gen": {
3579 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3580 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003581 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003582 "sub": {
3583 "op": Op.SUB,
3584 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003585 "build_fcn": (
3586 build_binary_broadcast,
3587 TosaTensorGen.tgBroadcastFuzz,
3588 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003589 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003590 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003591 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003592 "error_if_validators": (
3593 TosaErrorValidator.evRankMismatch,
3594 TosaErrorValidator.evWrongInputType,
3595 TosaErrorValidator.evWrongOutputType,
3596 TosaErrorValidator.evWrongInputList,
3597 TosaErrorValidator.evWrongOutputList,
3598 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003599 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003600 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003601 "data_gen": {
3602 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3603 },
3604 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003605 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003606 "table": {
3607 "op": Op.TABLE,
3608 # Use the automatic generation functions to create the input array
3609 # but create the table tensor in the build function, as it may be
3610 # a different type from the input
3611 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003612 "build_fcn": (
3613 build_table,
3614 TosaTensorGen.tgBasic,
3615 TosaTensorValuesGen.tvgDefault,
3616 TosaArgGen.agTable,
3617 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003618 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003619 "error_if_validators": (
3620 TosaErrorValidator.evWrongInputType,
3621 TosaErrorValidator.evWrongOutputType,
3622 TosaErrorValidator.evWrongInputList,
3623 TosaErrorValidator.evWrongOutputList,
3624 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003625 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003626 # Elementwise Unary operators
3627 "abs": {
3628 "op": Op.ABS,
3629 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003630 "build_fcn": (
3631 build_unary,
3632 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003633 TosaTensorValuesGen.tvgLazyGenDefault,
3634 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003635 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003636 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003637 "error_if_validators": (
3638 TosaErrorValidator.evWrongInputType,
3639 TosaErrorValidator.evWrongOutputType,
3640 TosaErrorValidator.evWrongInputList,
3641 TosaErrorValidator.evWrongOutputList,
3642 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003643 "data_gen": {
3644 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3645 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003646 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003647 "bitwise_not": {
3648 "op": Op.BITWISE_NOT,
3649 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003650 "build_fcn": (
3651 build_unary,
3652 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003653 TosaTensorValuesGen.tvgLazyGenDefault,
3654 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003655 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003656 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003657 "error_if_validators": (
3658 TosaErrorValidator.evWrongInputType,
3659 TosaErrorValidator.evWrongOutputType,
3660 TosaErrorValidator.evWrongInputList,
3661 TosaErrorValidator.evWrongOutputList,
3662 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003663 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003664 "ceil": {
3665 "op": Op.CEIL,
3666 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003667 "build_fcn": (
3668 build_unary,
3669 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003670 TosaTensorValuesGen.tvgLazyGenDefault,
3671 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003672 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003673 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003674 "error_if_validators": (
3675 TosaErrorValidator.evWrongInputType,
3676 TosaErrorValidator.evWrongOutputType,
3677 TosaErrorValidator.evWrongInputList,
3678 TosaErrorValidator.evWrongOutputList,
3679 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003680 "data_gen": {
3681 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3682 },
3683 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003684 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003685 "clz": {
3686 "op": Op.CLZ,
3687 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003688 "build_fcn": (
3689 build_unary,
3690 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003691 TosaTensorValuesGen.tvgLazyGenDefault,
3692 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003693 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003694 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003695 "error_if_validators": (
3696 TosaErrorValidator.evWrongInputType,
3697 TosaErrorValidator.evWrongOutputType,
3698 TosaErrorValidator.evWrongInputList,
3699 TosaErrorValidator.evWrongOutputList,
3700 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003701 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003702 "exp": {
3703 "op": Op.EXP,
3704 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003705 "build_fcn": (
3706 build_unary,
3707 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003708 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003709 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003710 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003711 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003712 "error_if_validators": (
3713 TosaErrorValidator.evWrongInputType,
3714 TosaErrorValidator.evWrongOutputType,
3715 TosaErrorValidator.evWrongInputList,
3716 TosaErrorValidator.evWrongOutputList,
3717 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003718 "data_gen": {
3719 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3720 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003721 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003722 "floor": {
3723 "op": Op.FLOOR,
3724 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003725 "build_fcn": (
3726 build_unary,
3727 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003728 TosaTensorValuesGen.tvgLazyGenDefault,
3729 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003730 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003731 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003732 "error_if_validators": (
3733 TosaErrorValidator.evWrongInputType,
3734 TosaErrorValidator.evWrongOutputType,
3735 TosaErrorValidator.evWrongInputList,
3736 TosaErrorValidator.evWrongOutputList,
3737 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003738 "data_gen": {
3739 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3740 },
3741 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003742 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003743 "log": {
3744 "op": Op.LOG,
3745 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003746 "build_fcn": (
3747 build_unary,
3748 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003749 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003750 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003751 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003752 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003753 "error_if_validators": (
3754 TosaErrorValidator.evWrongInputType,
3755 TosaErrorValidator.evWrongOutputType,
3756 TosaErrorValidator.evWrongInputList,
3757 TosaErrorValidator.evWrongOutputList,
3758 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003759 "data_gen": {
3760 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3761 },
3762 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003763 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003764 "logical_not": {
3765 "op": Op.LOGICAL_NOT,
3766 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003767 "build_fcn": (
3768 build_unary,
3769 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003770 TosaTensorValuesGen.tvgLazyGenDefault,
3771 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003772 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003773 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003774 "error_if_validators": (
3775 TosaErrorValidator.evWrongInputType,
3776 TosaErrorValidator.evWrongOutputType,
3777 TosaErrorValidator.evWrongInputList,
3778 TosaErrorValidator.evWrongOutputList,
3779 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003780 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 "negate": {
3782 "op": Op.NEGATE,
3783 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003784 "build_fcn": (
3785 build_unary,
3786 TosaTensorGen.tgBasic,
3787 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003788 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003789 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003790 "qgen": TosaQuantGen.qgUnary,
3791 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003792 "error_if_validators": (
3793 TosaErrorValidator.evInputZeroPointNotZero,
3794 TosaErrorValidator.evOutputZeroPointNotZero,
3795 TosaErrorValidator.evWrongInputType,
3796 TosaErrorValidator.evWrongOutputType,
3797 TosaErrorValidator.evWrongInputList,
3798 TosaErrorValidator.evWrongOutputList,
3799 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003800 "data_gen": {
3801 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3802 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003803 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003804 "reciprocal": {
3805 "op": Op.RECIPROCAL,
3806 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003807 "build_fcn": (
3808 build_unary,
3809 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003810 TosaTensorValuesGen.tvgLazyGenDefault,
3811 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003812 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003813 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003814 "error_if_validators": (
3815 TosaErrorValidator.evWrongInputType,
3816 TosaErrorValidator.evWrongOutputType,
3817 TosaErrorValidator.evWrongInputList,
3818 TosaErrorValidator.evWrongOutputList,
3819 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003820 "data_gen": {
3821 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3822 },
3823 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003824 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003825 "rsqrt": {
3826 "op": Op.RSQRT,
3827 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003828 "build_fcn": (
3829 build_unary,
3830 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003831 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003832 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003833 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003834 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003835 "error_if_validators": (
3836 TosaErrorValidator.evWrongInputType,
3837 TosaErrorValidator.evWrongOutputType,
3838 TosaErrorValidator.evWrongInputList,
3839 TosaErrorValidator.evWrongOutputList,
3840 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003841 "data_gen": {
3842 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3843 },
3844 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003845 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003846 # Elementwise Ternary operators
3847 "select": {
3848 "op": Op.SELECT,
3849 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003850 "build_fcn": (
3851 build_select,
3852 TosaTensorGen.tgBroadcastFuzz,
3853 TosaTensorValuesGen.tvgSelect,
3854 None,
3855 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003856 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003857 "error_if_validators": (
3858 TosaErrorValidator.evRankMismatch,
3859 TosaErrorValidator.evWrongInputType,
3860 TosaErrorValidator.evWrongOutputType,
3861 TosaErrorValidator.evWrongInputList,
3862 TosaErrorValidator.evWrongOutputList,
3863 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003864 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003865 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003866 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003867 # Comparison operators
3868 "equal": {
3869 "op": Op.EQUAL,
3870 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003871 "build_fcn": (
3872 build_comparison,
3873 TosaTensorGen.tgBroadcastFuzz,
3874 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003875 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003876 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003877 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003878 "error_if_validators": (
3879 TosaErrorValidator.evRankMismatch,
3880 TosaErrorValidator.evWrongInputType,
3881 TosaErrorValidator.evWrongOutputType,
3882 TosaErrorValidator.evWrongInputList,
3883 TosaErrorValidator.evWrongOutputList,
3884 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003885 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003886 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003887 "data_gen": {
3888 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3889 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003890 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003891 "greater_equal": {
3892 "op": Op.GREATER_EQUAL,
3893 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003894 "build_fcn": (
3895 build_comparison,
3896 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003897 TosaTensorValuesGen.tvgLazyGenDefault,
3898 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003899 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003900 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003901 "error_if_validators": (
3902 TosaErrorValidator.evRankMismatch,
3903 TosaErrorValidator.evWrongInputType,
3904 TosaErrorValidator.evWrongOutputType,
3905 TosaErrorValidator.evWrongInputList,
3906 TosaErrorValidator.evWrongOutputList,
3907 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003908 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003909 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003910 "data_gen": {
3911 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3912 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003913 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003914 "greater": {
3915 "op": Op.GREATER,
3916 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003917 "build_fcn": (
3918 build_comparison,
3919 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003920 TosaTensorValuesGen.tvgLazyGenDefault,
3921 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003922 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003923 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003924 "error_if_validators": (
3925 TosaErrorValidator.evRankMismatch,
3926 TosaErrorValidator.evWrongInputType,
3927 TosaErrorValidator.evWrongOutputType,
3928 TosaErrorValidator.evWrongInputList,
3929 TosaErrorValidator.evWrongOutputList,
3930 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003931 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003932 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003933 "data_gen": {
3934 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3935 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003936 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003937 # Reduction operators
3938 "reduce_all": {
3939 "op": Op.REDUCE_ALL,
3940 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003941 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003942 "build_fcn": (
3943 build_reduce,
3944 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003945 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003946 TosaArgGen.agAxis,
3947 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003948 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003949 "error_if_validators": (
3950 TosaErrorValidator.evAxisLargerRank,
3951 TosaErrorValidator.evAxisSmallerZero,
3952 TosaErrorValidator.evShapeOfAxisNotOne,
3953 TosaErrorValidator.evWrongInputType,
3954 TosaErrorValidator.evWrongOutputType,
3955 TosaErrorValidator.evWrongRank,
3956 TosaErrorValidator.evWrongInputList,
3957 TosaErrorValidator.evWrongOutputList,
3958 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003959 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003960 "reduce_any": {
3961 "op": Op.REDUCE_ANY,
3962 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003963 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003964 "build_fcn": (
3965 build_reduce,
3966 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003967 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003968 TosaArgGen.agAxis,
3969 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003970 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003971 "error_if_validators": (
3972 TosaErrorValidator.evAxisLargerRank,
3973 TosaErrorValidator.evAxisSmallerZero,
3974 TosaErrorValidator.evShapeOfAxisNotOne,
3975 TosaErrorValidator.evWrongInputType,
3976 TosaErrorValidator.evWrongOutputType,
3977 TosaErrorValidator.evWrongRank,
3978 TosaErrorValidator.evWrongInputList,
3979 TosaErrorValidator.evWrongOutputList,
3980 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003981 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003982 "reduce_max": {
3983 "op": Op.REDUCE_MAX,
3984 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003985 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003986 "build_fcn": (
3987 build_reduce,
3988 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003989 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003990 TosaArgGen.agAxis,
3991 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003992 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003993 "error_if_validators": (
3994 TosaErrorValidator.evAxisLargerRank,
3995 TosaErrorValidator.evAxisSmallerZero,
3996 TosaErrorValidator.evShapeOfAxisNotOne,
3997 TosaErrorValidator.evWrongInputType,
3998 TosaErrorValidator.evWrongOutputType,
3999 TosaErrorValidator.evWrongRank,
4000 TosaErrorValidator.evWrongInputList,
4001 TosaErrorValidator.evWrongOutputList,
4002 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004003 "data_gen": {
4004 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4005 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004006 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004007 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004008 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004009 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004010 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004011 "build_fcn": (
4012 build_reduce,
4013 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004014 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004015 TosaArgGen.agAxis,
4016 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004017 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004018 "error_if_validators": (
4019 TosaErrorValidator.evAxisLargerRank,
4020 TosaErrorValidator.evAxisSmallerZero,
4021 TosaErrorValidator.evShapeOfAxisNotOne,
4022 TosaErrorValidator.evWrongInputType,
4023 TosaErrorValidator.evWrongOutputType,
4024 TosaErrorValidator.evWrongRank,
4025 TosaErrorValidator.evWrongInputList,
4026 TosaErrorValidator.evWrongOutputList,
4027 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004028 "data_gen": {
4029 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4030 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004031 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004032 "reduce_product": {
4033 "op": Op.REDUCE_PRODUCT,
4034 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004035 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004036 "build_fcn": (
4037 build_reduce,
4038 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004039 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004040 TosaArgGen.agAxis,
4041 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004042 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004043 "error_if_validators": (
4044 TosaErrorValidator.evAxisLargerRank,
4045 TosaErrorValidator.evAxisSmallerZero,
4046 TosaErrorValidator.evShapeOfAxisNotOne,
4047 TosaErrorValidator.evWrongInputType,
4048 TosaErrorValidator.evWrongOutputType,
4049 TosaErrorValidator.evWrongRank,
4050 TosaErrorValidator.evWrongInputList,
4051 TosaErrorValidator.evWrongOutputList,
4052 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004053 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004054 "reduce_sum": {
4055 "op": Op.REDUCE_SUM,
4056 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004057 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004058 "build_fcn": (
4059 build_reduce,
4060 TosaTensorGen.tgBasic,
4061 TosaTensorValuesGen.tvgReduceSum,
4062 TosaArgGen.agAxis,
4063 ),
James Ward24dbc422022-10-19 12:20:31 +01004064 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004065 "error_if_validators": (
4066 TosaErrorValidator.evAxisLargerRank,
4067 TosaErrorValidator.evAxisSmallerZero,
4068 TosaErrorValidator.evShapeOfAxisNotOne,
4069 TosaErrorValidator.evWrongInputType,
4070 TosaErrorValidator.evWrongOutputType,
4071 TosaErrorValidator.evWrongRank,
4072 TosaErrorValidator.evWrongInputList,
4073 TosaErrorValidator.evWrongOutputList,
4074 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004075 "data_gen": {
4076 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4077 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004078 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004079 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004080 "concat": {
4081 "op": Op.CONCAT,
4082 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004083 "build_fcn": (
4084 build_concat,
4085 TosaTensorGen.tgConcat,
4086 TosaTensorValuesGen.tvgConcat,
4087 TosaArgGen.agAxis,
4088 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004089 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004090 "error_if_validators": (
4091 TosaErrorValidator.evAxisLargerRank,
4092 TosaErrorValidator.evAxisSmallerZero,
4093 TosaErrorValidator.evConcatInputRankMismatch,
4094 TosaErrorValidator.evConcatShapeSumMismatch,
4095 TosaErrorValidator.evConcatInputDimMismatch,
4096 TosaErrorValidator.evWrongInputType,
4097 TosaErrorValidator.evWrongOutputType,
4098 TosaErrorValidator.evWrongOutputList,
4099 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004100 },
4101 "pad": {
4102 "op": Op.PAD,
4103 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004104 "build_fcn": (
4105 build_pad,
4106 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004107 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004108 TosaArgGen.agPad,
4109 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004110 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004111 "error_if_validators": (
4112 TosaErrorValidator.evWrongInputType,
4113 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004114 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004115 TosaErrorValidator.evWrongOutputType,
4116 TosaErrorValidator.evWrongInputList,
4117 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004118 TosaErrorValidator.evRankMismatch,
4119 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004120 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004121 "data_gen": {
4122 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4123 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004124 },
Won Jeona21b2e82023-08-10 10:33:01 +00004125 "dim": {
4126 "op": Op.DIM,
4127 "operands": (1, 0),
4128 "build_fcn": (
4129 build_dim,
4130 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004131 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004132 TosaArgGen.agAxis,
4133 ),
4134 "types": TYPE_FIB,
4135 "error_if_validators": (
4136 TosaErrorValidator.evAxisLargerRank,
4137 TosaErrorValidator.evAxisSmallerZero,
4138 TosaErrorValidator.evWrongInputType,
4139 TosaErrorValidator.evWrongInputList,
4140 TosaErrorValidator.evWrongOutputList,
4141 TosaErrorValidator.evWrongRank,
4142 ),
4143 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004144 "reshape": {
4145 "op": Op.RESHAPE,
4146 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004147 "build_fcn": (
4148 build_reshape,
4149 TosaTensorGen.tgBasic,
4150 TosaTensorValuesGen.tvgDefault,
4151 TosaArgGen.agReshape,
4152 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004153 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004154 "error_if_validators": (
4155 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4156 TosaErrorValidator.evWrongInputType,
4157 TosaErrorValidator.evWrongOutputType,
4158 TosaErrorValidator.evWrongInputList,
4159 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00004160 TosaErrorValidator.evReshapeOutputSizeMultiInference,
4161 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004162 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004163 },
4164 "reverse": {
4165 "op": Op.REVERSE,
4166 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004167 "build_fcn": (
4168 build_reverse,
4169 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004170 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004171 TosaArgGen.agAxis,
4172 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004173 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004174 "error_if_validators": (
4175 TosaErrorValidator.evAxisSmallerZero,
4176 TosaErrorValidator.evAxisLargerRank,
4177 TosaErrorValidator.evWrongInputType,
4178 TosaErrorValidator.evWrongOutputType,
4179 TosaErrorValidator.evWrongInputList,
4180 TosaErrorValidator.evWrongOutputList,
4181 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004182 },
4183 "slice": {
4184 "op": Op.SLICE,
4185 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004186 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004187 "build_fcn": (
4188 build_slice,
4189 TosaTensorGen.tgBasic,
4190 TosaTensorValuesGen.tvgDefault,
4191 TosaArgGen.agSlice,
4192 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004193 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004194 "error_if_validators": (
4195 TosaErrorValidator.evStartSmallerZero,
4196 TosaErrorValidator.evSizeSmallerEqualZero,
4197 TosaErrorValidator.evStartSizeOutsideBounds,
4198 TosaErrorValidator.evSizeOutputShapeMismatch,
4199 TosaErrorValidator.evInputSizeStartLengthMismatch,
4200 TosaErrorValidator.evWrongRank,
4201 TosaErrorValidator.evWrongInputType,
4202 TosaErrorValidator.evWrongOutputType,
4203 TosaErrorValidator.evWrongInputList,
4204 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004205 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004206 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004207 },
4208 "tile": {
4209 "op": Op.TILE,
4210 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004211 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004212 "build_fcn": (
4213 build_tile,
4214 TosaTensorGen.tgBasic,
4215 TosaTensorValuesGen.tvgDefault,
4216 TosaArgGen.agTile,
4217 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004218 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004219 "error_if_validators": (
4220 TosaErrorValidator.evWrongInputType,
4221 TosaErrorValidator.evWrongOutputType,
4222 TosaErrorValidator.evWrongInputList,
4223 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004224 TosaErrorValidator.evRankMismatch,
4225 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004226 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004227 },
4228 "transpose": {
4229 "op": Op.TRANSPOSE,
4230 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004231 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004232 "build_fcn": (
4233 build_transpose,
4234 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004235 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004236 TosaArgGen.agTranspose,
4237 ),
4238 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004239 "error_if_validators": (
4240 TosaErrorValidator.evIndexOutsideBounds,
4241 TosaErrorValidator.evIndexUsedTwice,
4242 TosaErrorValidator.evWrongInputType,
4243 TosaErrorValidator.evWrongOutputType,
4244 TosaErrorValidator.evWrongInputList,
4245 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004246 TosaErrorValidator.evWrongRank,
4247 TosaErrorValidator.evRankMismatch,
4248 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004249 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004250 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004251 # Data nodes
4252 "const": {
4253 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004254 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004255 "build_fcn": (
4256 build_const,
4257 TosaTensorGen.tgBasic,
4258 TosaTensorValuesGen.tvgDefault,
4259 None,
4260 ),
Luke Hutton65872422023-02-20 10:33:04 +00004261 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004262 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004263 "identity": {
4264 "op": Op.IDENTITY,
4265 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004266 "build_fcn": (
4267 build_unary,
4268 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004269 TosaTensorValuesGen.tvgLazyGenDefault,
4270 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004271 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004272 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004273 "data_gen": {
4274 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4275 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004276 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004277 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004278 "gather": {
4279 "op": Op.GATHER,
4280 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4281 "operands": (1, 0),
4282 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004283 "build_fcn": (
4284 build_gather,
4285 TosaTensorGen.tgBasic,
4286 TosaTensorValuesGen.tvgDefault,
4287 None,
4288 ),
James Ward24dbc422022-10-19 12:20:31 +01004289 "types": (
4290 DType.INT8,
4291 DType.INT16,
4292 DType.INT32,
4293 DType.FP16,
4294 DType.BF16,
4295 DType.FP32,
4296 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004297 "error_if_validators": (
4298 TosaErrorValidator.evWrongInputType,
4299 TosaErrorValidator.evWrongOutputType,
4300 TosaErrorValidator.evWrongInputList,
4301 TosaErrorValidator.evWrongOutputList,
4302 TosaErrorValidator.evWrongRank,
4303 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004304 },
4305 "scatter": {
4306 "op": Op.SCATTER,
4307 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004308 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004309 "operands": (2, 0),
4310 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004311 "build_fcn": (
4312 build_scatter,
4313 TosaTensorGen.tgScatter,
4314 TosaTensorValuesGen.tvgDefault,
4315 None,
4316 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004317 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004318 "error_if_validators": (
4319 TosaErrorValidator.evWrongInputType,
4320 TosaErrorValidator.evWrongOutputType,
4321 TosaErrorValidator.evWrongInputList,
4322 TosaErrorValidator.evWrongOutputList,
4323 TosaErrorValidator.evWrongRank,
4324 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004325 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004326 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004327 "resize": {
4328 "op": Op.RESIZE,
4329 "operands": (1, 0),
4330 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004331 "build_fcn": (
4332 build_resize,
4333 TosaTensorGen.tgNHWC,
4334 TosaTensorValuesGen.tvgDefault,
4335 TosaArgGen.agResize,
4336 ),
James Ward24dbc422022-10-19 12:20:31 +01004337 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004338 "invalid_test_validators": (
4339 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004340 ),
4341 "error_if_validators": (
4342 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004343 TosaErrorValidator.evScaleSmallerEqualZero,
4344 TosaErrorValidator.evScaleNLargerMax,
4345 TosaErrorValidator.evScaleDLargerMax,
4346 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004347 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004348 TosaErrorValidator.evBorderSmallerMin,
4349 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004350 TosaErrorValidator.evWrongInputType,
4351 TosaErrorValidator.evWrongOutputType,
4352 TosaErrorValidator.evWrongRank,
4353 TosaErrorValidator.evWrongInputList,
4354 TosaErrorValidator.evWrongOutputList,
4355 TosaErrorValidator.evBatchMismatch,
4356 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004357 TosaErrorValidator.evResizeOutputShapeMismatch,
4358 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004359 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004360 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004361 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004362 "cast": {
4363 "op": Op.CAST,
4364 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004365 "build_fcn": (
4366 build_cast,
4367 TosaTensorGen.tgBasic,
4368 TosaTensorValuesGen.tvgDefault,
4369 TosaArgGen.agCast,
4370 ),
James Ward8b390432022-08-12 20:48:56 +01004371 "types": (
4372 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004373 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004374 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004375 DType.INT8,
4376 DType.INT16,
4377 DType.INT32,
4378 DType.BOOL,
4379 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004380 "error_if_validators": (
4381 TosaErrorValidator.evWrongInputType,
4382 TosaErrorValidator.evWrongOutputType,
4383 TosaErrorValidator.evWrongInputList,
4384 TosaErrorValidator.evWrongOutputList,
4385 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004386 },
4387 "rescale": {
4388 "op": Op.RESCALE,
4389 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004390 "build_fcn": (
4391 build_rescale,
4392 TosaTensorGen.tgBasic,
4393 TosaTensorValuesGen.tvgDefault,
4394 TosaArgGen.agRescale,
4395 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004396 "types": [
4397 DType.UINT8,
4398 DType.INT8,
4399 DType.INT16,
4400 DType.INT32,
4401 DType.INT48,
4402 DType.UINT16,
4403 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004404 "error_if_validators": (
4405 TosaErrorValidator.evInputZeroPointNotZero,
4406 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004407 TosaErrorValidator.evU16InputZeroPointNotValid,
4408 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004409 TosaErrorValidator.evScaleTrue,
4410 TosaErrorValidator.evScaleNotTrue,
4411 TosaErrorValidator.evWrongInputType,
4412 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004413 TosaErrorValidator.evWrongInputList,
4414 TosaErrorValidator.evWrongOutputList,
4415 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004416 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004417 # Custom
4418 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004419 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004420 # Two varients of cond_if, one that generates one of two constant tensors (no
4421 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4422 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004423 "cond_if_const": {
4424 "op": Op.COND_IF,
4425 "operands": (0, 2),
4426 "build_fcn": (
4427 build_cond_if_const,
4428 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004429 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004430 TosaArgGen.agCondIf,
4431 ),
4432 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004433 "error_if_validators": (
4434 TosaErrorValidator.evOutputListThenGraphMismatch,
4435 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004436 TosaErrorValidator.evCondIfCondNotMatchingBool,
4437 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004438 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004439 },
4440 "cond_if_binary": {
4441 "op": Op.COND_IF,
4442 "operands": (2, 0),
4443 "build_fcn": (
4444 build_cond_if_binary,
4445 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004446 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004447 TosaArgGen.agCondIf,
4448 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004449 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004450 "error_if_validators": (
4451 TosaErrorValidator.evInputListThenGraphMismatch,
4452 TosaErrorValidator.evInputListElseGraphMismatch,
4453 TosaErrorValidator.evOutputListThenGraphMismatch,
4454 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004455 TosaErrorValidator.evCondIfCondNotMatchingBool,
4456 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004457 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004458 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004459 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004460 "while_loop": {
4461 "op": Op.WHILE_LOOP,
4462 "operands": (0, 1),
4463 "build_fcn": (
4464 build_while_loop,
4465 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004466 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004467 TosaArgGen.agWhileLoop,
4468 ),
4469 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004470 "error_if_validators": (
4471 TosaErrorValidator.evInputListOutputListMismatch,
4472 TosaErrorValidator.evInputListCondGraphMismatch,
4473 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4474 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4475 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004476 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004477 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004478 },
Luke Hutton57287132023-02-06 14:54:18 +00004479 "fft2d": {
4480 "op": Op.FFT2D,
4481 "operands": (2, 0),
4482 "rank": (3, 3),
4483 "build_fcn": (
4484 build_fft2d,
4485 TosaTensorGen.tgFFT2d,
4486 TosaTensorValuesGen.tvgDefault,
4487 TosaArgGen.agFFT2d,
4488 ),
4489 "types": [DType.FP32],
4490 "error_if_validators": (
4491 TosaErrorValidator.evWrongInputType,
4492 TosaErrorValidator.evWrongOutputType,
4493 TosaErrorValidator.evWrongInputList,
4494 TosaErrorValidator.evWrongOutputList,
4495 TosaErrorValidator.evWrongRank,
4496 TosaErrorValidator.evBatchMismatch,
4497 TosaErrorValidator.evKernelNotPowerOfTwo,
4498 TosaErrorValidator.evFFTInputShapeMismatch,
4499 TosaErrorValidator.evFFTOutputShapeMismatch,
4500 ),
4501 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004502 "rfft2d": {
4503 "op": Op.RFFT2D,
4504 "operands": (1, 0),
4505 "rank": (3, 3),
4506 "build_fcn": (
4507 build_rfft2d,
4508 TosaTensorGen.tgRFFT2d,
4509 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004510 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004511 ),
4512 "types": [DType.FP32],
4513 "error_if_validators": (
4514 TosaErrorValidator.evWrongInputType,
4515 TosaErrorValidator.evWrongOutputType,
4516 TosaErrorValidator.evWrongInputList,
4517 TosaErrorValidator.evWrongOutputList,
4518 TosaErrorValidator.evWrongRank,
4519 TosaErrorValidator.evBatchMismatch,
4520 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004521 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004522 ),
4523 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004524 }
4525
Kevin Cheng550ccc52021-03-03 11:21:43 -08004526
Eric Kunzee5e26762020-10-13 16:11:07 -07004527class OutputShaper:
4528 # Methods in this class compute the expected output shape and datatype
4529 # for common classes of operations
4530 def __init__(self):
4531 pass
4532
4533 # These methods return arguments that can be used for
4534 # creating a new output tensor
4535 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004536 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4537 if error_name != ErrorIf.RankMismatch:
4538 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004539 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004540
4541 shape = []
4542 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004543 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004544 shape.append(b.shape[i])
4545 else:
4546 shape.append(a.shape[i])
4547
Jerry Ge135c9552023-05-23 20:59:32 +00004548 fuzz_idx = rng.integers(0, len(a.shape))
4549 if error_name == ErrorIf.DimensionMismatch:
4550 shape[fuzz_idx] += 1
4551
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004552 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004553 all_dtypes = [
4554 DType.INT8,
4555 DType.INT16,
4556 DType.INT32,
4557 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004558 DType.FP16,
4559 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004560 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004561 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004562 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4563 outputDType = rng.choice(wrong_dtypes)
4564 else:
4565 outputDType = a.dtype
4566
4567 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004568
4569 @staticmethod
4570 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004571 assert len(a.shape) == len(b.shape)
4572 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004573
4574 shape = []
4575 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004576 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004577 shape.append(a.shape[i])
4578
Kevin Cheng550ccc52021-03-03 11:21:43 -08004579 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004580
4581 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004582 def unaryOp(ser, rng, a, error_name=None):
4583 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004584 all_dtypes = [
4585 DType.INT8,
4586 DType.INT16,
4587 DType.INT32,
4588 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004589 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004590 DType.FP16,
4591 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004592 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004593 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4594 outputDType = rng.choice(wrong_dtypes)
4595 else:
4596 outputDType = a.dtype
4597
4598 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004599
4600 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004601 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004602 if error_name != ErrorIf.RankMismatch:
4603 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004604 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004605
4606 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004607 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004608 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004609 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4610 else:
4611 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004612
Jerry Ge135c9552023-05-23 20:59:32 +00004613 fuzz_idx = rng.integers(0, len(a.shape))
4614 if error_name == ErrorIf.DimensionMismatch:
4615 shape[fuzz_idx] += 1
4616
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004617 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004618 all_dtypes = [
4619 DType.INT8,
4620 DType.INT16,
4621 DType.INT32,
4622 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004623 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004624 DType.FP16,
4625 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004626 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004627 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4628 outputDType = rng.choice(wrong_dtypes)
4629 else:
4630 outputDType = a.dtype
4631
4632 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004633
4634 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004635 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004636 if error_name != ErrorIf.RankMismatch:
4637 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004638 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004639
4640 # Do broadcast
4641 shape = []
4642 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004643 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004644 shape.append(b.shape[i])
4645 else:
4646 shape.append(a.shape[i])
4647
Jerry Ge135c9552023-05-23 20:59:32 +00004648 fuzz_idx = rng.integers(0, len(a.shape))
4649 if error_name == ErrorIf.DimensionMismatch:
4650 shape[fuzz_idx] += 1
4651
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004652 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004653 wrong_dtypes = [
4654 DType.INT8,
4655 DType.INT16,
4656 DType.INT32,
4657 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004658 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004659 DType.FP16,
4660 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004661 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004662 outputDType = rng.choice(wrong_dtypes)
4663 else:
4664 outputDType = DType.BOOL
4665
4666 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004667
4668 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004669 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004670 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004671 if error_name not in [
4672 ErrorIf.AxisSmallerZero,
4673 ErrorIf.AxisLargerRank,
4674 ErrorIf.ShapeOfAxisNotOne,
4675 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004676 shape[axis] = 1
4677 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4678 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004679
Matthew Haddond6ce7252021-09-29 15:35:44 +01004680 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004681 all_dtypes = [
4682 DType.INT8,
4683 DType.INT16,
4684 DType.INT32,
4685 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004686 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004687 DType.FP16,
4688 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004689 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004690 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4691 outputDType = rng.choice(wrong_dtypes)
4692 else:
4693 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004694
Matthew Haddond6ce7252021-09-29 15:35:44 +01004695 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004696
4697 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004698 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004699 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004700
4701 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4702 del shape[axis]
4703
4704 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4705 remove = rng.choice([True, False])
4706 if remove and len(shape) > 1:
4707 del shape[0]
4708 else:
4709 shape.append(1)
4710 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4711 for i in range(len(shape)):
4712 shape[i] = shape[i] + rng.integers(1, 10)
4713
4714 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004715 all_dtypes = [
4716 DType.INT8,
4717 DType.INT16,
4718 DType.INT32,
4719 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004720 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004721 DType.FP16,
4722 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004723 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004724 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4725 outputDType = rng.choice(wrong_dtypes)
4726 else:
4727 outputDType = DType.INT32
4728
4729 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004730
4731 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004732 def conv2dOp(
4733 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4734 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004735
4736 # IFM: NHWC
4737 # Filter: OHWI
4738 # OFM: NHWC
4739
Kevin Cheng550ccc52021-03-03 11:21:43 -08004740 h = (
4741 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004742 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004743 + padding[0]
4744 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004745 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004746 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004747
Kevin Cheng550ccc52021-03-03 11:21:43 -08004748 w = (
4749 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004750 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004751 + padding[2]
4752 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004753 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004754 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004755
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004756 if error_name == ErrorIf.ConvOutputShapeMismatch:
4757 choices = [1, 2, 3]
4758 change = rng.choice(choices)
4759 # increment in multiples of stride to not hit non-integer error case
4760 if change in [1, 3]:
4761 h = h + (rng.choice(choices) * strides[0])
4762 if change in [2, 3]:
4763 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004764
Eric Kunzee5e26762020-10-13 16:11:07 -07004765 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4766
James Ward8b390432022-08-12 20:48:56 +01004767 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004768 # Pick some potentially correct output dtype if input type is incorrect
4769 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004770 else:
James Ward8b390432022-08-12 20:48:56 +01004771 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004772
4773 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004774 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004775 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004776 else:
4777 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004778 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004779 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004780
Kevin Cheng550ccc52021-03-03 11:21:43 -08004781 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004782
4783 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004784 def conv3dOp(
4785 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4786 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004787
4788 # IFM: NDHWC
4789 # Filter: ODHWI
4790 # OFM: NDHWC
4791
4792 d = (
4793 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004794 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004795 + padding[0]
4796 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004797 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004798 ) // strides[0] + 1
4799
4800 h = (
4801 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004802 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004803 + padding[2]
4804 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004805 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004806 ) // strides[1] + 1
4807
4808 w = (
4809 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004810 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004811 + padding[4]
4812 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004813 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004814 ) // strides[2] + 1
4815
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004816 if error_name == ErrorIf.ConvOutputShapeMismatch:
4817 choices = [1, 2, 3, 4]
4818 change = rng.choice(choices)
4819 # increment in multiples of stride to not hit non-integer error case
4820 if change in [1, 4]:
4821 d = d + (rng.choice(choices) * strides[0])
4822 if change in [2, 4]:
4823 h = h + (rng.choice(choices) * strides[1])
4824 if change in [3, 4]:
4825 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004826
Kevin Cheng1533b852021-09-01 12:51:58 -07004827 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4828
James Ward8b390432022-08-12 20:48:56 +01004829 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004830 # Pick some potentially correct output dtype if input type is incorrect
4831 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004832 else:
James Ward8b390432022-08-12 20:48:56 +01004833 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004834
4835 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004836 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004837 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004838 else:
4839 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004840 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004841 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004842
4843 return ser.addOutput(ofm_shape, out_dtype)
4844
4845 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004846 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004847 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004848 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004849 # IFM: NHWC
4850 # Filter: HWCM
4851 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004852
Kevin Cheng550ccc52021-03-03 11:21:43 -08004853 h = (
4854 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004855 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004856 + padding[0]
4857 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004858 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004859 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004860
Kevin Cheng550ccc52021-03-03 11:21:43 -08004861 w = (
4862 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004863 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004864 + padding[2]
4865 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004866 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004867 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004868
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004869 if error_name == ErrorIf.ConvOutputShapeMismatch:
4870 choices = [1, 2, 3]
4871 change = rng.choice(choices)
4872 # increment in multiples of stride to not hit non-integer error case
4873 if change in [1, 3]:
4874 h = h + (rng.choice(choices) * strides[0])
4875 if change in [2, 3]:
4876 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004877
Eric Kunzee5e26762020-10-13 16:11:07 -07004878 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4879
James Ward8b390432022-08-12 20:48:56 +01004880 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004881 # Pick some potentially correct output dtype if input type is incorrect
4882 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004883 else:
James Ward8b390432022-08-12 20:48:56 +01004884 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004885
4886 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004887 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004888 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004889 else:
4890 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004891 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004892 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004893
Kevin Cheng550ccc52021-03-03 11:21:43 -08004894 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004895
4896 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004897 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004898 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004899 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004900 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004901 h = 1
4902 w = 1
4903 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004904 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4905 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004906
4907 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004908 choices = [1, 2, 3]
4909 change = rng.choice(choices)
4910 # increment in multiples of stride to not hit non-integer error case
4911 if change in [1, 3]:
4912 h = h + (rng.choice(choices) * stride[0])
4913 if change in [2, 3]:
4914 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004915 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004916
4917 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004918 all_dtypes = [
4919 DType.INT8,
4920 DType.INT16,
4921 DType.INT32,
4922 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004923 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004924 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004925 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004926 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004927 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4928 outputDType = rng.choice(wrong_dtypes)
4929 else:
4930 outputDType = ifm.dtype
4931
4932 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004933
4934 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004935 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004936 # input: N, IC
4937 # filter: OC, IC
4938 # output: N, OC
4939
4940 output_shape = [input.shape[0], filter.shape[0]]
4941
James Ward8b390432022-08-12 20:48:56 +01004942 # Validated in arg_gen (also invalidated for ErrorIf)
4943 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004944
Kevin Cheng550ccc52021-03-03 11:21:43 -08004945 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004946
4947 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004948 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004949 # a: N, H, C
4950 # b: N, C, W
4951 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004952
Kevin Cheng2d60f002021-06-09 14:18:32 -07004953 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004954
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004955 if error_name == ErrorIf.WrongOutputType:
4956 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004957 incorrect_types = (
4958 DType.INT4,
4959 DType.INT8,
4960 DType.INT16,
4961 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004962 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004963 DType.FP16,
4964 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004965 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004966 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004967 incorrect_types = (
4968 DType.INT4,
4969 DType.INT8,
4970 DType.INT16,
4971 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004972 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004973 DType.FP16,
4974 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004975 )
James Ward24dbc422022-10-19 12:20:31 +01004976 elif (
4977 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4978 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004979 incorrect_types = (
4980 DType.INT4,
4981 DType.INT8,
4982 DType.INT16,
4983 DType.INT32,
4984 DType.INT48,
4985 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004986 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004987 elif error_name == ErrorIf.WrongInputType:
4988 # Pick some potentially correct output dtype if input type is incorrect
4989 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004990 else:
James Ward8b390432022-08-12 20:48:56 +01004991 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004992
Kevin Cheng550ccc52021-03-03 11:21:43 -08004993 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004994
4995 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004996 def concatOp(ser, rng, axis, inputs, error_name=None):
4997 input1 = inputs[0]
4998 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004999
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005000 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005001 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005002 if not (
5003 # unable to concat tensors of different ranks
5004 error_name == ErrorIf.ConcatInputRankMismatch
5005 # unable to concat tensors along an invalid axis
5006 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005007 ):
5008 for tensor in remaining_inputs:
5009 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005010
Matthew Haddon01c359d2021-10-15 16:30:48 +01005011 if error_name == ErrorIf.ConcatShapeSumMismatch:
5012 output_shape[axis] += rng.integers(5, 10)
5013
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005014 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005015 all_dtypes = {
5016 DType.INT8,
5017 DType.INT16,
5018 DType.INT32,
5019 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005020 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005021 DType.FP16,
5022 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005023 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005024 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5025 outputDType = rng.choice(wrong_dtypes)
5026 else:
5027 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005028
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005029 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005030
5031 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005032 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005033
5034 output_shape = a.shape.copy()
5035
5036 for i in range(len(output_shape)):
5037 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5038
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005039 if error_name == ErrorIf.PadOutputShapeMismatch:
5040 bad_dim = rng.choice(range(len(output_shape)))
5041 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005042 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005043 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005044
Matthew Haddone807aae2021-10-11 18:12:58 +01005045 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005046 all_dtypes = [
5047 DType.INT8,
5048 DType.INT16,
5049 DType.INT32,
5050 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005051 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005052 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005053 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005054 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005055 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5056 outputDType = rng.choice(wrong_dtypes)
5057 else:
5058 outputDType = a.dtype
5059
5060 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005061
5062 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005063 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00005064 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00005065
5066 if error_name == ErrorIf.WrongOutputType:
5067 all_dtypes = [
5068 DType.INT8,
5069 DType.INT16,
5070 DType.INT32,
5071 DType.INT48,
5072 DType.FP32,
5073 DType.FP16,
5074 DType.BF16,
5075 ]
5076 wrong_dtypes = list(set(all_dtypes))
5077 outputDType = rng.choice(wrong_dtypes)
5078 else:
5079 outputDType = DType.SHAPE
5080
5081 return ser.addOutput(output_shape, outputDType)
5082
5083 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005084 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005085 output_shape = shape.copy()
5086
Matthew Haddone807aae2021-10-11 18:12:58 +01005087 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5088 for i in range(len(output_shape)):
5089 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5090
5091 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005092 all_dtypes = [
5093 DType.INT8,
5094 DType.INT16,
5095 DType.INT32,
5096 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005097 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005098 DType.FP16,
5099 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005100 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005101 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5102 outputDType = rng.choice(wrong_dtypes)
5103 else:
5104 outputDType = a.dtype
5105
5106 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005107
5108 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005109 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005110
Matthew Haddone807aae2021-10-11 18:12:58 +01005111 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005112 all_dtypes = [
5113 DType.INT8,
5114 DType.INT16,
5115 DType.INT32,
5116 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005117 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005118 DType.FP16,
5119 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005120 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005121 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005122 outputDType = rng.choice(wrong_dtypes)
5123 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005124 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005125
Luke Huttona4e48ca2023-02-22 11:53:48 +00005126 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005127 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005128 for index in range(len(output_shape)):
5129 if output_shape[index] <= 2:
5130 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5131 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005132 output_shape[index] = output_shape[index] + rng.choice(
5133 [-2, -1, 1, 2]
5134 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005135 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5136 output_shape = input.shape.copy()
5137 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005138 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005139
5140 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005141
5142 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005143 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005144
5145 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005146 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005147
5148 for i in range(len(output_shape)):
5149 output_shape[i] = a.shape[i] * multiples[i]
5150
Luke Huttona4e48ca2023-02-22 11:53:48 +00005151 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005152 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005153
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005154 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005155 all_dtypes = [
5156 DType.INT8,
5157 DType.INT16,
5158 DType.INT32,
5159 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005160 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005161 DType.FP16,
5162 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005163 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005164 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5165 outputDType = rng.choice(wrong_dtypes)
5166 else:
5167 outputDType = a.dtype
5168
5169 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005170
5171 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005172 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005173 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005174
Kevin Cheng550ccc52021-03-03 11:21:43 -08005175 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005176
Luke Huttona4e48ca2023-02-22 11:53:48 +00005177 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005178 for i in range(len(output_shape)):
5179 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005180
Luke Huttona4e48ca2023-02-22 11:53:48 +00005181 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5182 for i in range(len(output_shape)):
5183 output_shape[i] += rng.integers(1, 10)
5184 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005185 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005186
Matthew Haddone807aae2021-10-11 18:12:58 +01005187 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005188 all_dtypes = [
5189 DType.INT8,
5190 DType.INT16,
5191 DType.INT32,
5192 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005193 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005194 DType.FP16,
5195 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005196 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005197 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5198 outputDType = rng.choice(wrong_dtypes)
5199 else:
5200 outputDType = a.dtype
5201
5202 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005203
5204 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005205 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005206 if error_name != ErrorIf.WrongRank:
5207 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005208 assert len(indices.shape) == 2
5209 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005210
Kevin Cheng77d0f762020-11-24 10:26:32 -08005211 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5212
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005213 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005214 all_dtypes = [
5215 DType.INT8,
5216 DType.INT16,
5217 DType.INT32,
5218 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005219 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005220 DType.FP16,
5221 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005222 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005223 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5224 outputDType = rng.choice(wrong_dtypes)
5225 else:
5226 outputDType = values.dtype
5227
5228 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005229
5230 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005231 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005232 if error_name != ErrorIf.WrongRank:
5233 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005234 assert len(indices.shape) == 2
5235 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005236 assert values_in.shape[0] == indices.shape[0] # N
5237 assert input.shape[1] == indices.shape[1] # W
5238 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005239
5240 output_shape = values_in.shape
5241
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005242 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005243 all_dtypes = [
5244 DType.INT8,
5245 DType.INT16,
5246 DType.INT32,
5247 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005248 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005249 DType.FP16,
5250 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005251 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005252 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5253 outputDType = rng.choice(wrong_dtypes)
5254 else:
5255 outputDType = values_in.dtype
5256
5257 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005258
5259 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005260 def tableOp(ser, rng, input, error_name=None):
5261 # Same shape as the input, dtype dependent on input dtype
5262 if error_name != ErrorIf.WrongInputType:
5263 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005264 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005265 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005266 wrong_dtypes = [
5267 DType.INT8,
5268 DType.INT16,
5269 DType.INT32,
5270 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005271 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005272 DType.FP16,
5273 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005274 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005275 wrong_dtypes.remove(output_dtype)
5276 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005277 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005278
5279 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005280 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005281 serializer,
5282 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005283 input,
5284 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005285 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005286 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005287 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005288 input_dtype,
5289 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005290 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005291 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005292 # Calculate OH, OW
5293 scale_y_n = scale[0]
5294 scale_y_d = scale[1]
5295 scale_x_n = scale[2]
5296 scale_x_d = scale[3]
5297 if error_name == ErrorIf.ScaleSmallerEqualZero:
5298 scale_y_n = max(scale_y_n, 1)
5299 scale_y_d = max(scale_y_d, 1)
5300 scale_x_n = max(scale_x_n, 1)
5301 scale_x_d = max(scale_x_d, 1)
5302
5303 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5304 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5305
5306 if error_name is not None:
5307 # Make sure the output tensor is valid, which can occur when
5308 # scale, offset or border have been changed for ERROR_IFs
5309 oh = max(oh, 1)
5310 ow = max(ow, 1)
5311 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005312 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5313 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005314
5315 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5316 choices = [1, 2, 3]
5317 change = rng.choice(choices)
5318 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5319 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005320 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005321 oh -= scale_y_d
5322 assert oh > 0 # Should have been caught in agResize
5323 else:
5324 oh += scale_y_d
5325 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005326 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005327 ow -= scale_x_d
5328 assert ow > 0 # Should have been caught in agResize
5329 else:
5330 ow += scale_x_d
5331
Matthew Haddon848efb42021-09-09 12:30:53 +01005332 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005333 output_dims = [
5334 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005335 oh,
5336 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005337 input.shape[0],
5338 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005339 elif error_name == ErrorIf.BatchMismatch:
5340 output_dims = [
5341 input.shape[0] + rng.integers(1, 10),
5342 oh,
5343 ow,
5344 input.shape[3],
5345 ]
5346 elif error_name == ErrorIf.ChannelMismatch:
5347 output_dims = [
5348 input.shape[0],
5349 oh,
5350 ow,
5351 input.shape[3] + rng.integers(1, 10),
5352 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005353 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005354 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005355
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005356 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005357
5358 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005359 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005360 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005361
5362 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005363 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005364 if error_name == ErrorIf.ConvOutputShapeMismatch:
5365 choices = [1, 2, 3]
5366 change = rng.choice(choices)
5367 if change in [1, 3]:
5368 output_shape[1] = output_shape[1] + rng.choice(choices)
5369 if change in [2, 3]:
5370 output_shape[2] = output_shape[2] + rng.choice(choices)
5371
James Ward8b390432022-08-12 20:48:56 +01005372 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005373 # Pick some potentially correct output dtype if input type is incorrect
5374 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005375 else:
James Ward8b390432022-08-12 20:48:56 +01005376 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005377
5378 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005379 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005380 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005381 else:
5382 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005383 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005384 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005385
Kevin Cheng550ccc52021-03-03 11:21:43 -08005386 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005387
5388 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005389 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5390 outputs = []
5391
5392 assert ifm1.dtype == ifm2.dtype
5393 input_dtype = ifm1.dtype
5394
5395 if error_name != ErrorIf.FFTInputShapeMismatch:
5396 assert ifm1.shape == ifm2.shape
5397
5398 input_shape = ifm1.shape
5399 if error_name != ErrorIf.WrongRank:
5400 assert len(input_shape) == 3
5401
5402 output_shape = input_shape.copy()
5403 output_dtype = input_dtype
5404
5405 if error_name == ErrorIf.WrongOutputType:
5406 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005407 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005408 output_dtype = rng.choice(wrong_dtypes)
5409 elif error_name == ErrorIf.BatchMismatch:
5410 output_shape[0] += rng.integers(1, 10)
5411 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5412 modify_dim = rng.choice([1, 2])
5413 output_shape[modify_dim] += rng.integers(1, 10)
5414
5415 outputs.append(serializer.addOutput(output_shape, output_dtype))
5416 outputs.append(serializer.addOutput(output_shape, output_dtype))
5417 return outputs
5418
5419 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005420 def rfft2dOp(serializer, rng, value, error_name=None):
5421 outputs = []
5422
5423 input_shape = value.shape
5424 if error_name != ErrorIf.WrongRank:
5425 assert len(input_shape) == 3
5426
5427 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5428
5429 output_dtype = value.dtype
5430 if error_name == ErrorIf.WrongOutputType:
5431 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005432 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005433 output_dtype = rng.choice(wrong_dtypes)
5434 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005435 output_shape[0] += rng.integers(1, 10)
5436 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5437 modify_dim = rng.choice([1, 2])
5438 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005439
5440 outputs.append(serializer.addOutput(output_shape, output_dtype))
5441 outputs.append(serializer.addOutput(output_shape, output_dtype))
5442 return outputs