blob: 1995cbc56996cd3ed80f17ccefdf809b84bf17f4 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
40 TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6)
41 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
69 vals.append(v)
70 return tuple(sorted(vals))
71
72 self.random_float_range = {}
73 for dtype in (DType.FP32, DType.FP16, DType.BF16):
74 self.random_float_range[dtype] = convertFPRange(
75 args.tensor_fp_value_range,
76 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
77 )
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 def createSerializer(self, opName, testPath):
80 self.testPath = os.path.join(opName, testPath)
81
82 fullPath = os.path.join(self.basePath, self.testPath)
83 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010084 # Embed const data in the flatbuffer
85 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010086 if self.args.lazy_data_gen:
87 # Lazy data generation - so make constants files
88 constMode = ts.ConstMode.INPUTS
89 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 constMode = ts.ConstMode.EMBED_DUMP
91 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070092
93 def getSerializer(self):
94 return self.ser
95
Jeremy Johnson1271c442023-09-05 11:39:26 +010096 def serialize(self, testName, metaData=None):
97 path = Path(self.basePath) / self.testPath
98
99 # Write out TOSA flatbuffer binary
100 path_fb = path / f"{testName}.tosa"
101 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 fd.write(self.ser.serialize())
103
Jeremy Johnson1271c442023-09-05 11:39:26 +0100104 # Get JSON descriptor from serializer
105 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
106
107 if metaData:
108 # Add extra meta data to desc.json
109 desc["meta"] = metaData
110
111 # Validate desc.json before we output it
112 self.descSchemaValidator.validate_config(desc)
113
114 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100115 if "data_gen" in metaData:
116 if self.args.lazy_data_gen:
117 # Output datagen meta data as CPP data
118 path_md = path / f"{testName}_meta_data_gen.cpp"
119 with path_md.open("w") as fd:
120 fd.write(TOSA_AUTOGENERATED_HEADER)
121 fd.write("// Test meta data for data generation setup\n\n")
122 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
123 json.dump(metaData["data_gen"], fd)
124 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100125 if "compliance" in metaData:
126 # Output datagen meta data as CPP data
127 path_md = path / f"{testName}_meta_compliance.cpp"
128 with path_md.open("w") as fd:
129 fd.write(TOSA_AUTOGENERATED_HEADER)
130 fd.write("// Test meta data for compliance validation\n\n")
131 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
132 json.dump(metaData["compliance"], fd)
133 fd.write(')";\n\n')
134
135 # Write desc.json
136 path_desc = path / "desc.json"
137 with path_desc.open("w") as fd:
138 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700139
Matthew Haddon74567092021-07-16 15:38:20 +0100140 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000141 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100142 seed = self.random_seed + 1
143 self.rng = np.random.default_rng(seed)
144
Jeremy Johnson1271c442023-09-05 11:39:26 +0100145 def getDTypeRange(self, dtype, high_inclusive=False):
146 # Returns dtype value range boundaries (low, high)
147 # The high boundary is excluded in the range
148 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100149 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100150 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 elif dtype == DType.BOOL:
152 rng = (0, 2)
153 elif dtype == DType.UINT8:
154 rng = (0, 256)
155 elif dtype == DType.UINT16:
156 rng = (0, 65536)
157 elif dtype == DType.INT4:
158 # TOSA specific INT4 weight range from -7 to 7
159 rng = (-7, 8)
160 elif dtype == DType.INT8:
161 rng = (-128, 128)
162 elif dtype == DType.INT16:
163 rng = (-32768, 32768)
164 elif dtype in (DType.INT32, DType.SHAPE):
165 # restricting too large value for SHAPE
166 rng = (-(1 << 31), (1 << 31))
167 elif dtype == DType.INT48:
168 rng = (-(1 << 47), (1 << 47))
169 else:
170 raise Exception("Unknown dtype: {}".format(dtype))
171
172 if not high_inclusive:
173 # Exclusive high: low <= range < high
174 return rng
175 else:
176 # Inclusive range: low <= range <= high
177 return (rng[0], rng[1] - 1)
178
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 def getRandTensor(self, shape, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100180 low, high = self.getDTypeRange(dtype)
181
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700183 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700184 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100185 return np.int64(self.rng.integers(low=low, high=high, size=shape))
186 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
187 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
188
189 if dtype == DType.FP16:
190 return np.float16(f_tensor)
191 else:
192 f32_tensor = np.float32(f_tensor)
193 if dtype == DType.BF16:
194 # Floor the last 16 bits of each f32 value
195 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
196 else:
197 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700198 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 # All other integer types
200 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700201
Kevin Cheng989cb052021-04-28 16:29:44 -0700202 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700203 placeholders = []
204
Kevin Cheng989cb052021-04-28 16:29:44 -0700205 assert len(shape_list) == len(dtype_list)
206
Jeremy Johnson1271c442023-09-05 11:39:26 +0100207 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700208 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100209 if not self.args.lazy_data_gen:
210 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700212
213 return placeholders
214
Kevin Cheng989cb052021-04-28 16:29:44 -0700215 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700216 consts = []
217
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 assert len(shape_list) == len(dtype_list)
219
Jeremy Johnson1271c442023-09-05 11:39:26 +0100220 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700221 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100222 if not self.args.lazy_data_gen:
223 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700225
226 return consts
227
228 def makeShape(self, rank):
229 if self.targetted_shape:
230 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800231 return np.int32(
232 self.rng.integers(
233 low=self.args.tensor_shape_range[0],
234 high=self.args.tensor_shape_range[1],
235 size=rank,
236 )
237 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700238
239 def setTargetShape(self, shape):
240 self.targetted_shape = shape
241
242 def randInt(self, low=0, high=256):
243 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
244
245 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100246 low, high = self.getDTypeRange(dtype)
247
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100248 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100249 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100250 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100251 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100252 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100253 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
254 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700255 elif dtype == DType.BOOL:
256 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700257 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 # Special size
259 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700260
261 return np.int32(self.rng.integers(low, high, size=1))[0]
262
263 def shapeStr(self, shape):
264
265 sStr = []
266 # Convert to strings
267 for i in shape:
268 sStr.append(str(i))
269
Kevin Cheng550ccc52021-03-03 11:21:43 -0800270 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700271
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100272 def typeStr(self, dtype):
273 if isinstance(dtype, list) or isinstance(dtype, tuple):
274 assert len(dtype) >= 2
275 strs = [self.typeStr(t) for t in dtype]
276 # Limit types to the first 2 as the 3rd is the accumulator
277 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700278 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100279 if dtype in gtu.DTYPE_ATTRIBUTES:
280 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700281 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100282 raise Exception(
283 "Unknown dtype, cannot convert to string: {}".format(dtype)
284 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100286 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100287 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100288 if dtype in gtu.DTYPE_ATTRIBUTES:
289 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700290 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100291 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700292
Luke Hutton57287132023-02-06 14:54:18 +0000293 def constrictBatchSize(self, shape):
294 # Limit the batch size unless an explicit target shape set
295 if self.args.max_batch_size and not self.args.target_shapes:
296 shape[0] = min(shape[0], self.args.max_batch_size)
297 return shape
298
James Ward30124a82023-02-02 14:56:33 +0000299 def makeDimension(self):
300 return self.randInt(
301 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
302 )
303
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100304 def tensorComplianceMetaData(
305 self, op, inputType, argsDict, outputTensor, errorName
306 ):
307 if (
308 errorName
309 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
310 or not gtu.dtypeIsSupportedByCompliance(inputType)
311 ):
312 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100313 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100314
Jeremy Johnson1271c442023-09-05 11:39:26 +0100315 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100316 compliance_tens = {
317 "mode": None,
318 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
319 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
320 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100321 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
322 mode = gtu.ComplianceMode.DOT_PRODUCT
323 compliance_tens["dot_product_info"] = {
324 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100325 "ks": int(argsDict["ksb"])
326 if "ksb" in argsDict
327 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100328 }
329 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
330 mode = gtu.ComplianceMode.FP_SPECIAL
331 elif "compliance" in op and "ulp" in op["compliance"]:
332 mode = gtu.ComplianceMode.ULP
333 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
334 elif op["op"] == Op.REDUCE_PRODUCT:
335 mode = gtu.ComplianceMode.REDUCE_PRODUCT
336 else:
337 mode = gtu.ComplianceMode.EXACT
338 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
339
340 return compliance_tens
341
342 # Build Op functions
343 # Create the output tensor (calling OutputShaper as needed)
344 # Do final tweaks to attributes (if necessary for errorIf)
345 # Add Op into graph
346 # Return resulting tensor information or BuildInfo
347
348 class BuildInfo:
349 """Enhanced build information containing result tensor and associated compliance dict."""
350
351 def __init__(self, resultTensor, complianceDict):
352 self.resultTensor = resultTensor
353 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700354
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100355 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
356 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
357
Matthew Haddon848efb42021-09-09 12:30:53 +0100358 # build_placeholder returns an int, ABS/other ops does not
359 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000360 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100361 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000362 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000363 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100364 return result_tens
365
366 # Ensure new output type has correct qinfo
367 if error_name == ErrorIf.WrongOutputType:
368 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000369 qinfo = [
370 TosaQuantGen.getZeroPoint(self, a.dtype),
371 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
372 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100373
374 # Invalidate Input/Output list for error if checks.
375 input_list = [a.name]
376 output_list = [result_tens.name]
377 pCount, cCount = op["operands"]
378 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000379 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
380 self, error_name, input_list, output_list
381 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100382
Les Bell729b0352021-11-24 10:28:21 +0000383 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100384 self.ser,
385 validator_fcns,
386 error_name,
387 op=op,
388 input_dtype=a.dtype,
389 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000390 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000391 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100392 input_list=input_list,
393 output_list=output_list,
394 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000395 ):
396 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100397
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000398 attr = None
399 if op["op"] == Op.NEGATE:
400 attr = ts.TosaSerializerAttribute()
401 attr.NegateAttribute(qinfo[0], qinfo[1])
402
403 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700404 return result_tens
405
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100406 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000407 result_tens = OutputShaper.binaryBroadcastOp(
408 self.ser, self.rng, a, b, error_name
409 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100410
411 # Invalidate Input/Output list for error if checks.
412 input_list = [a.name, b.name]
413 output_list = [result_tens.name]
414 pCount, cCount = op["operands"]
415 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000416 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
417 self, error_name, input_list, output_list
418 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100419
Les Bell729b0352021-11-24 10:28:21 +0000420 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100421 self.ser,
422 validator_fcns,
423 error_name,
424 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000425 input1=a,
426 input2=b,
427 input_dtype=a.dtype,
428 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000429 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100430 input_list=input_list,
431 output_list=output_list,
432 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000433 ):
434 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100435
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000436 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700437 return result_tens
438
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100439 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700440 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000441 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700442 return result_tens
443
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000444 def build_arithmetic_right_shift(
445 self, op, a, b, round, validator_fcns=None, error_name=None
446 ):
447 result_tens = OutputShaper.binaryBroadcastOp(
448 self.ser, self.rng, a, b, error_name
449 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100450
451 # Invalidate Input/Output list for error if checks.
452 input_list = [a.name, b.name]
453 output_list = [result_tens.name]
454 pCount, cCount = op["operands"]
455 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000456 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
457 self, error_name, input_list, output_list
458 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100459
Les Bell729b0352021-11-24 10:28:21 +0000460 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100461 self.ser,
462 validator_fcns,
463 error_name,
464 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000465 input1=a,
466 input2=b,
467 input_dtype=a.dtype,
468 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000469 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100470 input_list=input_list,
471 output_list=output_list,
472 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000473 ):
474 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800475
476 attr = ts.TosaSerializerAttribute()
477 attr.ArithmeticRightShiftAttribute(round)
478
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000479 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800480 return result_tens
481
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100482 def build_mul(
483 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
484 ):
485 assert len(inputs) == 2
486 a, b = inputs
487 shift = args_dict["shift"]
488
489 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000490 self.ser, self.rng, a, b, error_name
491 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700492
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100493 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100494 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100495 result_tensor.setDtype(DType.INT32)
496
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100497 if error_name == ErrorIf.WrongOutputType:
498 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
499 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100500 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100501
502 # Invalidate Input/Output list for error if checks.
503 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100504 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100505 pCount, cCount = op["operands"]
506 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000507 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
508 self, error_name, input_list, output_list
509 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100510
Les Bell729b0352021-11-24 10:28:21 +0000511 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100512 self.ser,
513 validator_fcns,
514 error_name,
515 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000516 input1=a,
517 input2=b,
518 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100519 output_dtype=result_tensor.dtype,
520 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100521 input_list=input_list,
522 output_list=output_list,
523 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000524 ):
525 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700526
Kevin Chengaee1fac2020-11-11 13:54:06 -0800527 attr = ts.TosaSerializerAttribute()
528 attr.MulAttribute(shift)
529
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000530 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100531
532 compliance = self.tensorComplianceMetaData(
533 op, a.dtype, args_dict, result_tensor, error_name
534 )
535
536 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700537
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100538 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
539 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700540
Kevin Chengfe392ce2021-10-18 21:51:55 +0000541 attr = ts.TosaSerializerAttribute()
542 attr.TableAttribute(table)
543
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100544 # Invalidate Input/Output list for error if checks.
545 input_list = [a.name]
546 output_list = [result_tens.name]
547 pCount, cCount = op["operands"]
548 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000549 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
550 self, error_name, input_list, output_list
551 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100552
Les Bell729b0352021-11-24 10:28:21 +0000553 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100554 self.ser,
555 validator_fcns,
556 error_name,
557 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000558 input_shape=a.shape,
559 input_dtype=a.dtype,
560 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000561 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100562 input_list=input_list,
563 output_list=output_list,
564 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000565 ):
566 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100567
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000568 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700569
570 return result_tens
571
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100572 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
573 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
574
575 # Invalidate Input/Output list for error if checks.
576 input_list = [cond.name, a.name, b.name]
577 output_list = [result_tens.name]
578 pCount, cCount = op["operands"]
579 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000580 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
581 self, error_name, input_list, output_list
582 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100583
Les Bell729b0352021-11-24 10:28:21 +0000584 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100585 self.ser,
586 validator_fcns,
587 error_name,
588 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000589 input1=cond,
590 input2=a,
591 input3=b,
592 input_shape=a.shape,
593 input_dtype=a.dtype,
594 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000595 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100596 input_list=input_list,
597 output_list=output_list,
598 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000599 ):
600 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100601
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000602 self.ser.addOperator(
603 op["op"],
604 input_list,
605 output_list,
606 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700607 return result_tens
608
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100609 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000610 result_tens = OutputShaper.binaryComparisonOp(
611 self.ser, self.rng, a, b, error_name
612 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100613
614 # Invalidate Input/Output list for error if checks.
615 input_list = [a.name, b.name]
616 output_list = [result_tens.name]
617 pCount, cCount = op["operands"]
618 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000619 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
620 self, error_name, input_list, output_list
621 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622
Les Bell729b0352021-11-24 10:28:21 +0000623 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100624 self.ser,
625 validator_fcns,
626 error_name,
627 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000628 input1=a,
629 input2=b,
630 input_shape=a.shape,
631 input_dtype=a.dtype,
632 output_shape=result_tens.shape,
633 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000634 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100635 input_list=input_list,
636 output_list=output_list,
637 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000638 ):
639 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100640
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000641 self.ser.addOperator(
642 op["op"],
643 input_list,
644 output_list,
645 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700646 return result_tens
647
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100648 def build_argmax(self, op, a, axis, validator_fcns, error_name):
649 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
650
651 # Invalidate Input/Output list for error if checks.
652 input_list = [a.name]
653 output_list = [result_tens.name]
654 pCount, cCount = op["operands"]
655 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000656 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
657 self, error_name, input_list, output_list
658 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100659
Les Bell729b0352021-11-24 10:28:21 +0000660 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100661 self.ser,
662 validator_fcns,
663 error_name,
664 op=op,
665 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000666 input_shape=a.shape,
667 input_dtype=a.dtype,
668 output_shape=result_tens.shape,
669 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000670 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100671 input_list=input_list,
672 output_list=output_list,
673 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000674 ):
675 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700676
677 attr = ts.TosaSerializerAttribute()
678 attr.AxisAttribute(axis)
679
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000680 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700681 return result_tens
682
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000683 def build_pool2d(
684 self,
685 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100686 inputs,
687 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000688 validator_fcns=None,
689 error_name=None,
690 qinfo=None,
691 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100692 assert len(inputs) == 1
693 input = inputs[0]
694 # max_pool has no accum_dtype
695 accum_dtype = (
696 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
697 )
698 stride = args_dict["stride"]
699 pad = args_dict["pad"]
700 kernel = args_dict["kernel"]
701
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000702 result_tens = OutputShaper.pool2dOp(
703 self.ser, self.rng, input, kernel, stride, pad, error_name
704 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100705
706 # Ensure new output type has correct qinfo
707 if error_name == ErrorIf.WrongInputType:
708 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000709 qinfo = [
710 TosaQuantGen.getZeroPoint(self, input.dtype),
711 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
712 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100713
714 # Invalidate Input/Output list for error if checks.
715 input_list = [input.name]
716 output_list = [result_tens.name]
717 pCount, cCount = op["operands"]
718 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000719 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
720 self, error_name, input_list, output_list
721 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100722
Les Bell729b0352021-11-24 10:28:21 +0000723 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100724 self.ser,
725 validator_fcns,
726 error_name,
727 op=op,
728 input_shape=input.shape,
729 input_dtype=input.dtype,
730 output_shape=result_tens.shape,
731 output_dtype=result_tens.dtype,
732 kernel=kernel,
733 stride=stride,
734 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000735 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000736 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100737 input_list=input_list,
738 output_list=output_list,
739 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000740 ):
741 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700742
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000743 if qinfo is None:
744 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700745
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000746 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100747 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000748
749 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700750 return result_tens
751
James Ward8b390432022-08-12 20:48:56 +0100752 def build_maxpool2d(
753 self,
754 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100755 inputs,
756 args_dict,
James Ward8b390432022-08-12 20:48:56 +0100757 validator_fcns=None,
758 error_name=None,
759 qinfo=None,
760 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100761 result_tensor = self.build_pool2d(
James Ward8b390432022-08-12 20:48:56 +0100762 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100763 inputs,
764 args_dict,
James Ward8b390432022-08-12 20:48:56 +0100765 validator_fcns,
766 error_name,
767 qinfo,
768 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100769 compliance = self.tensorComplianceMetaData(
770 op, inputs[0].dtype, args_dict, result_tensor, error_name
771 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100772
773 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100774
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000775 def build_conv2d(
776 self,
777 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100778 inputs,
779 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000780 validator_fcns=None,
781 error_name=None,
782 qinfo=None,
783 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100784 assert len(inputs) == 3
785 ifm, filter, bias = inputs
786 accum_dtype = args_dict["acc_type"]
787 strides = args_dict["stride"]
788 padding = args_dict["pad"]
789 dilations = args_dict["dilation"]
790
Kevin Cheng550ccc52021-03-03 11:21:43 -0800791 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100792 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100793 self.ser,
794 self.rng,
795 ifm,
796 filter,
797 accum_dtype,
798 strides,
799 padding,
800 dilations,
801 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000802 )
803
804 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000805 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
806 DType.INT8,
807 DType.UINT8,
808 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000809 qinfo = [
810 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100811 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000812 ]
Les Bell0e027d42021-11-09 14:42:14 +0000813
814 # Invalidate Input/Output list for error_if checks.
815 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100816 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000817 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000818 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
819 self, error_name, input_list, output_list
820 )
Les Bell0e027d42021-11-09 14:42:14 +0000821
Les Bell729b0352021-11-24 10:28:21 +0000822 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000823 self.ser,
824 validator_fcns,
825 error_name,
826 op=op,
827 input_dtype=ifm.dtype,
828 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100829 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000830 qinfo=qinfo,
831 input_list=input_list,
832 num_operands=num_operands,
833 output_list=output_list,
834 pad=padding,
835 stride=strides,
836 dilation=dilations,
837 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100838 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100839 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000840 ):
841 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700842
843 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000844 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700845
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000846 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100847
848 compliance = self.tensorComplianceMetaData(
849 op, ifm.dtype, args_dict, result_tensor, error_name
850 )
851
852 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700853
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000854 def build_conv3d(
855 self,
856 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100857 inputs,
858 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000859 validator_fcns=None,
860 error_name=None,
861 qinfo=None,
862 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100863 assert len(inputs) == 3
864 ifm, filter, bias = inputs
865 accum_dtype = args_dict["acc_type"]
866 strides = args_dict["stride"]
867 padding = args_dict["pad"]
868 dilations = args_dict["dilation"]
869
Kevin Cheng1533b852021-09-01 12:51:58 -0700870 assert len(padding) == 6
871 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100872 self.ser,
873 self.rng,
874 ifm,
875 filter,
876 accum_dtype,
877 strides,
878 padding,
879 dilations,
880 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000881 )
882
883 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000884 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
885 DType.INT8,
886 DType.UINT8,
887 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000888 qinfo = [
889 TosaQuantGen.getZeroPoint(self, ifm.dtype),
890 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
891 ]
Les Bell0e027d42021-11-09 14:42:14 +0000892
893 # Invalidate Input/Output list for error_if checks.
894 input_list = [ifm.name, filter.name, bias.name]
895 output_list = [result_tens.name]
896 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000897 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
898 self, error_name, input_list, output_list
899 )
Les Bell0e027d42021-11-09 14:42:14 +0000900
Les Bell729b0352021-11-24 10:28:21 +0000901 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000902 self.ser,
903 validator_fcns,
904 error_name,
905 op=op,
906 input_dtype=ifm.dtype,
907 weight_dtype=filter.dtype,
908 output_dtype=result_tens.dtype,
909 qinfo=qinfo,
910 input_list=input_list,
911 num_operands=num_operands,
912 output_list=output_list,
913 pad=padding,
914 stride=strides,
915 dilation=dilations,
916 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100917 weight_shape=filter.shape,
918 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000919 ):
920 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700921
922 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000923 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700924
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000925 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700926 return result_tens
927
Kevin Cheng550ccc52021-03-03 11:21:43 -0800928 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000929 self,
930 op,
931 ifm,
932 filter,
933 bias,
James Ward8b390432022-08-12 20:48:56 +0100934 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000935 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700936 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000937 output_shape,
938 validator_fcns=None,
939 error_name=None,
940 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800941 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700942 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000943 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100944 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000945 )
Les Bell0e027d42021-11-09 14:42:14 +0000946
947 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000948 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
949 DType.INT8,
950 DType.UINT8,
951 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000952 qinfo = [
953 TosaQuantGen.getZeroPoint(self, ifm.dtype),
954 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
955 ]
Les Bell0e027d42021-11-09 14:42:14 +0000956
957 # Invalidate Input/Output list for error_if checks.
958 input_list = [ifm.name, filter.name, bias.name]
959 output_list = [result_tens.name]
960 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000961 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
962 self, error_name, input_list, output_list
963 )
Les Bell0e027d42021-11-09 14:42:14 +0000964
Les Bell729b0352021-11-24 10:28:21 +0000965 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000966 self.ser,
967 validator_fcns,
968 error_name,
969 op=op,
970 input_dtype=ifm.dtype,
971 weight_dtype=filter.dtype,
972 output_dtype=result_tens.dtype,
973 qinfo=qinfo,
974 input_list=input_list,
975 num_operands=num_operands,
976 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700977 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000978 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000979 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100980 weight_shape=filter.shape,
981 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000982 ):
983 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700984
985 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000986 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700987
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000988 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700989 return result_tens
990
Kevin Cheng550ccc52021-03-03 11:21:43 -0800991 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000992 self,
993 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100994 inputs,
995 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000996 validator_fcns=None,
997 error_name=None,
998 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800999 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001000 assert len(inputs) == 3
1001 ifm, filter, bias = inputs
1002 accum_dtype = args_dict["acc_type"]
1003 strides = args_dict["stride"]
1004 padding = args_dict["pad"]
1005 dilations = args_dict["dilation"]
1006
Kevin Cheng550ccc52021-03-03 11:21:43 -08001007 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001008 self.ser,
1009 self.rng,
1010 ifm,
1011 filter,
1012 accum_dtype,
1013 strides,
1014 padding,
1015 dilations,
1016 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001017 )
1018
1019 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001020 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1021 DType.INT8,
1022 DType.UINT8,
1023 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001024 qinfo = [
1025 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1026 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1027 ]
Les Bell0e027d42021-11-09 14:42:14 +00001028
1029 # Invalidate Input/Output list for error_if checks.
1030 input_list = [ifm.name, filter.name, bias.name]
1031 output_list = [result_tens.name]
1032 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001033 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1034 self, error_name, input_list, output_list
1035 )
Les Bell0e027d42021-11-09 14:42:14 +00001036
Les Bell729b0352021-11-24 10:28:21 +00001037 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001038 self.ser,
1039 validator_fcns,
1040 error_name,
1041 op=op,
1042 input_dtype=ifm.dtype,
1043 weight_dtype=filter.dtype,
1044 output_dtype=result_tens.dtype,
1045 qinfo=qinfo,
1046 input_list=input_list,
1047 num_operands=num_operands,
1048 output_list=output_list,
1049 pad=padding,
1050 stride=strides,
1051 dilation=dilations,
1052 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001053 weight_shape=filter.shape,
1054 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001055 ):
1056 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001057
1058 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001059 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07001060
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001061 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001062 return result_tens
1063
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001064 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001065 self,
1066 op,
1067 ifm,
1068 filter,
1069 bias,
1070 accum_dtype,
1071 validator_fcns=None,
1072 error_name=None,
1073 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001074 ):
1075 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001076 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001077 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001078
1079 # Invalidate Input/Output list for error if checks.
1080 input_list = [ifm.name, filter.name, bias.name]
1081 output_list = [result_tens.name]
1082 pCount, cCount = op["operands"]
1083 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001084 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1085 self, error_name, input_list, output_list
1086 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001087
Les Bell729b0352021-11-24 10:28:21 +00001088 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001089 self.ser,
1090 validator_fcns,
1091 error_name,
1092 op=op,
1093 input_shape=ifm.shape,
1094 input_dtype=ifm.dtype,
1095 weight_dtype=filter.dtype,
1096 output_shape=result_tens.shape,
1097 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001098 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001099 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001100 input_list=input_list,
1101 output_list=output_list,
1102 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001103 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001104 ):
1105 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001106
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001107 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001108 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001109
1110 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001111 return result_tens
1112
James Ward8b390432022-08-12 20:48:56 +01001113 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001114 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001115 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001116 assert len(inputs) == 2
1117 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001118 accum_dtype = args_dict["acc_type"]
1119 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001120 self.ser, self.rng, a, b, accum_dtype, error_name
1121 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001122
1123 # Invalidate Input/Output list for error if checks.
1124 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001125 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001126 pCount, cCount = op["operands"]
1127 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001128 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1129 self, error_name, input_list, output_list
1130 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001131
Les Bell729b0352021-11-24 10:28:21 +00001132 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001133 self.ser,
1134 validator_fcns,
1135 error_name,
1136 op=op,
1137 input_shape=a.shape,
1138 input_dtype=a.dtype,
1139 input2_shape=b.shape,
1140 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001141 output_shape=result_tensor.shape,
1142 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001143 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001144 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001145 input_list=input_list,
1146 output_list=output_list,
1147 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001148 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001149 ):
1150 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001151
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001152 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001153 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001154
1155 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001156
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001157 compliance = self.tensorComplianceMetaData(
1158 op, a.dtype, args_dict, result_tensor, error_name
1159 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001160
1161 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001162
Matthew Haddond6ce7252021-09-29 15:35:44 +01001163 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1164 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1165
1166 # Invalidate Input/Output list for error if checks.
1167 input_list = [a.name]
1168 output_list = [result_tens.name]
1169 pCount, cCount = op["operands"]
1170 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001171 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1172 self, error_name, input_list, output_list
1173 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001174
Les Bell729b0352021-11-24 10:28:21 +00001175 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001176 self.ser,
1177 validator_fcns,
1178 error_name,
1179 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001180 axis=axis,
1181 input_shape=a.shape,
1182 output_shape=result_tens.shape,
1183 input_dtype=a.dtype,
1184 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001185 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001186 input_list=input_list,
1187 output_list=output_list,
1188 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001189 ):
1190 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001191
1192 attr = ts.TosaSerializerAttribute()
1193 attr.AxisAttribute(axis)
1194
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001195 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001196 return result_tens
1197
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001198 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1199 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001200
Jeremy Johnson18e26662021-07-22 16:15:29 +01001201 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001202
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001203 if error_name == ErrorIf.MaxSmallerMin:
1204 # Make sure the numbers are different to invoke this error
1205 while v[0] == v[1]:
1206 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1207 max_val = min(v)
1208 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001209 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001210 max_val = max(v)
1211 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001212
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001213 # Invalidate Input/Output list for error if checks.
1214 input_list = [a.name]
1215 output_list = [result_tens.name]
1216 pCount, cCount = op["operands"]
1217 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001218 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1219 self, error_name, input_list, output_list
1220 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001221
Les Bell729b0352021-11-24 10:28:21 +00001222 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001223 self.ser,
1224 validator_fcns,
1225 error_name,
1226 op=op,
1227 max_val=max_val,
1228 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001229 input_shape=a.shape,
1230 output_shape=result_tens.shape,
1231 input_dtype=a.dtype,
1232 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001233 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001234 input_list=input_list,
1235 output_list=output_list,
1236 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001237 ):
1238 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001239
1240 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001241 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1242 if a.dtype == DType.FP16:
1243 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1244 min_val = min_val.astype(np.float32)
1245 max_val = max_val.astype(np.float32)
1246
1247 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001248 else:
James Ward34071252022-12-07 15:48:47 +00001249 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001250
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001251 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001252 return result_tens
1253
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001254 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1255 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001256 attr = ts.TosaSerializerAttribute()
1257
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001258 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001259
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001260 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001261 return result_tens
1262
1263 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001264 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1265 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001266
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001267 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001268 return result_tens
1269
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001270 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1271 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1272
1273 # Invalidate Input/Output list for error if checks.
1274 input_list = [a.name]
1275 output_list = [result_tens.name]
1276 pCount, cCount = op["operands"]
1277 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001278 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1279 self, error_name, input_list, output_list
1280 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001281
Les Bell729b0352021-11-24 10:28:21 +00001282 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001283 self.ser,
1284 validator_fcns,
1285 error_name,
1286 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001287 input_shape=a.shape,
1288 output_shape=result_tens.shape,
1289 input_dtype=a.dtype,
1290 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001291 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001292 input_list=input_list,
1293 output_list=output_list,
1294 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001295 ):
1296 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001297
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001298 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001299 return result_tens
1300
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001301 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1302 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1303
1304 # Invalidate Input/Output list for error if checks.
1305 input_list = [a.name]
1306 output_list = [result_tens.name]
1307 pCount, cCount = op["operands"]
1308 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001309 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1310 self, error_name, input_list, output_list
1311 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001312
Les Bell729b0352021-11-24 10:28:21 +00001313 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001314 self.ser,
1315 validator_fcns,
1316 error_name,
1317 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001318 input_shape=a.shape,
1319 output_shape=result_tens.shape,
1320 input_dtype=a.dtype,
1321 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001322 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001323 input_list=input_list,
1324 output_list=output_list,
1325 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001326 ):
1327 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001328
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001329 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001330 return result_tens
1331
Won Jeon78155c62023-06-10 00:20:04 +00001332 def build_erf(self, op, a, validator_fcns=None, error_name=None):
1333 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1334
1335 # Invalidate Input/Output list for error if checks.
1336 input_list = [a.name]
1337 output_list = [result_tens.name]
1338 pCount, cCount = op["operands"]
1339 num_operands = pCount + cCount
1340 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1341 self, error_name, input_list, output_list
1342 )
1343
1344 if not TosaErrorValidator.evValidateErrorIfs(
1345 self.ser,
1346 validator_fcns,
1347 error_name,
1348 op=op,
1349 input_shape=a.shape,
1350 output_shape=result_tens.shape,
1351 input_dtype=a.dtype,
1352 output_dtype=result_tens.dtype,
1353 result_tensors=[result_tens],
1354 input_list=input_list,
1355 output_list=output_list,
1356 num_operands=num_operands,
1357 ):
1358 return None
1359
1360 self.ser.addOperator(op["op"], input_list, output_list)
1361 return result_tens
1362
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001363 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1364 if error_name != ErrorIf.WrongInputType:
1365 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001366
1367 # To store variable length list of input tensors we need to store axis along with it
1368 axis = a[-1]
1369 a = a[:-1]
1370
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001371 result_tens = OutputShaper.concatOp(
1372 self.ser, self.rng, axis, *a, error_name=error_name
1373 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001374
Matthew Haddon818ab902021-07-27 09:12:49 +01001375 input_tensor_names = []
1376 for tensor in a:
1377 input_tensor_names.append(tensor.name)
1378
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001379 # Invalidate Input/Output list for error if checks.
1380 input_list = input_tensor_names
1381 output_list = [result_tens.name]
1382 pCount, cCount = op["operands"]
1383 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001384 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1385 self, error_name, input_list, output_list
1386 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001387
Les Bell729b0352021-11-24 10:28:21 +00001388 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001389 self.ser,
1390 validator_fcns,
1391 error_name,
1392 op=op,
1393 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001394 input_shape=a[0].shape,
1395 output_shape=result_tens.shape,
1396 input_dtype=a[0].dtype,
1397 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001398 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001399 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001400 input_list=input_list,
1401 output_list=output_list,
1402 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001403 ):
1404 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001405
1406 attr = ts.TosaSerializerAttribute()
1407 attr.AxisAttribute(axis)
1408
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001409 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001410 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001411
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001412 def build_pad(
1413 self,
1414 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001415 inputs,
1416 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001417 validator_fcns=None,
1418 error_name=None,
1419 qinfo=None,
1420 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001421 assert len(inputs) == 1
1422 a = inputs[0]
1423 padding = args_dict["pad"]
1424 pad_const_int = args_dict["pad_const_int"]
1425 pad_const_float = args_dict["pad_const_fp"]
1426
1427 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001428
Kevin Chengfe392ce2021-10-18 21:51:55 +00001429 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001430 attr.PadAttribute(
1431 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1432 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001433
Matthew Haddone807aae2021-10-11 18:12:58 +01001434 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001435 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001436 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001437 pCount, cCount = op["operands"]
1438 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001439 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1440 self, error_name, input_list, output_list
1441 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001442
Les Bell729b0352021-11-24 10:28:21 +00001443 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001444 self.ser,
1445 validator_fcns,
1446 error_name,
1447 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001448 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001449 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001450 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001451 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001452 pad=padding,
1453 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001454 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001455 input_list=input_list,
1456 output_list=output_list,
1457 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001458 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001459 ):
1460 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001461
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001462 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001463
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001464 compliance = self.tensorComplianceMetaData(
1465 op, a.dtype, args_dict, result_tensor, error_name
1466 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001467
1468 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001469
Won Jeona21b2e82023-08-10 10:33:01 +00001470 def build_dim(
1471 self,
1472 op,
1473 a,
1474 axis,
1475 validator_fcns=None,
1476 error_name=None,
1477 qinfo=None,
1478 ):
1479 result_tens = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
1480
1481 # Invalidate Input/Output list for error if checks.
1482 input_list = [a.name]
1483 output_list = [result_tens.name]
1484 pCount, cCount = op["operands"]
1485 num_operands = pCount + cCount
1486 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1487 self, error_name, input_list, output_list
1488 )
1489
1490 if not TosaErrorValidator.evValidateErrorIfs(
1491 self.ser,
1492 validator_fcns,
1493 error_name,
1494 op=op,
1495 axis=axis,
1496 input_shape=a.shape,
1497 input_dtype=a.dtype,
1498 output_shape=result_tens.shape,
1499 output_dtype=result_tens.dtype,
1500 result_tensors=[result_tens],
1501 input_list=input_list,
1502 output_list=output_list,
1503 num_operands=num_operands,
1504 ):
1505 return None
1506
1507 attr = ts.TosaSerializerAttribute()
1508 attr.AxisAttribute(axis)
1509
1510 self.ser.addOperator(op["op"], input_list, output_list, attr)
1511 return result_tens
1512
Matthew Haddone807aae2021-10-11 18:12:58 +01001513 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001514 result_tens = OutputShaper.reshapeOp(
1515 self.ser, self.rng, a, newShape, error_name
1516 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001517
1518 # Invalidate Input/Output list for error if checks.
1519 input_list = [a.name]
1520 output_list = [result_tens.name]
1521 pCount, cCount = op["operands"]
1522 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001523 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1524 self, error_name, input_list, output_list
1525 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001526
Les Bell729b0352021-11-24 10:28:21 +00001527 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001528 self.ser,
1529 validator_fcns,
1530 error_name,
1531 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001532 input_shape=a.shape,
1533 output_shape=result_tens.shape,
1534 input_dtype=a.dtype,
1535 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001536 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001537 input_list=input_list,
1538 output_list=output_list,
1539 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001540 ):
1541 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001542
1543 attr = ts.TosaSerializerAttribute()
1544 attr.ReshapeAttribute(newShape)
1545
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001546 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001547 return result_tens
1548
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001549 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1550 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1551
1552 # Invalidate Input/Output list for error if checks.
1553 input_list = [a.name]
1554 output_list = [result_tens.name]
1555 pCount, cCount = op["operands"]
1556 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001557 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1558 self, error_name, input_list, output_list
1559 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001560
Les Bell729b0352021-11-24 10:28:21 +00001561 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001562 self.ser,
1563 validator_fcns,
1564 error_name,
1565 op=op,
1566 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001567 input_shape=a.shape,
1568 output_shape=result_tens.shape,
1569 input_dtype=a.dtype,
1570 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001571 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001572 input_list=input_list,
1573 output_list=output_list,
1574 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001575 ):
1576 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001577
1578 attr = ts.TosaSerializerAttribute()
1579 attr.AxisAttribute(axis)
1580
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001581 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001582 return result_tens
1583
Matthew Haddone807aae2021-10-11 18:12:58 +01001584 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1585 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001586
Kevin Chengfe392ce2021-10-18 21:51:55 +00001587 attr = ts.TosaSerializerAttribute()
1588 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001589
Matthew Haddone807aae2021-10-11 18:12:58 +01001590 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001591 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001592 output_list = [result_tens.name]
1593 pCount, cCount = op["operands"]
1594 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001595 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1596 self, error_name, input_list, output_list
1597 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001598
Les Bell729b0352021-11-24 10:28:21 +00001599 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001600 self.ser,
1601 validator_fcns,
1602 error_name,
1603 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001604 input_shape=a.shape,
1605 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001606 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001607 input_dtype=a.dtype,
1608 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001609 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001610 input_list=input_list,
1611 output_list=output_list,
1612 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001613 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001614 ):
1615 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001616
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001617 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001618 return result_tens
1619
Matthew Haddone807aae2021-10-11 18:12:58 +01001620 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001621 result_tens = OutputShaper.sliceOp(
1622 self.ser, self.rng, a, start, size, error_name
1623 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001624
1625 # Invalidate Input/Output list for error if checks.
1626 input_list = [a.name]
1627 output_list = [result_tens.name]
1628 pCount, cCount = op["operands"]
1629 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001630 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1631 self, error_name, input_list, output_list
1632 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001633
Les Bell729b0352021-11-24 10:28:21 +00001634 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001635 self.ser,
1636 validator_fcns,
1637 error_name,
1638 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001639 input_shape=a.shape,
1640 output_shape=result_tens.shape,
1641 input_dtype=a.dtype,
1642 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001643 start=start,
1644 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001645 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001646 input_list=input_list,
1647 output_list=output_list,
1648 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001649 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001650 ):
1651 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001652
1653 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001654 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001655
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001656 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001657 return result_tens
1658
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001659 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1660 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1661
1662 # Invalidate Input/Output list for error if checks.
1663 input_list = [a.name]
1664 output_list = [result_tens.name]
1665 pCount, cCount = op["operands"]
1666 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001667 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1668 self, error_name, input_list, output_list
1669 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001670
Les Bell729b0352021-11-24 10:28:21 +00001671 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001672 self.ser,
1673 validator_fcns,
1674 error_name,
1675 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001676 input_shape=a.shape,
1677 output_shape=result_tens.shape,
1678 input_dtype=a.dtype,
1679 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001680 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001681 input_list=input_list,
1682 output_list=output_list,
1683 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001684 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001685 ):
1686 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001687
1688 attr = ts.TosaSerializerAttribute()
1689 attr.TileAttribute(multiples)
1690
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001691 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001692 return result_tens
1693
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001694 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001695
1696 # Create a new indicies tensor
1697 # here with data that doesn't exceed the dimensions of the values tensor
1698
Kevin Cheng550ccc52021-03-03 11:21:43 -08001699 K = values.shape[1] # K
1700 W = self.randInt(
1701 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1702 ) # W
1703 indicies_arr = np.int32(
1704 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1705 ) # (N, W)
1706 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001707
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001708 result_tens = OutputShaper.gatherOp(
1709 self.ser, self.rng, values, indicies, error_name
1710 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001711
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001712 # Invalidate Input/Output list for error if checks.
1713 input_list = [values.name, indicies.name]
1714 output_list = [result_tens.name]
1715 pCount, cCount = op["operands"]
1716 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001717 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1718 self, error_name, input_list, output_list
1719 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001720
Les Bell729b0352021-11-24 10:28:21 +00001721 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001722 self.ser,
1723 validator_fcns,
1724 error_name,
1725 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001726 input_shape=values.shape,
1727 output_shape=result_tens.shape,
1728 input_dtype=values.dtype,
1729 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001730 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001731 input_list=input_list,
1732 output_list=output_list,
1733 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001734 ):
1735 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001736
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001737 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001738
1739 return result_tens
1740
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001741 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001742
1743 # Create a new indicies tensor
1744 # here with data that doesn't exceed the dimensions of the values_in tensor
1745
Kevin Cheng550ccc52021-03-03 11:21:43 -08001746 K = values_in.shape[1] # K
1747 W = input.shape[1] # W
1748 indicies_arr = np.int32(
1749 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1750 ) # (N, W)
1751 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001752
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001753 result_tens = OutputShaper.scatterOp(
1754 self.ser, self.rng, values_in, indicies, input, error_name
1755 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001756
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001757 # Invalidate Input/Output list for error if checks.
1758 input_list = [values_in.name, indicies.name, input.name]
1759 output_list = [result_tens.name]
1760 pCount, cCount = op["operands"]
1761 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001762 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1763 self, error_name, input_list, output_list
1764 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001765
Les Bell729b0352021-11-24 10:28:21 +00001766 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001767 self.ser,
1768 validator_fcns,
1769 error_name,
1770 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001771 input_shape=values_in.shape,
1772 output_shape=result_tens.shape,
1773 input_dtype=values_in.dtype,
1774 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001775 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001776 input_list=input_list,
1777 output_list=output_list,
1778 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001779 ):
1780 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001781
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001782 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001783
Kevin Cheng77d0f762020-11-24 10:26:32 -08001784 return result_tens
1785
Kevin Cheng550ccc52021-03-03 11:21:43 -08001786 def build_resize(
1787 self,
1788 op,
1789 input,
1790 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001791 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001792 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001793 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001794 input_dtype,
1795 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001796 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001797 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001798 ):
1799 result_tens = OutputShaper.resizeOp(
1800 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001801 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001802 input,
1803 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001804 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001805 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001806 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001807 input_dtype,
1808 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001809 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001810 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001811
Matthew Haddon848efb42021-09-09 12:30:53 +01001812 # Invalidate Input/Output list for error if checks.
1813 input_list = [input.name]
1814 output_list = [result_tens.name]
1815 pCount, cCount = op["operands"]
1816 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001817 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1818 self, error_name, input_list, output_list
1819 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001820
Les Bell729b0352021-11-24 10:28:21 +00001821 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001822 self.ser,
1823 validator_fcns,
1824 error_name,
1825 op=op,
1826 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001827 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001828 input_dtype=input_dtype,
1829 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001830 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001831 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001832 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001833 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001834 input_list=input_list,
1835 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001836 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001837 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001838 ):
1839 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001840
Eric Kunzee5e26762020-10-13 16:11:07 -07001841 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001842
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001843 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001844
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001845 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001846 return result_tens
1847
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001848 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1849 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1850 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001851 self.ser.addOperator(
1852 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1853 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001854 return result_tens
1855
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001856 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001857 self.ser.addOutputTensor(val)
1858 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001859
1860 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001861 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001862 result_tens = OutputShaper.typeConversionOp(
1863 self.ser, self.rng, val, out_dtype, error_name
1864 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001865
1866 # Invalidate Input/Output list for error if checks.
1867 input_list = [val.name]
1868 output_list = [result_tens.name]
1869 pCount, cCount = op["operands"]
1870 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001871 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1872 self, error_name, input_list, output_list
1873 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001874
Les Bell729b0352021-11-24 10:28:21 +00001875 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001876 self.ser,
1877 validator_fcns,
1878 error_name,
1879 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001880 input_shape=val.shape,
1881 output_shape=result_tens.shape,
1882 input_dtype=val.dtype,
1883 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001884 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001885 input_list=input_list,
1886 output_list=output_list,
1887 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001888 ):
1889 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001890
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001891 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001892 return result_tens
1893
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001894 def build_rescale(
1895 self,
1896 op,
1897 val,
1898 out_dtype,
1899 scale32,
1900 double_round,
1901 per_channel,
1902 validator_fcns,
1903 error_name,
1904 ):
1905 result_tens = OutputShaper.typeConversionOp(
1906 self.ser, self.rng, val, out_dtype, error_name
1907 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001908
1909 if per_channel:
1910 nc = val.shape[-1]
1911 else:
1912 nc = 1
1913
1914 in_type_width = self.typeWidth(val.dtype)
1915 out_type_width = self.typeWidth(out_dtype)
1916
Kevin Cheng3a478572021-01-22 17:21:02 -08001917 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001918 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001919 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001920 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001921 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001922 in_type_width += 1
1923 elif error_name in [
1924 ErrorIf.InputZeroPointNotZero,
1925 ErrorIf.U16InputZeroPointNotValid,
1926 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001927 input_zp = self.randInt(-128, 128)
1928 if input_zp == 0:
1929 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001930 in_type_width += 1
1931 elif val.dtype == DType.UINT16:
1932 # Must come after ErrorIf.U16InputZeroPointNotValid check
1933 input_zp = self.rng.choice([0, 32768])
1934 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001935 else:
1936 input_zp = 0
1937
Kevin Cheng3a478572021-01-22 17:21:02 -08001938 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001939 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001940 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001941 elif out_dtype == DType.UINT8:
1942 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001943 out_type_width += 1
1944 elif error_name in [
1945 ErrorIf.OutputZeroPointNotZero,
1946 ErrorIf.U16OutputZeroPointNotValid,
1947 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001948 output_zp = self.randInt(-128, 128)
1949 if output_zp == 0:
1950 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001951 out_type_width += 1
1952 elif out_dtype == DType.UINT16:
1953 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1954 output_zp = self.rng.choice([0, 32768])
1955 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001956 else:
1957 output_zp = 0
1958
1959 # Calculate scale based on:
1960 # scale = a *(2^output_width)/(2^input_width))
1961
1962 a = np.float32(self.rng.random(size=[nc]))
1963 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1964
1965 if scale32:
1966 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001967 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001968 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1969 else:
1970 # Cap the scaling at 2^15 - 1 for scale16
1971 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1972
Kevin Cheng550ccc52021-03-03 11:21:43 -08001973 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001974
1975 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1976 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001977 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1978 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001979
1980 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001981 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1982 scale_arr[i], scale32
1983 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001984 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1985 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001986
Kevin Cheng550ccc52021-03-03 11:21:43 -08001987 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001988 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001989 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001990 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001991 assert val.placeholderFilename
1992 values = np.load(
1993 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1994 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001995 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1996 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1997 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1998 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001999 if not np.all(np.array_equal(values, val_adj)):
2000 # Values changed so overwrite file with new values
2001 np.save(
2002 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2003 val_adj,
2004 False,
2005 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002006
Matthew Haddonc2025212021-10-08 21:21:05 +01002007 # Invalidate Input/Output list for error if checks.
2008 input_list = [val.name]
2009 output_list = [result_tens.name]
2010 pCount, cCount = op["operands"]
2011 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002012 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2013 self, error_name, input_list, output_list
2014 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002015
2016 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002017 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002018 self.ser,
2019 validator_fcns,
2020 error_name,
2021 op=op,
2022 input_dtype=val.dtype,
2023 output_dtype=out_dtype,
2024 input_shape=val.shape,
2025 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002026 scale32=scale32,
2027 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002028 input_list=input_list,
2029 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002030 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002031 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002032 ):
2033 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002034
Eric Kunzee5e26762020-10-13 16:11:07 -07002035 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002036 attr.RescaleAttribute(
2037 input_zp,
2038 output_zp,
2039 multiplier_arr,
2040 shift_arr,
2041 scale32,
2042 double_round,
2043 per_channel,
2044 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002045
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002046 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002047 return result_tens
2048
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002049 def _get_condition_tensor(self, op, cond, error_name):
2050 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002051 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002052 else:
2053 cond_type = DType.BOOL
2054 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2055 choice = self.rng.choice([1, 2])
2056 if choice == 1:
2057 cond_shape = [2]
2058 else:
2059 cond_shape = [1, 2]
2060 else:
2061 # Must be of size 1 (rank 0)
2062 cond_shape = []
2063 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2064 return cond_tens
2065
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002066 def build_cond_if_const(
2067 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2068 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002069 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002070 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002071 # and fill them with const nodes for the body.
2072
2073 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002074 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002075
2076 # Make then/else tensors
2077 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002078
2079 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002080 if error_name in [
2081 ErrorIf.CondIfOutputListThenGraphMismatch,
2082 ErrorIf.CondIfOutputListElseGraphMismatch,
2083 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002084 incorrect_shape = deepcopy(then_tens.shape)
2085 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002086 incorrect_shape[i] += (
2087 self.rng.choice([-3, -2, 2, 3])
2088 if incorrect_shape[i] > 3
2089 else self.rng.choice([1, 2, 4])
2090 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002091 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2092
Jeremy Johnson18e26662021-07-22 16:15:29 +01002093 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2094 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002095
2096 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002097 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002098
2099 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002100 then_block = "THEN_BLOCK"
2101 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002102 attr = ts.TosaSerializerAttribute()
2103 attr.CondIfAttribute(then_block, else_block)
2104
2105 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002106 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002107
Jerry Ge9e94af82022-10-27 09:57:00 -07002108 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002109 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002110 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2111 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2112 else:
2113 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002114 self.ser.addOutputTensor(then_tens)
2115
Jerry Ge9e94af82022-10-27 09:57:00 -07002116 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002117 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2118 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2119 else:
2120 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002121 self.ser.addOutputTensor(else_tens)
2122
Les Bell729b0352021-11-24 10:28:21 +00002123 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002124 self.ser,
2125 validator_fcns,
2126 error_name,
2127 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002128 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002129 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002130 ):
2131 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002132
Eric Kunzee5e26762020-10-13 16:11:07 -07002133 return result_tens
2134
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002135 def build_cond_if_binary(
2136 self, op, a, b, cond, validator_fcns=None, error_name=None
2137 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002138 # For cond_if with a binary op in the then/else blocks, take a and b and
2139 # alternately add or subtract them based on the condition
2140
2141 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002142 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002143
Kevin Cheng550ccc52021-03-03 11:21:43 -08002144 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002145
2146 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002147 then_block = "THEN_BLOCK"
2148 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002149 attr = ts.TosaSerializerAttribute()
2150 attr.CondIfAttribute(then_block, else_block)
2151
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002152 if error_name in [
2153 ErrorIf.CondIfInputListThenGraphMismatch,
2154 ErrorIf.CondIfInputListElseGraphMismatch,
2155 ErrorIf.CondIfOutputListElseGraphMismatch,
2156 ErrorIf.CondIfOutputListThenGraphMismatch,
2157 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002158 incorrect_shape = a.shape.copy()
2159 for i in range(len(incorrect_shape)):
2160 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2161 incorrect_block_input = deepcopy(a)
2162 incorrect_block_input.shape = incorrect_shape
2163
Eric Kunzee5e26762020-10-13 16:11:07 -07002164 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002165 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002166 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002167 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002168
James Ward24dbc422022-10-19 12:20:31 +01002169 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002170 then_op, else_op = Op.ADD, Op.SUB
2171 elif a.dtype in (DType.INT8, DType.INT16):
2172 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2173 else:
2174 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002175
Les Bell6040b4d2021-10-11 12:50:31 +01002176 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002177 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002178 if (
2179 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2180 and block == then_block
2181 ) or (
2182 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2183 and block == else_block
2184 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002185 self.ser.addInputTensor(incorrect_block_input)
2186 self.ser.addInputTensor(b)
2187 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002188 elif (
2189 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2190 and block == then_block
2191 ) or (
2192 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2193 and block == else_block
2194 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002195 self.ser.addInputTensor(a)
2196 self.ser.addInputTensor(b)
2197 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2198 else:
2199 self.ser.addInputTensor(a)
2200 self.ser.addInputTensor(b)
2201 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002202 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002203
Les Bell729b0352021-11-24 10:28:21 +00002204 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002205 self.ser,
2206 validator_fcns,
2207 error_name,
2208 op=op,
2209 a=a,
2210 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002211 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002212 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002213 ):
2214 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002215
Eric Kunzee5e26762020-10-13 16:11:07 -07002216 return result_tens
2217
Matthew Haddon630c17c2021-10-14 15:05:41 +01002218 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002219 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002220
Kevin Cheng550ccc52021-03-03 11:21:43 -08002221 cond_block = "COND_BLOCK"
2222 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002223
2224 attr = ts.TosaSerializerAttribute()
2225 attr.WhileLoopAttribute(cond_block, body_block)
2226
2227 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002228 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002229 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002230 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002231
2232 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002233 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2234 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002235 if error_name == ErrorIf.InputListOutputListMismatch:
2236 incorrect_acc = deepcopy(acc)
2237 for i in range(len(incorrect_acc.shape)):
2238 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2239 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2240 else:
2241 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002242
2243 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002244 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002245 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002246 [iter.name, a.name, acc.name],
2247 [iter_out.name, a_out.name, acc_out.name],
2248 attr,
2249 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002250 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002251
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002252 if error_name in [
2253 ErrorIf.InputListCondGraphMismatch,
2254 ErrorIf.InputListBodyGraphInputMismatch,
2255 ErrorIf.InputListBodyGraphOutputMismatch,
2256 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002257 incorrect_iter = deepcopy(iter)
2258 for i in range(len(incorrect_iter.shape)):
2259 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2260 if len(incorrect_iter.shape) == 0:
2261 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2262
2263 incorrect_acc = deepcopy(acc)
2264 for i in range(len(incorrect_acc.shape)):
2265 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2266
Eric Kunzee5e26762020-10-13 16:11:07 -07002267 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002268 self.ser.addBasicBlock(cond_block)
2269
Matthew Haddon630c17c2021-10-14 15:05:41 +01002270 if error_name == ErrorIf.InputListCondGraphMismatch:
2271 self.ser.addInputTensor(incorrect_iter)
2272 self.ser.addInputTensor(a)
2273 self.ser.addInputTensor(incorrect_acc)
2274 else:
2275 self.ser.addInputTensor(iter)
2276 self.ser.addInputTensor(a)
2277 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002278 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002279
2280 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002281 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002282 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002283 cond_type = DType.BOOL
2284 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2285 choice = self.rng.choice([1, 2])
2286 if choice == 1:
2287 cond_shape = [3]
2288 else:
2289 cond_shape = [1, 2]
2290 else:
2291 cond_shape = []
2292 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002293
Kevin Cheng550ccc52021-03-03 11:21:43 -08002294 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002295
2296 # BODY block (input: a, acc, iter, output: a, acc, iter)
2297 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002298 self.ser.addBasicBlock(body_block)
2299
Matthew Haddon630c17c2021-10-14 15:05:41 +01002300 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2301 self.ser.addInputTensor(incorrect_iter)
2302 self.ser.addInputTensor(a)
2303 self.ser.addInputTensor(incorrect_acc)
2304 else:
2305 self.ser.addInputTensor(iter)
2306 self.ser.addInputTensor(a)
2307 self.ser.addInputTensor(acc)
2308
Kevin Cheng550ccc52021-03-03 11:21:43 -08002309 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002310
2311 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002312 iter_body_out = self.ser.addIntermediate(
2313 incorrect_iter.shape, incorrect_iter.dtype
2314 )
2315 acc_body_out = self.ser.addIntermediate(
2316 incorrect_acc.shape, incorrect_acc.dtype
2317 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002318 else:
2319 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2320 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2321
Eric Kunzee5e26762020-10-13 16:11:07 -07002322 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2323 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2324 self.ser.addOutputTensor(iter_body_out)
2325 self.ser.addOutputTensor(a)
2326 self.ser.addOutputTensor(acc_body_out)
2327
Les Bell729b0352021-11-24 10:28:21 +00002328 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002329 self.ser,
2330 validator_fcns,
2331 error_name,
2332 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002333 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002334 ):
2335 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002336
Eric Kunzee5e26762020-10-13 16:11:07 -07002337 return acc_out
2338
Luke Hutton57287132023-02-06 14:54:18 +00002339 def build_fft2d(
2340 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2341 ):
2342 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2343
2344 input_names = [val1.name, val2.name]
2345 pCount, cCount = op["operands"]
2346 num_operands = pCount + cCount
2347
2348 output_names = [res.name for res in results]
2349 output_shapes = [res.shape for res in results]
2350 output_dtypes = [res.dtype for res in results]
2351
2352 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2353 self, error_name, input_names, output_names
2354 )
2355
2356 if not TosaErrorValidator.evValidateErrorIfs(
2357 self.ser,
2358 validator_fcns,
2359 error_name,
2360 op=op,
2361 inverse=inverse,
2362 input1=val1,
2363 input2=val2,
2364 input_shape=val1.shape,
2365 input_dtype=val1.dtype,
2366 output_shape=output_shapes,
2367 output_dtype=output_dtypes,
2368 result_tensors=results,
2369 input_list=input_names,
2370 output_list=output_names,
2371 num_operands=num_operands,
2372 ):
2373 return None
2374
2375 attr = ts.TosaSerializerAttribute()
2376 attr.FFTAttribute(inverse)
2377
2378 self.ser.addOperator(op["op"], input_names, output_names, attr)
2379 return results
2380
Luke Hutton261b7b62023-01-10 14:50:31 +00002381 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2382 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2383
2384 input_names = [val.name]
2385 pCount, cCount = op["operands"]
2386 num_operands = pCount + cCount
2387
2388 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002389 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002390 output_dtypes = [res.dtype for res in results]
2391
2392 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2393 self, error_name, input_names, output_names
2394 )
2395
2396 if not TosaErrorValidator.evValidateErrorIfs(
2397 self.ser,
2398 validator_fcns,
2399 error_name,
2400 op=op,
2401 input_shape=val.shape,
2402 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002403 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002404 output_dtype=output_dtypes,
2405 result_tensors=results,
2406 input_list=input_names,
2407 output_list=output_names,
2408 num_operands=num_operands,
2409 ):
2410 return None
2411
2412 self.ser.addOperator(op["op"], input_names, output_names)
2413 return results
2414
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002415 def create_filter_lists(
2416 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2417 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002418 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2419 default_test_rank_range = range(1, 5)
2420 if not shapeFilter:
2421 shapeFilter = [None]
2422
2423 # Calculate the filters based on what is requested and what the operator allows
2424 rmin, rmax = op["rank"]
2425 if rankFilter is not None:
2426 cleanRankFilter = []
2427 # Ensure rankFilter values are allowed by operator
2428 for rank in rankFilter:
2429 if rank >= rmin and rank <= rmax:
2430 cleanRankFilter.append(rank)
2431 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002432 # Ensure default behaviour is bounded by default range or by operator,
2433 # whichever is the smaller range of ranks.
2434 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002435 cleanRankFilter = (
2436 opRankRange
2437 if len(opRankRange) <= len(default_test_rank_range)
2438 else default_test_rank_range
2439 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002440 else:
2441 cleanRankFilter = range(rmin, rmax + 1)
2442
2443 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002444
Matthew Haddon1c00b712021-10-01 15:51:03 +01002445 if dtypeFilter is not None:
2446 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002447 # Create list of operator dtypes filtered by requested dtypes
2448 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002449 if dtype in dtypeFilter or (
2450 isinstance(dtype, list) and dtype[0] in dtypeFilter
2451 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002452 cleanDtypeFilter.append(dtype)
2453 else:
2454 cleanDtypeFilter = dtypes
2455
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002456 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002457 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002458 "shapeFilter": shapeFilter,
2459 "rankFilter": cleanRankFilter,
2460 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002461 }
2462 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002463 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002464 if validator is not None:
2465 validator_info = validator(check=False, op=op)
2466 else:
2467 return None
2468
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002469 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002470
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002471 # Set parameters as required
2472 if error_arguments["rank"] is not None:
2473 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002474 else:
2475 rankFilter = cleanRankFilter
2476
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002477 if error_arguments["dtype"] is not None:
2478 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002479 else:
2480 dtypeFilter = cleanDtypeFilter
2481
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002482 if error_arguments["shape"] is not None:
2483 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002484 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002485 shapeFilter = shapeFilter[
2486 :2
2487 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002488
2489 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002490 "shapeFilter": shapeFilter,
2491 "rankFilter": rankFilter,
2492 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002493 }
2494 return filterDict
2495
Kevin Cheng550ccc52021-03-03 11:21:43 -08002496 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002497 self,
2498 opName,
2499 shapeFilter=[None],
2500 rankFilter=None,
2501 dtypeFilter=None,
2502 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002503 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002504
2505 try:
2506 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002507 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002508 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002509
2510 # Initialize a new random number generator
2511 self.rng = np.random.default_rng(self.random_seed)
2512
Jeremy Johnson1271c442023-09-05 11:39:26 +01002513 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002514
Eric Kunzee5e26762020-10-13 16:11:07 -07002515 # Test list consists of a tuple of:
2516 # (opName, testNameStr, dtype, shapeList, argumentsList)
2517 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002518 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002519 error_if_validators = op["error_if_validators"]
2520 else:
2521 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002522
Matthew Haddon1c00b712021-10-01 15:51:03 +01002523 for validator in error_if_validators:
2524 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002525 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002526 else:
2527 error_name = None
2528
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002529 filterDict = self.create_filter_lists(
2530 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2531 )
2532 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002533 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002534 cleanRankFilter = filterDict["rankFilter"]
2535 cleanDtypeFilter = filterDict["dtypeFilter"]
2536 cleanShapeFilter = filterDict["shapeFilter"]
2537 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002538
2539 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002540 for t in cleanDtypeFilter:
2541 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002542 # Filter out by rank
2543 if shape is not None and len(shape) != r:
2544 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002545 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002546 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002547
Matthew Haddon74567092021-07-16 15:38:20 +01002548 shapeStr = self.shapeStr(shapeList[0])
2549 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002550
Matthew Haddon74567092021-07-16 15:38:20 +01002551 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2552 argList = []
2553 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002554 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002555 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002556 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002557
Matthew Haddon74567092021-07-16 15:38:20 +01002558 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002559 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002560 if argStr:
2561 testStr = "{}_{}_{}_{}".format(
2562 opName, shapeStr, typeStr, argStr
2563 )
2564 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002565 testStr = "{}_{}_{}".format(
2566 opName, shapeStr, typeStr
2567 )
2568 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002569 if argStr:
2570 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2571 opName, error_name, shapeStr, typeStr, argStr
2572 )
2573 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002574 testStr = "{}_ERRORIF_{}_{}_{}".format(
2575 opName, error_name, shapeStr, typeStr
2576 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002577
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002578 testList.append(
2579 (opName, testStr, t, error_name, shapeList, args)
2580 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002581
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002582 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002583 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2584 if "invalid_test_validators" in op:
2585 invalid_test_validators = op["invalid_test_validators"]
2586 clean_testList = []
2587 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002588 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002589 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002590 if validator_fcn(
2591 opName=test[0],
2592 input_dtype=test[2],
2593 shapeList=test[4],
2594 args=test[5],
2595 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002596 remove_test = True
2597 if not remove_test:
2598 clean_testList.append(test)
2599 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002600
2601 return testList
2602
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002603 def serializeTest(
2604 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2605 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002606 try:
2607 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002608 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002609 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002610
Jeremy Johnson0c716862023-04-13 17:18:19 +01002611 if self.args.verbose:
2612 print(f"Creating {testStr}")
2613
Eric Kunzee5e26762020-10-13 16:11:07 -07002614 # Create a serializer
2615 self.createSerializer(opName, testStr)
2616
Jeremy Johnson1271c442023-09-05 11:39:26 +01002617 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002618 if "error_if_validators" in op:
2619 error_if_validators = op["error_if_validators"]
2620 else:
2621 error_if_validators = None
2622
Kevin Cheng550ccc52021-03-03 11:21:43 -08002623 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002624 num_operands = pCount + cCount
2625
2626 if isinstance(dtype_or_dtypeList, list):
2627 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002628 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002629 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002630 else:
2631 dtypeList = [dtype_or_dtypeList] * (num_operands)
2632
Kevin Cheng93a16282021-08-31 16:14:03 -07002633 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002634 assert (
2635 len(shapeList) == num_operands
2636 ), "shapeList length {} must match number of operands {}".format(
2637 len(shapeList), num_operands
2638 )
2639 assert (
2640 len(dtypeList) == num_operands
2641 ), "dtypeList length {} must match number of operands {}".format(
2642 len(dtypeList), num_operands
2643 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002644
2645 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002646 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002647 except KeyError:
2648 qgen = None
2649
2650 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002651
Matthew Haddon1c00b712021-10-01 15:51:03 +01002652 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002653 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002654 else:
2655 qinfo = None
2656
Jeremy Johnson1271c442023-09-05 11:39:26 +01002657 # Extra meta data for the desc.json
2658 tensMeta = {}
2659
2660 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002661 if isinstance(testArgs, dict):
2662 # New interface with args info in dictionary
2663 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002664 assert "dg_type" in argsDict
2665 tvgInfo = tvgen_fcn(
2666 self, opName, dtypeList, shapeList, argsDict, error_name
2667 )
2668 if tvgInfo.dataGenDict:
2669 tensMeta["data_gen"] = tvgInfo.dataGenDict
2670 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002671
2672 result = build_fcn(
2673 self,
2674 op,
2675 tens,
2676 argsDict,
2677 validator_fcns=error_if_validators,
2678 error_name=error_name,
2679 qinfo=qinfo,
2680 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002681 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002682 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002683 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002684
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002685 try:
2686 if error_if_validators is None:
2687 if qinfo is not None:
2688 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2689 else:
2690 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002691 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002692 if qinfo is not None:
2693 result = build_fcn(
2694 self,
2695 op,
2696 *tens,
2697 *testArgs,
2698 validator_fcns=error_if_validators,
2699 error_name=error_name,
2700 qinfo=qinfo,
2701 )
2702 else:
2703 result = build_fcn(
2704 self,
2705 op,
2706 *tens,
2707 *testArgs,
2708 validator_fcns=error_if_validators,
2709 error_name=error_name,
2710 )
2711 except TypeError as e:
2712 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2713 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002714
Jeremy Johnson1271c442023-09-05 11:39:26 +01002715 if result:
Les Bell729b0352021-11-24 10:28:21 +00002716 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002717 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2718 # Add the compliance meta data
2719 # NOTE: This currently expects only one result output
2720 tensMeta["compliance"] = {
2721 "version": "0.1",
2722 "tensors": {result.resultTensor.name: result.complianceDict},
2723 }
2724 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002725 else:
2726 # The test is not valid
2727 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002728
Eric Kunzee5e26762020-10-13 16:11:07 -07002729 def createDynamicOpLists(self):
2730
Jeremy Johnson00423432022-09-12 17:27:37 +01002731 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2732 # Already created these lists (can occur when class is initialized more than once)
2733 return
2734
Eric Kunzee5e26762020-10-13 16:11:07 -07002735 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002736 if not self.args.level8k:
2737 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2738 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2739 else:
2740 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2741 KERNELS_2D = [[1, bigK], [bigK, 2]]
2742 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002743
Kevin Cheng1533b852021-09-01 12:51:58 -07002744 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002745 testName = "conv2d_{}x{}".format(k[0], k[1])
2746 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2747 self.TOSA_OP_LIST[testName]["filter"] = k
2748 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002749
Kevin Cheng550ccc52021-03-03 11:21:43 -08002750 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2751 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2752 "depthwise_conv2d_TEMPLATE"
2753 ].copy()
2754 self.TOSA_OP_LIST[testName]["filter"] = k
2755 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002756
Kevin Cheng550ccc52021-03-03 11:21:43 -08002757 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2758 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2759 "transpose_conv2d_TEMPLATE"
2760 ].copy()
2761 self.TOSA_OP_LIST[testName]["filter"] = k
2762 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002763
Kevin Cheng1533b852021-09-01 12:51:58 -07002764 for k in KERNELS_3D:
2765 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2766 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2767 self.TOSA_OP_LIST[testName]["filter"] = k
2768 self.TOSA_OP_LIST[testName]["template"] = False
2769
Eric Kunzee5e26762020-10-13 16:11:07 -07002770 # Delete any templates after having created any dynamic ops
2771 # This is a two-pass operation because it's bad practice to delete
2772 # keys from dictionaries while iterating
2773 keyList = []
2774 for k in self.TOSA_OP_LIST:
2775 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002776 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002777 keyList.append(k)
2778 continue
2779 except KeyError:
2780 pass
2781
2782 for k in keyList:
2783 del self.TOSA_OP_LIST[k]
2784
2785 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002786 """Fill in default fields for ops if they aren't already specified.
2787 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002788 for op in self.TOSA_OP_LIST:
2789
2790 # Required fields
2791 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002792 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002793 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002794 raise Exception(
2795 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2796 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002797
2798 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002799 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002800 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002801 raise Exception(
2802 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2803 op
2804 )
2805 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002806
2807 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002808 _ = self.TOSA_OP_LIST[op]["types"]
2809 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002810 raise Exception(
2811 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2812 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002813
2814 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002815 _ = self.TOSA_OP_LIST[op]["op"]
2816 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002817 raise Exception(
2818 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2819 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002820
2821 # Put in default rank range, if missing
2822 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002823 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002824 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002825 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002826
2827 # Tensor operator list
2828 # 'op': op name
2829 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002830 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2831 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002832 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2833 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002834 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002835
Kevin Cheng550ccc52021-03-03 11:21:43 -08002836 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002837 TYPE_INT_FP = [
2838 DType.INT8,
2839 DType.INT16,
2840 DType.INT32,
2841 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002842 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002843 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002844 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002845
Kevin Cheng550ccc52021-03-03 11:21:43 -08002846 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002847 TYPE_FI32 = [
2848 DType.FP32,
2849 DType.FP16,
2850 DType.BF16,
2851 DType.INT32,
2852 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002853 TYPE_FIB = [
2854 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002855 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002856 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002857 DType.INT8,
2858 DType.INT16,
2859 DType.INT32,
2860 DType.BOOL,
2861 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002862 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002863
James Ward24dbc422022-10-19 12:20:31 +01002864 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002865
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002866 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002867 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002868 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002869 [DType.INT8, DType.INT8, DType.INT32],
2870 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002871 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002872 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002873 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002874 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002875 ]
2876
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002877 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002878
2879 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002880 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002881 "argmax": {
2882 "op": Op.ARGMAX,
2883 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002884 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002885 "build_fcn": (
2886 build_argmax,
2887 TosaTensorGen.tgBasic,
2888 TosaTensorValuesGen.tvgDefault,
2889 TosaArgGen.agAxis,
2890 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002891 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002892 "error_if_validators": (
2893 TosaErrorValidator.evAxisSmallerZero,
2894 TosaErrorValidator.evAxisLargerRank,
2895 TosaErrorValidator.evArgmaxOutputRankMismatch,
2896 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2897 TosaErrorValidator.evWrongRank,
2898 TosaErrorValidator.evWrongInputType,
2899 TosaErrorValidator.evWrongOutputType,
2900 TosaErrorValidator.evWrongInputList,
2901 TosaErrorValidator.evWrongOutputList,
2902 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002903 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002904 "avg_pool2d": {
2905 "op": Op.AVG_POOL2D,
2906 "operands": (1, 0),
2907 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002908 "build_fcn": (
2909 build_pool2d,
2910 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002911 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002912 TosaArgGen.agPooling,
2913 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002914 "qgen": TosaQuantGen.qgUnary,
2915 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002916 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002917 "error_if_validators": (
2918 TosaErrorValidator.evKernelSmallerOne,
2919 TosaErrorValidator.evStrideSmallerOne,
2920 TosaErrorValidator.evPadSmallerZero,
2921 TosaErrorValidator.evWrongRank,
2922 TosaErrorValidator.evWrongInputType,
2923 TosaErrorValidator.evWrongOutputType,
2924 TosaErrorValidator.evWrongInputList,
2925 TosaErrorValidator.evWrongOutputList,
2926 TosaErrorValidator.evInputZeroPointNotZero,
2927 TosaErrorValidator.evOutputZeroPointNotZero,
2928 TosaErrorValidator.evPadLargerEqualKernel,
2929 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002930 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002931 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002932 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002933 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002934 "conv2d_TEMPLATE": {
2935 "op": Op.CONV2D,
2936 "operands": (1, 2),
2937 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002938 "build_fcn": (
2939 build_conv2d,
2940 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002941 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002942 TosaArgGen.agConv,
2943 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002944 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002945 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002946 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2947 "error_if_validators": (
2948 TosaErrorValidator.evWrongInputType,
2949 TosaErrorValidator.evWrongOutputType,
2950 TosaErrorValidator.evWrongInputList,
2951 TosaErrorValidator.evWrongOutputList,
2952 TosaErrorValidator.evInputZeroPointNotZero,
2953 TosaErrorValidator.evWeightZeroPointNotZero,
2954 TosaErrorValidator.evPadSmallerZero,
2955 TosaErrorValidator.evStrideSmallerOne,
2956 TosaErrorValidator.evDilationSmallerOne,
2957 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002958 TosaErrorValidator.evConvOutputShapeMismatch,
2959 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002960 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002961 "data_gen": {
2962 "fp": (gtu.DataGenType.DOT_PRODUCT,),
2963 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002964 "template": True,
2965 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002966 # Templated operator. Filled in by createDynamicOpLists
2967 "conv3d_TEMPLATE": {
2968 "op": Op.CONV3D,
2969 "operands": (1, 2),
2970 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002971 "build_fcn": (
2972 build_conv3d,
2973 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002974 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002975 TosaArgGen.agConv,
2976 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002977 "qgen": TosaQuantGen.qgConv,
2978 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002979 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2980 "error_if_validators": (
2981 TosaErrorValidator.evWrongInputType,
2982 TosaErrorValidator.evWrongOutputType,
2983 TosaErrorValidator.evWrongInputList,
2984 TosaErrorValidator.evWrongOutputList,
2985 TosaErrorValidator.evInputZeroPointNotZero,
2986 TosaErrorValidator.evWeightZeroPointNotZero,
2987 TosaErrorValidator.evPadSmallerZero,
2988 TosaErrorValidator.evStrideSmallerOne,
2989 TosaErrorValidator.evDilationSmallerOne,
2990 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002991 TosaErrorValidator.evConvOutputShapeMismatch,
2992 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002993 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002994 "template": True,
2995 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002996 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002997 "depthwise_conv2d_TEMPLATE": {
2998 "op": Op.DEPTHWISE_CONV2D,
2999 "operands": (1, 2),
3000 "filter": [1, 1],
3001 "rank": (4, 4),
3002 "build_fcn": (
3003 build_depthwise_conv2d,
3004 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003005 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003006 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003007 ),
3008 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003009 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003010 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3011 "error_if_validators": (
3012 TosaErrorValidator.evWrongInputType,
3013 TosaErrorValidator.evWrongOutputType,
3014 TosaErrorValidator.evWrongInputList,
3015 TosaErrorValidator.evWrongOutputList,
3016 TosaErrorValidator.evInputZeroPointNotZero,
3017 TosaErrorValidator.evWeightZeroPointNotZero,
3018 TosaErrorValidator.evPadSmallerZero,
3019 TosaErrorValidator.evStrideSmallerOne,
3020 TosaErrorValidator.evDilationSmallerOne,
3021 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003022 TosaErrorValidator.evConvOutputShapeMismatch,
3023 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003024 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003025 "template": True,
3026 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003027 "fully_connected": {
3028 "op": Op.FULLY_CONNECTED,
3029 "operands": (1, 2),
3030 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003031 "build_fcn": (
3032 build_fully_connected,
3033 TosaTensorGen.tgFullyConnected,
3034 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01003035 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003036 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003037 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003038 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003039 "error_if_validators": (
3040 TosaErrorValidator.evInputZeroPointNotZero,
3041 TosaErrorValidator.evWeightZeroPointNotZero,
3042 TosaErrorValidator.evWrongRank,
3043 TosaErrorValidator.evWrongInputType,
3044 TosaErrorValidator.evWrongOutputType,
3045 TosaErrorValidator.evWrongInputList,
3046 TosaErrorValidator.evWrongOutputList,
3047 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003048 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003049 "matmul": {
3050 "op": Op.MATMUL,
3051 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003052 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003053 "build_fcn": (
3054 build_matmul,
3055 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003056 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003057 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003058 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003059 "qgen": TosaQuantGen.qgMatmul,
3060 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003061 "error_if_validators": (
3062 TosaErrorValidator.evInputZeroPointNotZero,
3063 TosaErrorValidator.evWrongRank,
3064 TosaErrorValidator.evWrongInputType,
3065 TosaErrorValidator.evWrongOutputType,
3066 TosaErrorValidator.evWrongInputList,
3067 TosaErrorValidator.evWrongOutputList,
3068 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003069 "data_gen": {
3070 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003071 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003072 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003073 "max_pool2d": {
3074 "op": Op.MAX_POOL2D,
3075 "operands": (1, 0),
3076 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003077 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01003078 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003079 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003080 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003081 TosaArgGen.agPooling,
3082 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003083 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003084 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003085 "error_if_validators": (
3086 TosaErrorValidator.evKernelSmallerOne,
3087 TosaErrorValidator.evStrideSmallerOne,
3088 TosaErrorValidator.evPadSmallerZero,
3089 TosaErrorValidator.evWrongRank,
3090 TosaErrorValidator.evWrongInputType,
3091 TosaErrorValidator.evWrongOutputType,
3092 TosaErrorValidator.evWrongInputList,
3093 TosaErrorValidator.evWrongOutputList,
3094 TosaErrorValidator.evPadLargerEqualKernel,
3095 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003096 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003097 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003098 "data_gen": {
3099 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3100 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003101 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003102 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003103 "transpose_conv2d_TEMPLATE": {
3104 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003105 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003106 "rank": (4, 4),
3107 "build_fcn": (
3108 build_transpose_conv2d,
3109 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003110 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003111 TosaArgGen.agTransposeConv2D,
3112 ),
3113 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003114 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003115 "invalid_test_validators": (
3116 TosaInvalidValidator.ivHeightWidthInvalid,
3117 TosaInvalidValidator.ivNonPositiveOutputShape,
3118 ),
3119 "error_if_validators": (
3120 TosaErrorValidator.evWrongInputType,
3121 TosaErrorValidator.evWrongOutputType,
3122 TosaErrorValidator.evWrongInputList,
3123 TosaErrorValidator.evWrongOutputList,
3124 TosaErrorValidator.evInputZeroPointNotZero,
3125 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003126 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003127 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003128 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003129 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003130 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003131 "template": True,
3132 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003133 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003134 "clamp": {
3135 "op": Op.CLAMP,
3136 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003137 "build_fcn": (
3138 build_clamp,
3139 TosaTensorGen.tgBasic,
3140 TosaTensorValuesGen.tvgDefault,
3141 None,
3142 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003143 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003144 "error_if_validators": (
3145 TosaErrorValidator.evMaxSmallerMin,
3146 TosaErrorValidator.evWrongInputType,
3147 TosaErrorValidator.evWrongOutputType,
3148 TosaErrorValidator.evWrongInputList,
3149 TosaErrorValidator.evWrongOutputList,
3150 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003151 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003152 "sigmoid": {
3153 "op": Op.SIGMOID,
3154 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003155 "build_fcn": (
3156 build_sigmoid,
3157 TosaTensorGen.tgBasic,
3158 TosaTensorValuesGen.tvgDefault,
3159 None,
3160 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003161 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003162 "error_if_validators": (
3163 TosaErrorValidator.evWrongInputType,
3164 TosaErrorValidator.evWrongOutputType,
3165 TosaErrorValidator.evWrongInputList,
3166 TosaErrorValidator.evWrongOutputList,
3167 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003168 },
3169 "tanh": {
3170 "op": Op.TANH,
3171 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003172 "build_fcn": (
3173 build_tanh,
3174 TosaTensorGen.tgBasic,
3175 TosaTensorValuesGen.tvgDefault,
3176 None,
3177 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003178 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003179 "error_if_validators": (
3180 TosaErrorValidator.evWrongInputType,
3181 TosaErrorValidator.evWrongOutputType,
3182 TosaErrorValidator.evWrongInputList,
3183 TosaErrorValidator.evWrongOutputList,
3184 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003185 },
Won Jeon78155c62023-06-10 00:20:04 +00003186 "erf": {
3187 "op": Op.ERF,
3188 "operands": (1, 0),
3189 "build_fcn": (
3190 build_erf,
3191 TosaTensorGen.tgBasic,
3192 TosaTensorValuesGen.tvgDefault,
3193 None,
3194 ),
3195 "types": TYPE_FP,
3196 "error_if_validators": (
3197 TosaErrorValidator.evWrongInputType,
3198 TosaErrorValidator.evWrongOutputType,
3199 TosaErrorValidator.evWrongInputList,
3200 TosaErrorValidator.evWrongOutputList,
3201 ),
3202 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003203 # Elementwise Binary Operators
3204 "add": {
3205 "op": Op.ADD,
3206 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003207 "build_fcn": (
3208 build_binary_broadcast,
3209 TosaTensorGen.tgBroadcastFuzz,
3210 TosaTensorValuesGen.tvgAddSub,
3211 None,
3212 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003213 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003214 "error_if_validators": (
3215 TosaErrorValidator.evRankMismatch,
3216 TosaErrorValidator.evWrongInputType,
3217 TosaErrorValidator.evWrongOutputType,
3218 TosaErrorValidator.evWrongInputList,
3219 TosaErrorValidator.evWrongOutputList,
3220 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003221 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003222 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003223 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003224 "arithmetic_right_shift": {
3225 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3226 "operands": (2, 0),
3227 "build_fcn": (
3228 build_arithmetic_right_shift,
3229 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003230 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003231 TosaArgGen.agArithmeticRightShift,
3232 ),
3233 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003234 "error_if_validators": (
3235 TosaErrorValidator.evRankMismatch,
3236 TosaErrorValidator.evWrongInputType,
3237 TosaErrorValidator.evWrongOutputType,
3238 TosaErrorValidator.evWrongInputList,
3239 TosaErrorValidator.evWrongOutputList,
3240 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003241 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003242 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003243 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003244 "bitwise_and": {
3245 "op": Op.BITWISE_AND,
3246 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003247 "build_fcn": (
3248 build_binary_broadcast,
3249 TosaTensorGen.tgBroadcastFuzz,
3250 TosaTensorValuesGen.tvgDefault,
3251 None,
3252 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003253 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003254 "error_if_validators": (
3255 TosaErrorValidator.evRankMismatch,
3256 TosaErrorValidator.evWrongInputType,
3257 TosaErrorValidator.evWrongOutputType,
3258 TosaErrorValidator.evWrongInputList,
3259 TosaErrorValidator.evWrongOutputList,
3260 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003261 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003262 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003263 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003264 "bitwise_or": {
3265 "op": Op.BITWISE_OR,
3266 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003267 "build_fcn": (
3268 build_binary_broadcast,
3269 TosaTensorGen.tgBroadcastFuzz,
3270 TosaTensorValuesGen.tvgDefault,
3271 None,
3272 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003273 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003274 "error_if_validators": (
3275 TosaErrorValidator.evRankMismatch,
3276 TosaErrorValidator.evWrongInputType,
3277 TosaErrorValidator.evWrongOutputType,
3278 TosaErrorValidator.evWrongInputList,
3279 TosaErrorValidator.evWrongOutputList,
3280 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003281 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003282 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003283 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003284 "bitwise_xor": {
3285 "op": Op.BITWISE_XOR,
3286 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003287 "build_fcn": (
3288 build_binary_broadcast,
3289 TosaTensorGen.tgBroadcastFuzz,
3290 TosaTensorValuesGen.tvgDefault,
3291 None,
3292 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003293 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003294 "error_if_validators": (
3295 TosaErrorValidator.evRankMismatch,
3296 TosaErrorValidator.evWrongInputType,
3297 TosaErrorValidator.evWrongOutputType,
3298 TosaErrorValidator.evWrongInputList,
3299 TosaErrorValidator.evWrongOutputList,
3300 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003301 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003302 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003303 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003304 "intdiv": {
3305 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003306 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003307 "build_fcn": (
3308 build_binary_broadcast,
3309 TosaTensorGen.tgBroadcastFuzz,
3310 TosaTensorValuesGen.tvgIntDiv,
3311 None,
3312 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003313 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003314 "error_if_validators": (
3315 TosaErrorValidator.evRankMismatch,
3316 TosaErrorValidator.evWrongInputType,
3317 TosaErrorValidator.evWrongOutputType,
3318 TosaErrorValidator.evWrongInputList,
3319 TosaErrorValidator.evWrongOutputList,
3320 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003321 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003322 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003323 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003324 "logical_and": {
3325 "op": Op.LOGICAL_AND,
3326 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003327 "build_fcn": (
3328 build_binary_broadcast,
3329 TosaTensorGen.tgBroadcastFuzz,
3330 TosaTensorValuesGen.tvgDefault,
3331 None,
3332 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003333 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003334 "error_if_validators": (
3335 TosaErrorValidator.evRankMismatch,
3336 TosaErrorValidator.evWrongInputType,
3337 TosaErrorValidator.evWrongOutputType,
3338 TosaErrorValidator.evWrongInputList,
3339 TosaErrorValidator.evWrongOutputList,
3340 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003341 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003342 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003343 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003344 "logical_left_shift": {
3345 "op": Op.LOGICAL_LEFT_SHIFT,
3346 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003347 "build_fcn": (
3348 build_binary_broadcast,
3349 TosaTensorGen.tgBroadcastFuzz,
3350 TosaTensorValuesGen.tvgLogicalShift,
3351 None,
3352 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003353 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003354 "error_if_validators": (
3355 TosaErrorValidator.evRankMismatch,
3356 TosaErrorValidator.evWrongInputType,
3357 TosaErrorValidator.evWrongOutputType,
3358 TosaErrorValidator.evWrongInputList,
3359 TosaErrorValidator.evWrongOutputList,
3360 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003361 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003362 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003363 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003364 "logical_right_shift": {
3365 "op": Op.LOGICAL_RIGHT_SHIFT,
3366 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003367 "build_fcn": (
3368 build_binary_broadcast,
3369 TosaTensorGen.tgBroadcastFuzz,
3370 TosaTensorValuesGen.tvgLogicalShift,
3371 None,
3372 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003373 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003374 "error_if_validators": (
3375 TosaErrorValidator.evRankMismatch,
3376 TosaErrorValidator.evWrongInputType,
3377 TosaErrorValidator.evWrongOutputType,
3378 TosaErrorValidator.evWrongInputList,
3379 TosaErrorValidator.evWrongOutputList,
3380 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003381 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003382 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003383 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003384 "logical_or": {
3385 "op": Op.LOGICAL_OR,
3386 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003387 "build_fcn": (
3388 build_binary_broadcast,
3389 TosaTensorGen.tgBroadcastFuzz,
3390 TosaTensorValuesGen.tvgDefault,
3391 None,
3392 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003393 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003394 "error_if_validators": (
3395 TosaErrorValidator.evRankMismatch,
3396 TosaErrorValidator.evWrongInputType,
3397 TosaErrorValidator.evWrongOutputType,
3398 TosaErrorValidator.evWrongInputList,
3399 TosaErrorValidator.evWrongOutputList,
3400 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003401 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003402 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003403 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003404 "logical_xor": {
3405 "op": Op.LOGICAL_XOR,
3406 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003407 "build_fcn": (
3408 build_binary_broadcast,
3409 TosaTensorGen.tgBroadcastFuzz,
3410 TosaTensorValuesGen.tvgDefault,
3411 None,
3412 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003413 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003414 "error_if_validators": (
3415 TosaErrorValidator.evRankMismatch,
3416 TosaErrorValidator.evWrongInputType,
3417 TosaErrorValidator.evWrongOutputType,
3418 TosaErrorValidator.evWrongInputList,
3419 TosaErrorValidator.evWrongOutputList,
3420 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003421 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003422 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003423 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003424 "maximum": {
3425 "op": Op.MAXIMUM,
3426 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003427 "build_fcn": (
3428 build_binary_broadcast,
3429 TosaTensorGen.tgBroadcastFuzz,
3430 TosaTensorValuesGen.tvgDefault,
3431 None,
3432 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003433 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003434 "error_if_validators": (
3435 TosaErrorValidator.evRankMismatch,
3436 TosaErrorValidator.evWrongInputType,
3437 TosaErrorValidator.evWrongOutputType,
3438 TosaErrorValidator.evWrongInputList,
3439 TosaErrorValidator.evWrongOutputList,
3440 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003441 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003442 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003443 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003444 "minimum": {
3445 "op": Op.MINIMUM,
3446 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003447 "build_fcn": (
3448 build_binary_broadcast,
3449 TosaTensorGen.tgBroadcastFuzz,
3450 TosaTensorValuesGen.tvgDefault,
3451 None,
3452 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003453 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003454 "error_if_validators": (
3455 TosaErrorValidator.evRankMismatch,
3456 TosaErrorValidator.evWrongInputType,
3457 TosaErrorValidator.evWrongOutputType,
3458 TosaErrorValidator.evWrongInputList,
3459 TosaErrorValidator.evWrongOutputList,
3460 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003461 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003462 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003463 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003464 "mul": {
3465 "op": Op.MUL,
3466 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003467 "build_fcn": (
3468 build_mul,
3469 TosaTensorGen.tgBroadcastFuzz,
3470 TosaTensorValuesGen.tvgMul,
3471 TosaArgGen.agMul,
3472 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003474 "error_if_validators": (
3475 TosaErrorValidator.evWrongInputType,
3476 TosaErrorValidator.evWrongOutputType,
3477 TosaErrorValidator.evWrongInputList,
3478 TosaErrorValidator.evWrongOutputList,
3479 TosaErrorValidator.evRankMismatch,
3480 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003481 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003482 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003483 "data_gen": {
3484 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3485 },
3486 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003487 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003488 "pow": {
3489 "op": Op.POW,
3490 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003491 "build_fcn": (
3492 build_binary_broadcast,
3493 TosaTensorGen.tgBroadcastFuzz,
3494 TosaTensorValuesGen.tvgDefault,
3495 None,
3496 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003497 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003498 "error_if_validators": (
3499 TosaErrorValidator.evRankMismatch,
3500 TosaErrorValidator.evWrongInputType,
3501 TosaErrorValidator.evWrongOutputType,
3502 TosaErrorValidator.evWrongInputList,
3503 TosaErrorValidator.evWrongOutputList,
3504 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003505 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003506 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003507 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003508 "sub": {
3509 "op": Op.SUB,
3510 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003511 "build_fcn": (
3512 build_binary_broadcast,
3513 TosaTensorGen.tgBroadcastFuzz,
3514 TosaTensorValuesGen.tvgAddSub,
3515 None,
3516 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003517 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003518 "error_if_validators": (
3519 TosaErrorValidator.evRankMismatch,
3520 TosaErrorValidator.evWrongInputType,
3521 TosaErrorValidator.evWrongOutputType,
3522 TosaErrorValidator.evWrongInputList,
3523 TosaErrorValidator.evWrongOutputList,
3524 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003525 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003526 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003527 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003528 "table": {
3529 "op": Op.TABLE,
3530 # Use the automatic generation functions to create the input array
3531 # but create the table tensor in the build function, as it may be
3532 # a different type from the input
3533 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003534 "build_fcn": (
3535 build_table,
3536 TosaTensorGen.tgBasic,
3537 TosaTensorValuesGen.tvgDefault,
3538 TosaArgGen.agTable,
3539 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003540 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003541 "error_if_validators": (
3542 TosaErrorValidator.evWrongInputType,
3543 TosaErrorValidator.evWrongOutputType,
3544 TosaErrorValidator.evWrongInputList,
3545 TosaErrorValidator.evWrongOutputList,
3546 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003547 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003548 # Elementwise Unary operators
3549 "abs": {
3550 "op": Op.ABS,
3551 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003552 "build_fcn": (
3553 build_unary,
3554 TosaTensorGen.tgBasic,
3555 TosaTensorValuesGen.tvgDefault,
3556 None,
3557 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003558 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003559 "error_if_validators": (
3560 TosaErrorValidator.evWrongInputType,
3561 TosaErrorValidator.evWrongOutputType,
3562 TosaErrorValidator.evWrongInputList,
3563 TosaErrorValidator.evWrongOutputList,
3564 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003565 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003566 "bitwise_not": {
3567 "op": Op.BITWISE_NOT,
3568 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003569 "build_fcn": (
3570 build_unary,
3571 TosaTensorGen.tgBasic,
3572 TosaTensorValuesGen.tvgDefault,
3573 None,
3574 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003575 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003576 "error_if_validators": (
3577 TosaErrorValidator.evWrongInputType,
3578 TosaErrorValidator.evWrongOutputType,
3579 TosaErrorValidator.evWrongInputList,
3580 TosaErrorValidator.evWrongOutputList,
3581 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003582 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003583 "ceil": {
3584 "op": Op.CEIL,
3585 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003586 "build_fcn": (
3587 build_unary,
3588 TosaTensorGen.tgBasic,
3589 TosaTensorValuesGen.tvgDefault,
3590 None,
3591 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003592 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003593 "error_if_validators": (
3594 TosaErrorValidator.evWrongInputType,
3595 TosaErrorValidator.evWrongOutputType,
3596 TosaErrorValidator.evWrongInputList,
3597 TosaErrorValidator.evWrongOutputList,
3598 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003599 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003600 "clz": {
3601 "op": Op.CLZ,
3602 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003603 "build_fcn": (
3604 build_unary,
3605 TosaTensorGen.tgBasic,
3606 TosaTensorValuesGen.tvgDefault,
3607 None,
3608 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003609 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003610 "error_if_validators": (
3611 TosaErrorValidator.evWrongInputType,
3612 TosaErrorValidator.evWrongOutputType,
3613 TosaErrorValidator.evWrongInputList,
3614 TosaErrorValidator.evWrongOutputList,
3615 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003616 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003617 "exp": {
3618 "op": Op.EXP,
3619 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003620 "build_fcn": (
3621 build_unary,
3622 TosaTensorGen.tgBasic,
3623 TosaTensorValuesGen.tvgDefault,
3624 None,
3625 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003626 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003627 "error_if_validators": (
3628 TosaErrorValidator.evWrongInputType,
3629 TosaErrorValidator.evWrongOutputType,
3630 TosaErrorValidator.evWrongInputList,
3631 TosaErrorValidator.evWrongOutputList,
3632 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003633 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003634 "floor": {
3635 "op": Op.FLOOR,
3636 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003637 "build_fcn": (
3638 build_unary,
3639 TosaTensorGen.tgBasic,
3640 TosaTensorValuesGen.tvgDefault,
3641 None,
3642 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003643 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003644 "error_if_validators": (
3645 TosaErrorValidator.evWrongInputType,
3646 TosaErrorValidator.evWrongOutputType,
3647 TosaErrorValidator.evWrongInputList,
3648 TosaErrorValidator.evWrongOutputList,
3649 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003650 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003651 "log": {
3652 "op": Op.LOG,
3653 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003654 "build_fcn": (
3655 build_unary,
3656 TosaTensorGen.tgBasic,
3657 TosaTensorValuesGen.tvgDefault,
3658 None,
3659 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003660 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003661 "error_if_validators": (
3662 TosaErrorValidator.evWrongInputType,
3663 TosaErrorValidator.evWrongOutputType,
3664 TosaErrorValidator.evWrongInputList,
3665 TosaErrorValidator.evWrongOutputList,
3666 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003667 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003668 "logical_not": {
3669 "op": Op.LOGICAL_NOT,
3670 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003671 "build_fcn": (
3672 build_unary,
3673 TosaTensorGen.tgBasic,
3674 TosaTensorValuesGen.tvgDefault,
3675 None,
3676 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003677 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003678 "error_if_validators": (
3679 TosaErrorValidator.evWrongInputType,
3680 TosaErrorValidator.evWrongOutputType,
3681 TosaErrorValidator.evWrongInputList,
3682 TosaErrorValidator.evWrongOutputList,
3683 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003684 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003685 "negate": {
3686 "op": Op.NEGATE,
3687 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003688 "build_fcn": (
3689 build_unary,
3690 TosaTensorGen.tgBasic,
3691 TosaTensorValuesGen.tvgNegate,
3692 None,
3693 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003694 "qgen": TosaQuantGen.qgUnary,
3695 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003696 "error_if_validators": (
3697 TosaErrorValidator.evInputZeroPointNotZero,
3698 TosaErrorValidator.evOutputZeroPointNotZero,
3699 TosaErrorValidator.evWrongInputType,
3700 TosaErrorValidator.evWrongOutputType,
3701 TosaErrorValidator.evWrongInputList,
3702 TosaErrorValidator.evWrongOutputList,
3703 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003704 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003705 "reciprocal": {
3706 "op": Op.RECIPROCAL,
3707 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003708 "build_fcn": (
3709 build_unary,
3710 TosaTensorGen.tgBasic,
3711 TosaTensorValuesGen.tvgDefault,
3712 None,
3713 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003714 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003715 "error_if_validators": (
3716 TosaErrorValidator.evWrongInputType,
3717 TosaErrorValidator.evWrongOutputType,
3718 TosaErrorValidator.evWrongInputList,
3719 TosaErrorValidator.evWrongOutputList,
3720 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003721 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003722 "rsqrt": {
3723 "op": Op.RSQRT,
3724 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003725 "build_fcn": (
3726 build_unary,
3727 TosaTensorGen.tgBasic,
3728 TosaTensorValuesGen.tvgDefault,
3729 None,
3730 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003731 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003732 "error_if_validators": (
3733 TosaErrorValidator.evWrongInputType,
3734 TosaErrorValidator.evWrongOutputType,
3735 TosaErrorValidator.evWrongInputList,
3736 TosaErrorValidator.evWrongOutputList,
3737 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003738 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003739 # Elementwise Ternary operators
3740 "select": {
3741 "op": Op.SELECT,
3742 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003743 "build_fcn": (
3744 build_select,
3745 TosaTensorGen.tgBroadcastFuzz,
3746 TosaTensorValuesGen.tvgSelect,
3747 None,
3748 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003749 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003750 "error_if_validators": (
3751 TosaErrorValidator.evRankMismatch,
3752 TosaErrorValidator.evWrongInputType,
3753 TosaErrorValidator.evWrongOutputType,
3754 TosaErrorValidator.evWrongInputList,
3755 TosaErrorValidator.evWrongOutputList,
3756 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003757 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003758 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003759 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003760 # Comparison operators
3761 "equal": {
3762 "op": Op.EQUAL,
3763 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003764 "build_fcn": (
3765 build_comparison,
3766 TosaTensorGen.tgBroadcastFuzz,
3767 TosaTensorValuesGen.tvgEqual,
3768 None,
3769 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003770 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003771 "error_if_validators": (
3772 TosaErrorValidator.evRankMismatch,
3773 TosaErrorValidator.evWrongInputType,
3774 TosaErrorValidator.evWrongOutputType,
3775 TosaErrorValidator.evWrongInputList,
3776 TosaErrorValidator.evWrongOutputList,
3777 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003778 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003779 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003780 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 "greater_equal": {
3782 "op": Op.GREATER_EQUAL,
3783 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003784 "build_fcn": (
3785 build_comparison,
3786 TosaTensorGen.tgBroadcastFuzz,
3787 TosaTensorValuesGen.tvgDefault,
3788 None,
3789 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003790 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003791 "error_if_validators": (
3792 TosaErrorValidator.evRankMismatch,
3793 TosaErrorValidator.evWrongInputType,
3794 TosaErrorValidator.evWrongOutputType,
3795 TosaErrorValidator.evWrongInputList,
3796 TosaErrorValidator.evWrongOutputList,
3797 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003798 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003799 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003800 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003801 "greater": {
3802 "op": Op.GREATER,
3803 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003804 "build_fcn": (
3805 build_comparison,
3806 TosaTensorGen.tgBroadcastFuzz,
3807 TosaTensorValuesGen.tvgDefault,
3808 None,
3809 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003810 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003811 "error_if_validators": (
3812 TosaErrorValidator.evRankMismatch,
3813 TosaErrorValidator.evWrongInputType,
3814 TosaErrorValidator.evWrongOutputType,
3815 TosaErrorValidator.evWrongInputList,
3816 TosaErrorValidator.evWrongOutputList,
3817 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003818 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003819 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003820 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003821 # Reduction operators
3822 "reduce_all": {
3823 "op": Op.REDUCE_ALL,
3824 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003825 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003826 "build_fcn": (
3827 build_reduce,
3828 TosaTensorGen.tgBasic,
3829 TosaTensorValuesGen.tvgDefault,
3830 TosaArgGen.agAxis,
3831 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003832 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003833 "error_if_validators": (
3834 TosaErrorValidator.evAxisLargerRank,
3835 TosaErrorValidator.evAxisSmallerZero,
3836 TosaErrorValidator.evShapeOfAxisNotOne,
3837 TosaErrorValidator.evWrongInputType,
3838 TosaErrorValidator.evWrongOutputType,
3839 TosaErrorValidator.evWrongRank,
3840 TosaErrorValidator.evWrongInputList,
3841 TosaErrorValidator.evWrongOutputList,
3842 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003843 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003844 "reduce_any": {
3845 "op": Op.REDUCE_ANY,
3846 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003847 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003848 "build_fcn": (
3849 build_reduce,
3850 TosaTensorGen.tgBasic,
3851 TosaTensorValuesGen.tvgDefault,
3852 TosaArgGen.agAxis,
3853 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003854 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003855 "error_if_validators": (
3856 TosaErrorValidator.evAxisLargerRank,
3857 TosaErrorValidator.evAxisSmallerZero,
3858 TosaErrorValidator.evShapeOfAxisNotOne,
3859 TosaErrorValidator.evWrongInputType,
3860 TosaErrorValidator.evWrongOutputType,
3861 TosaErrorValidator.evWrongRank,
3862 TosaErrorValidator.evWrongInputList,
3863 TosaErrorValidator.evWrongOutputList,
3864 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003865 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003866 "reduce_max": {
3867 "op": Op.REDUCE_MAX,
3868 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003869 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003870 "build_fcn": (
3871 build_reduce,
3872 TosaTensorGen.tgBasic,
3873 TosaTensorValuesGen.tvgDefault,
3874 TosaArgGen.agAxis,
3875 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003876 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003877 "error_if_validators": (
3878 TosaErrorValidator.evAxisLargerRank,
3879 TosaErrorValidator.evAxisSmallerZero,
3880 TosaErrorValidator.evShapeOfAxisNotOne,
3881 TosaErrorValidator.evWrongInputType,
3882 TosaErrorValidator.evWrongOutputType,
3883 TosaErrorValidator.evWrongRank,
3884 TosaErrorValidator.evWrongInputList,
3885 TosaErrorValidator.evWrongOutputList,
3886 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003887 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003888 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003889 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003890 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003891 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003892 "build_fcn": (
3893 build_reduce,
3894 TosaTensorGen.tgBasic,
3895 TosaTensorValuesGen.tvgDefault,
3896 TosaArgGen.agAxis,
3897 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003898 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003899 "error_if_validators": (
3900 TosaErrorValidator.evAxisLargerRank,
3901 TosaErrorValidator.evAxisSmallerZero,
3902 TosaErrorValidator.evShapeOfAxisNotOne,
3903 TosaErrorValidator.evWrongInputType,
3904 TosaErrorValidator.evWrongOutputType,
3905 TosaErrorValidator.evWrongRank,
3906 TosaErrorValidator.evWrongInputList,
3907 TosaErrorValidator.evWrongOutputList,
3908 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003909 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003910 "reduce_product": {
3911 "op": Op.REDUCE_PRODUCT,
3912 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003913 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003914 "build_fcn": (
3915 build_reduce,
3916 TosaTensorGen.tgBasic,
3917 TosaTensorValuesGen.tvgDefault,
3918 TosaArgGen.agAxis,
3919 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003920 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003921 "error_if_validators": (
3922 TosaErrorValidator.evAxisLargerRank,
3923 TosaErrorValidator.evAxisSmallerZero,
3924 TosaErrorValidator.evShapeOfAxisNotOne,
3925 TosaErrorValidator.evWrongInputType,
3926 TosaErrorValidator.evWrongOutputType,
3927 TosaErrorValidator.evWrongRank,
3928 TosaErrorValidator.evWrongInputList,
3929 TosaErrorValidator.evWrongOutputList,
3930 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003931 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003932 "reduce_sum": {
3933 "op": Op.REDUCE_SUM,
3934 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003935 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003936 "build_fcn": (
3937 build_reduce,
3938 TosaTensorGen.tgBasic,
3939 TosaTensorValuesGen.tvgReduceSum,
3940 TosaArgGen.agAxis,
3941 ),
James Ward24dbc422022-10-19 12:20:31 +01003942 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003943 "error_if_validators": (
3944 TosaErrorValidator.evAxisLargerRank,
3945 TosaErrorValidator.evAxisSmallerZero,
3946 TosaErrorValidator.evShapeOfAxisNotOne,
3947 TosaErrorValidator.evWrongInputType,
3948 TosaErrorValidator.evWrongOutputType,
3949 TosaErrorValidator.evWrongRank,
3950 TosaErrorValidator.evWrongInputList,
3951 TosaErrorValidator.evWrongOutputList,
3952 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003953 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003954 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003955 "concat": {
3956 "op": Op.CONCAT,
3957 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003958 "build_fcn": (
3959 build_concat,
3960 TosaTensorGen.tgConcat,
3961 TosaTensorValuesGen.tvgConcat,
3962 TosaArgGen.agAxis,
3963 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003964 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003965 "error_if_validators": (
3966 TosaErrorValidator.evAxisLargerRank,
3967 TosaErrorValidator.evAxisSmallerZero,
3968 TosaErrorValidator.evConcatInputRankMismatch,
3969 TosaErrorValidator.evConcatShapeSumMismatch,
3970 TosaErrorValidator.evConcatInputDimMismatch,
3971 TosaErrorValidator.evWrongInputType,
3972 TosaErrorValidator.evWrongOutputType,
3973 TosaErrorValidator.evWrongOutputList,
3974 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003975 },
3976 "pad": {
3977 "op": Op.PAD,
3978 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003979 "build_fcn": (
3980 build_pad,
3981 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003982 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003983 TosaArgGen.agPad,
3984 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003985 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003986 "error_if_validators": (
3987 TosaErrorValidator.evWrongInputType,
3988 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003989 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003990 TosaErrorValidator.evWrongOutputType,
3991 TosaErrorValidator.evWrongInputList,
3992 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003993 TosaErrorValidator.evRankMismatch,
3994 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003995 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003996 "data_gen": {
3997 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3998 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003999 },
Won Jeona21b2e82023-08-10 10:33:01 +00004000 "dim": {
4001 "op": Op.DIM,
4002 "operands": (1, 0),
4003 "build_fcn": (
4004 build_dim,
4005 TosaTensorGen.tgBasic,
4006 TosaTensorValuesGen.tvgDefault,
4007 TosaArgGen.agAxis,
4008 ),
4009 "types": TYPE_FIB,
4010 "error_if_validators": (
4011 TosaErrorValidator.evAxisLargerRank,
4012 TosaErrorValidator.evAxisSmallerZero,
4013 TosaErrorValidator.evWrongInputType,
4014 TosaErrorValidator.evWrongInputList,
4015 TosaErrorValidator.evWrongOutputList,
4016 TosaErrorValidator.evWrongRank,
4017 ),
4018 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004019 "reshape": {
4020 "op": Op.RESHAPE,
4021 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004022 "build_fcn": (
4023 build_reshape,
4024 TosaTensorGen.tgBasic,
4025 TosaTensorValuesGen.tvgDefault,
4026 TosaArgGen.agReshape,
4027 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004028 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004029 "error_if_validators": (
4030 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4031 TosaErrorValidator.evWrongInputType,
4032 TosaErrorValidator.evWrongOutputType,
4033 TosaErrorValidator.evWrongInputList,
4034 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00004035 TosaErrorValidator.evReshapeOutputSizeMultiInference,
4036 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004037 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004038 },
4039 "reverse": {
4040 "op": Op.REVERSE,
4041 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004042 "build_fcn": (
4043 build_reverse,
4044 TosaTensorGen.tgBasic,
4045 TosaTensorValuesGen.tvgDefault,
4046 TosaArgGen.agAxis,
4047 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004048 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004049 "error_if_validators": (
4050 TosaErrorValidator.evAxisSmallerZero,
4051 TosaErrorValidator.evAxisLargerRank,
4052 TosaErrorValidator.evWrongInputType,
4053 TosaErrorValidator.evWrongOutputType,
4054 TosaErrorValidator.evWrongInputList,
4055 TosaErrorValidator.evWrongOutputList,
4056 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004057 },
4058 "slice": {
4059 "op": Op.SLICE,
4060 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004061 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004062 "build_fcn": (
4063 build_slice,
4064 TosaTensorGen.tgBasic,
4065 TosaTensorValuesGen.tvgDefault,
4066 TosaArgGen.agSlice,
4067 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004068 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004069 "error_if_validators": (
4070 TosaErrorValidator.evStartSmallerZero,
4071 TosaErrorValidator.evSizeSmallerEqualZero,
4072 TosaErrorValidator.evStartSizeOutsideBounds,
4073 TosaErrorValidator.evSizeOutputShapeMismatch,
4074 TosaErrorValidator.evInputSizeStartLengthMismatch,
4075 TosaErrorValidator.evWrongRank,
4076 TosaErrorValidator.evWrongInputType,
4077 TosaErrorValidator.evWrongOutputType,
4078 TosaErrorValidator.evWrongInputList,
4079 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004080 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004081 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004082 },
4083 "tile": {
4084 "op": Op.TILE,
4085 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004086 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004087 "build_fcn": (
4088 build_tile,
4089 TosaTensorGen.tgBasic,
4090 TosaTensorValuesGen.tvgDefault,
4091 TosaArgGen.agTile,
4092 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004093 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004094 "error_if_validators": (
4095 TosaErrorValidator.evWrongInputType,
4096 TosaErrorValidator.evWrongOutputType,
4097 TosaErrorValidator.evWrongInputList,
4098 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004099 TosaErrorValidator.evRankMismatch,
4100 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004101 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004102 },
4103 "transpose": {
4104 "op": Op.TRANSPOSE,
4105 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004106 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004107 "build_fcn": (
4108 build_transpose,
4109 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004110 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004111 TosaArgGen.agTranspose,
4112 ),
4113 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004114 "error_if_validators": (
4115 TosaErrorValidator.evIndexOutsideBounds,
4116 TosaErrorValidator.evIndexUsedTwice,
4117 TosaErrorValidator.evWrongInputType,
4118 TosaErrorValidator.evWrongOutputType,
4119 TosaErrorValidator.evWrongInputList,
4120 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004121 TosaErrorValidator.evWrongRank,
4122 TosaErrorValidator.evRankMismatch,
4123 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004124 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004125 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004126 # Data nodes
4127 "const": {
4128 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004129 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004130 "build_fcn": (
4131 build_const,
4132 TosaTensorGen.tgBasic,
4133 TosaTensorValuesGen.tvgDefault,
4134 None,
4135 ),
Luke Hutton65872422023-02-20 10:33:04 +00004136 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004137 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004138 "identity": {
4139 "op": Op.IDENTITY,
4140 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004141 "build_fcn": (
4142 build_unary,
4143 TosaTensorGen.tgBasic,
4144 TosaTensorValuesGen.tvgDefault,
4145 None,
4146 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004147 "types": TYPE_FIB,
4148 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004149 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004150 "gather": {
4151 "op": Op.GATHER,
4152 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4153 "operands": (1, 0),
4154 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004155 "build_fcn": (
4156 build_gather,
4157 TosaTensorGen.tgBasic,
4158 TosaTensorValuesGen.tvgDefault,
4159 None,
4160 ),
James Ward24dbc422022-10-19 12:20:31 +01004161 "types": (
4162 DType.INT8,
4163 DType.INT16,
4164 DType.INT32,
4165 DType.FP16,
4166 DType.BF16,
4167 DType.FP32,
4168 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004169 "error_if_validators": (
4170 TosaErrorValidator.evWrongInputType,
4171 TosaErrorValidator.evWrongOutputType,
4172 TosaErrorValidator.evWrongInputList,
4173 TosaErrorValidator.evWrongOutputList,
4174 TosaErrorValidator.evWrongRank,
4175 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004176 },
4177 "scatter": {
4178 "op": Op.SCATTER,
4179 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004180 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004181 "operands": (2, 0),
4182 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004183 "build_fcn": (
4184 build_scatter,
4185 TosaTensorGen.tgScatter,
4186 TosaTensorValuesGen.tvgDefault,
4187 None,
4188 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004189 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004190 "error_if_validators": (
4191 TosaErrorValidator.evWrongInputType,
4192 TosaErrorValidator.evWrongOutputType,
4193 TosaErrorValidator.evWrongInputList,
4194 TosaErrorValidator.evWrongOutputList,
4195 TosaErrorValidator.evWrongRank,
4196 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004197 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004198 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004199 "resize": {
4200 "op": Op.RESIZE,
4201 "operands": (1, 0),
4202 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004203 "build_fcn": (
4204 build_resize,
4205 TosaTensorGen.tgNHWC,
4206 TosaTensorValuesGen.tvgDefault,
4207 TosaArgGen.agResize,
4208 ),
James Ward24dbc422022-10-19 12:20:31 +01004209 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004210 "invalid_test_validators": (
4211 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004212 ),
4213 "error_if_validators": (
4214 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004215 TosaErrorValidator.evScaleSmallerEqualZero,
4216 TosaErrorValidator.evScaleNLargerMax,
4217 TosaErrorValidator.evScaleDLargerMax,
4218 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004219 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004220 TosaErrorValidator.evBorderSmallerMin,
4221 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004222 TosaErrorValidator.evWrongInputType,
4223 TosaErrorValidator.evWrongOutputType,
4224 TosaErrorValidator.evWrongRank,
4225 TosaErrorValidator.evWrongInputList,
4226 TosaErrorValidator.evWrongOutputList,
4227 TosaErrorValidator.evBatchMismatch,
4228 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004229 TosaErrorValidator.evResizeOutputShapeMismatch,
4230 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004231 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004232 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004233 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004234 "cast": {
4235 "op": Op.CAST,
4236 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004237 "build_fcn": (
4238 build_cast,
4239 TosaTensorGen.tgBasic,
4240 TosaTensorValuesGen.tvgDefault,
4241 TosaArgGen.agCast,
4242 ),
James Ward8b390432022-08-12 20:48:56 +01004243 "types": (
4244 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004245 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004246 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004247 DType.INT8,
4248 DType.INT16,
4249 DType.INT32,
4250 DType.BOOL,
4251 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004252 "error_if_validators": (
4253 TosaErrorValidator.evWrongInputType,
4254 TosaErrorValidator.evWrongOutputType,
4255 TosaErrorValidator.evWrongInputList,
4256 TosaErrorValidator.evWrongOutputList,
4257 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004258 },
4259 "rescale": {
4260 "op": Op.RESCALE,
4261 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004262 "build_fcn": (
4263 build_rescale,
4264 TosaTensorGen.tgBasic,
4265 TosaTensorValuesGen.tvgDefault,
4266 TosaArgGen.agRescale,
4267 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004268 "types": [
4269 DType.UINT8,
4270 DType.INT8,
4271 DType.INT16,
4272 DType.INT32,
4273 DType.INT48,
4274 DType.UINT16,
4275 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004276 "error_if_validators": (
4277 TosaErrorValidator.evInputZeroPointNotZero,
4278 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004279 TosaErrorValidator.evU16InputZeroPointNotValid,
4280 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004281 TosaErrorValidator.evScaleTrue,
4282 TosaErrorValidator.evScaleNotTrue,
4283 TosaErrorValidator.evWrongInputType,
4284 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004285 TosaErrorValidator.evWrongInputList,
4286 TosaErrorValidator.evWrongOutputList,
4287 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004288 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004289 # Custom
4290 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004291 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004292 # Two varients of cond_if, one that generates one of two constant tensors (no
4293 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4294 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004295 "cond_if_const": {
4296 "op": Op.COND_IF,
4297 "operands": (0, 2),
4298 "build_fcn": (
4299 build_cond_if_const,
4300 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004301 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004302 TosaArgGen.agCondIf,
4303 ),
4304 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004305 "error_if_validators": (
4306 TosaErrorValidator.evOutputListThenGraphMismatch,
4307 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004308 TosaErrorValidator.evCondIfCondNotMatchingBool,
4309 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004310 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004311 },
4312 "cond_if_binary": {
4313 "op": Op.COND_IF,
4314 "operands": (2, 0),
4315 "build_fcn": (
4316 build_cond_if_binary,
4317 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004318 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004319 TosaArgGen.agCondIf,
4320 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004321 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004322 "error_if_validators": (
4323 TosaErrorValidator.evInputListThenGraphMismatch,
4324 TosaErrorValidator.evInputListElseGraphMismatch,
4325 TosaErrorValidator.evOutputListThenGraphMismatch,
4326 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004327 TosaErrorValidator.evCondIfCondNotMatchingBool,
4328 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004329 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004330 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004331 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004332 "while_loop": {
4333 "op": Op.WHILE_LOOP,
4334 "operands": (0, 1),
4335 "build_fcn": (
4336 build_while_loop,
4337 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004338 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004339 TosaArgGen.agWhileLoop,
4340 ),
4341 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004342 "error_if_validators": (
4343 TosaErrorValidator.evInputListOutputListMismatch,
4344 TosaErrorValidator.evInputListCondGraphMismatch,
4345 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4346 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4347 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004348 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004349 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004350 },
Luke Hutton57287132023-02-06 14:54:18 +00004351 "fft2d": {
4352 "op": Op.FFT2D,
4353 "operands": (2, 0),
4354 "rank": (3, 3),
4355 "build_fcn": (
4356 build_fft2d,
4357 TosaTensorGen.tgFFT2d,
4358 TosaTensorValuesGen.tvgDefault,
4359 TosaArgGen.agFFT2d,
4360 ),
4361 "types": [DType.FP32],
4362 "error_if_validators": (
4363 TosaErrorValidator.evWrongInputType,
4364 TosaErrorValidator.evWrongOutputType,
4365 TosaErrorValidator.evWrongInputList,
4366 TosaErrorValidator.evWrongOutputList,
4367 TosaErrorValidator.evWrongRank,
4368 TosaErrorValidator.evBatchMismatch,
4369 TosaErrorValidator.evKernelNotPowerOfTwo,
4370 TosaErrorValidator.evFFTInputShapeMismatch,
4371 TosaErrorValidator.evFFTOutputShapeMismatch,
4372 ),
4373 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004374 "rfft2d": {
4375 "op": Op.RFFT2D,
4376 "operands": (1, 0),
4377 "rank": (3, 3),
4378 "build_fcn": (
4379 build_rfft2d,
4380 TosaTensorGen.tgRFFT2d,
4381 TosaTensorValuesGen.tvgDefault,
4382 TosaArgGen.agNone,
4383 ),
4384 "types": [DType.FP32],
4385 "error_if_validators": (
4386 TosaErrorValidator.evWrongInputType,
4387 TosaErrorValidator.evWrongOutputType,
4388 TosaErrorValidator.evWrongInputList,
4389 TosaErrorValidator.evWrongOutputList,
4390 TosaErrorValidator.evWrongRank,
4391 TosaErrorValidator.evBatchMismatch,
4392 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004393 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004394 ),
4395 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004396 }
4397
Kevin Cheng550ccc52021-03-03 11:21:43 -08004398
Eric Kunzee5e26762020-10-13 16:11:07 -07004399class OutputShaper:
4400 # Methods in this class compute the expected output shape and datatype
4401 # for common classes of operations
4402 def __init__(self):
4403 pass
4404
4405 # These methods return arguments that can be used for
4406 # creating a new output tensor
4407 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004408 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4409 if error_name != ErrorIf.RankMismatch:
4410 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004411 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004412
4413 shape = []
4414 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004415 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004416 shape.append(b.shape[i])
4417 else:
4418 shape.append(a.shape[i])
4419
Jerry Ge135c9552023-05-23 20:59:32 +00004420 fuzz_idx = rng.integers(0, len(a.shape))
4421 if error_name == ErrorIf.DimensionMismatch:
4422 shape[fuzz_idx] += 1
4423
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004424 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004425 all_dtypes = [
4426 DType.INT8,
4427 DType.INT16,
4428 DType.INT32,
4429 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004430 DType.FP16,
4431 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004432 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004433 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004434 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4435 outputDType = rng.choice(wrong_dtypes)
4436 else:
4437 outputDType = a.dtype
4438
4439 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004440
4441 @staticmethod
4442 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004443 assert len(a.shape) == len(b.shape)
4444 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004445
4446 shape = []
4447 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004448 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004449 shape.append(a.shape[i])
4450
Kevin Cheng550ccc52021-03-03 11:21:43 -08004451 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004452
4453 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004454 def unaryOp(ser, rng, a, error_name=None):
4455 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004456 all_dtypes = [
4457 DType.INT8,
4458 DType.INT16,
4459 DType.INT32,
4460 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004461 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004462 DType.FP16,
4463 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004464 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004465 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4466 outputDType = rng.choice(wrong_dtypes)
4467 else:
4468 outputDType = a.dtype
4469
4470 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004471
4472 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004473 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004474 if error_name != ErrorIf.RankMismatch:
4475 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004476 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004477
4478 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004479 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004480 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004481 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4482 else:
4483 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004484
Jerry Ge135c9552023-05-23 20:59:32 +00004485 fuzz_idx = rng.integers(0, len(a.shape))
4486 if error_name == ErrorIf.DimensionMismatch:
4487 shape[fuzz_idx] += 1
4488
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004489 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004490 all_dtypes = [
4491 DType.INT8,
4492 DType.INT16,
4493 DType.INT32,
4494 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004495 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004496 DType.FP16,
4497 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004498 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004499 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4500 outputDType = rng.choice(wrong_dtypes)
4501 else:
4502 outputDType = a.dtype
4503
4504 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004505
4506 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004507 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004508 if error_name != ErrorIf.RankMismatch:
4509 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004510 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004511
4512 # Do broadcast
4513 shape = []
4514 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004515 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004516 shape.append(b.shape[i])
4517 else:
4518 shape.append(a.shape[i])
4519
Jerry Ge135c9552023-05-23 20:59:32 +00004520 fuzz_idx = rng.integers(0, len(a.shape))
4521 if error_name == ErrorIf.DimensionMismatch:
4522 shape[fuzz_idx] += 1
4523
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004524 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004525 wrong_dtypes = [
4526 DType.INT8,
4527 DType.INT16,
4528 DType.INT32,
4529 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004530 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004531 DType.FP16,
4532 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004533 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004534 outputDType = rng.choice(wrong_dtypes)
4535 else:
4536 outputDType = DType.BOOL
4537
4538 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004539
4540 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004541 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004542 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004543 if error_name not in [
4544 ErrorIf.AxisSmallerZero,
4545 ErrorIf.AxisLargerRank,
4546 ErrorIf.ShapeOfAxisNotOne,
4547 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004548 shape[axis] = 1
4549 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4550 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004551
Matthew Haddond6ce7252021-09-29 15:35:44 +01004552 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004553 all_dtypes = [
4554 DType.INT8,
4555 DType.INT16,
4556 DType.INT32,
4557 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004558 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004559 DType.FP16,
4560 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004561 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004562 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4563 outputDType = rng.choice(wrong_dtypes)
4564 else:
4565 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004566
Matthew Haddond6ce7252021-09-29 15:35:44 +01004567 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004568
4569 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004570 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004571 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004572
4573 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4574 del shape[axis]
4575
4576 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4577 remove = rng.choice([True, False])
4578 if remove and len(shape) > 1:
4579 del shape[0]
4580 else:
4581 shape.append(1)
4582 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4583 for i in range(len(shape)):
4584 shape[i] = shape[i] + rng.integers(1, 10)
4585
4586 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004587 all_dtypes = [
4588 DType.INT8,
4589 DType.INT16,
4590 DType.INT32,
4591 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004592 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004593 DType.FP16,
4594 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004595 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004596 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4597 outputDType = rng.choice(wrong_dtypes)
4598 else:
4599 outputDType = DType.INT32
4600
4601 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004602
4603 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004604 def conv2dOp(
4605 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4606 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004607
4608 # IFM: NHWC
4609 # Filter: OHWI
4610 # OFM: NHWC
4611
Kevin Cheng550ccc52021-03-03 11:21:43 -08004612 h = (
4613 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004614 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004615 + padding[0]
4616 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004617 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004618 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004619
Kevin Cheng550ccc52021-03-03 11:21:43 -08004620 w = (
4621 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004622 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004623 + padding[2]
4624 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004625 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004626 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004627
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004628 if error_name == ErrorIf.ConvOutputShapeMismatch:
4629 choices = [1, 2, 3]
4630 change = rng.choice(choices)
4631 # increment in multiples of stride to not hit non-integer error case
4632 if change in [1, 3]:
4633 h = h + (rng.choice(choices) * strides[0])
4634 if change in [2, 3]:
4635 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004636
Eric Kunzee5e26762020-10-13 16:11:07 -07004637 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4638
James Ward8b390432022-08-12 20:48:56 +01004639 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004640 # Pick some potentially correct output dtype if input type is incorrect
4641 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004642 else:
James Ward8b390432022-08-12 20:48:56 +01004643 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004644
4645 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004646 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004647 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004648 else:
4649 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004650 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004651 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004652
Kevin Cheng550ccc52021-03-03 11:21:43 -08004653 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004654
4655 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004656 def conv3dOp(
4657 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4658 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004659
4660 # IFM: NDHWC
4661 # Filter: ODHWI
4662 # OFM: NDHWC
4663
4664 d = (
4665 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004666 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004667 + padding[0]
4668 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004669 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004670 ) // strides[0] + 1
4671
4672 h = (
4673 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004674 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004675 + padding[2]
4676 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004677 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004678 ) // strides[1] + 1
4679
4680 w = (
4681 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004682 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004683 + padding[4]
4684 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004685 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004686 ) // strides[2] + 1
4687
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004688 if error_name == ErrorIf.ConvOutputShapeMismatch:
4689 choices = [1, 2, 3, 4]
4690 change = rng.choice(choices)
4691 # increment in multiples of stride to not hit non-integer error case
4692 if change in [1, 4]:
4693 d = d + (rng.choice(choices) * strides[0])
4694 if change in [2, 4]:
4695 h = h + (rng.choice(choices) * strides[1])
4696 if change in [3, 4]:
4697 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004698
Kevin Cheng1533b852021-09-01 12:51:58 -07004699 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4700
James Ward8b390432022-08-12 20:48:56 +01004701 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004702 # Pick some potentially correct output dtype if input type is incorrect
4703 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004704 else:
James Ward8b390432022-08-12 20:48:56 +01004705 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004706
4707 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004708 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004709 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004710 else:
4711 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004712 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004713 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004714
4715 return ser.addOutput(ofm_shape, out_dtype)
4716
4717 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004718 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004719 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004720 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004721 # IFM: NHWC
4722 # Filter: HWCM
4723 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004724
Kevin Cheng550ccc52021-03-03 11:21:43 -08004725 h = (
4726 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004727 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004728 + padding[0]
4729 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004730 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004731 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004732
Kevin Cheng550ccc52021-03-03 11:21:43 -08004733 w = (
4734 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004735 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004736 + padding[2]
4737 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004738 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004739 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004740
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004741 if error_name == ErrorIf.ConvOutputShapeMismatch:
4742 choices = [1, 2, 3]
4743 change = rng.choice(choices)
4744 # increment in multiples of stride to not hit non-integer error case
4745 if change in [1, 3]:
4746 h = h + (rng.choice(choices) * strides[0])
4747 if change in [2, 3]:
4748 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004749
Eric Kunzee5e26762020-10-13 16:11:07 -07004750 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4751
James Ward8b390432022-08-12 20:48:56 +01004752 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004753 # Pick some potentially correct output dtype if input type is incorrect
4754 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004755 else:
James Ward8b390432022-08-12 20:48:56 +01004756 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004757
4758 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004759 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004760 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004761 else:
4762 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004763 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004764 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004765
Kevin Cheng550ccc52021-03-03 11:21:43 -08004766 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004767
4768 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004769 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004770 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004771 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004772 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004773 h = 1
4774 w = 1
4775 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004776 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4777 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004778
4779 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004780 choices = [1, 2, 3]
4781 change = rng.choice(choices)
4782 # increment in multiples of stride to not hit non-integer error case
4783 if change in [1, 3]:
4784 h = h + (rng.choice(choices) * stride[0])
4785 if change in [2, 3]:
4786 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004787 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004788
4789 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004790 all_dtypes = [
4791 DType.INT8,
4792 DType.INT16,
4793 DType.INT32,
4794 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004795 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004796 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004797 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004798 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004799 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4800 outputDType = rng.choice(wrong_dtypes)
4801 else:
4802 outputDType = ifm.dtype
4803
4804 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004805
4806 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004807 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004808 # input: N, IC
4809 # filter: OC, IC
4810 # output: N, OC
4811
4812 output_shape = [input.shape[0], filter.shape[0]]
4813
James Ward8b390432022-08-12 20:48:56 +01004814 # Validated in arg_gen (also invalidated for ErrorIf)
4815 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004816
Kevin Cheng550ccc52021-03-03 11:21:43 -08004817 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004818
4819 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004820 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004821 # a: N, H, C
4822 # b: N, C, W
4823 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004824
Kevin Cheng2d60f002021-06-09 14:18:32 -07004825 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004826
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004827 if error_name == ErrorIf.WrongOutputType:
4828 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004829 incorrect_types = (
4830 DType.INT4,
4831 DType.INT8,
4832 DType.INT16,
4833 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004834 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004835 DType.FP16,
4836 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004837 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004838 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004839 incorrect_types = (
4840 DType.INT4,
4841 DType.INT8,
4842 DType.INT16,
4843 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004844 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004845 DType.FP16,
4846 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004847 )
James Ward24dbc422022-10-19 12:20:31 +01004848 elif (
4849 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4850 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004851 incorrect_types = (
4852 DType.INT4,
4853 DType.INT8,
4854 DType.INT16,
4855 DType.INT32,
4856 DType.INT48,
4857 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004858 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004859 elif error_name == ErrorIf.WrongInputType:
4860 # Pick some potentially correct output dtype if input type is incorrect
4861 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004862 else:
James Ward8b390432022-08-12 20:48:56 +01004863 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004864
Kevin Cheng550ccc52021-03-03 11:21:43 -08004865 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004866
4867 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004868 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004869 input1 = a[0]
4870 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004871
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004872 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004873 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004874 if not (
4875 # unable to concat tensors of different ranks
4876 error_name == ErrorIf.ConcatInputRankMismatch
4877 # unable to concat tensors along an invalid axis
4878 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004879 ):
4880 for tensor in remaining_inputs:
4881 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004882
Matthew Haddon01c359d2021-10-15 16:30:48 +01004883 if error_name == ErrorIf.ConcatShapeSumMismatch:
4884 output_shape[axis] += rng.integers(5, 10)
4885
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004886 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004887 all_dtypes = {
4888 DType.INT8,
4889 DType.INT16,
4890 DType.INT32,
4891 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004892 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004893 DType.FP16,
4894 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004895 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004896 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4897 outputDType = rng.choice(wrong_dtypes)
4898 else:
4899 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004900
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004901 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004902
4903 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004904 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004905
4906 output_shape = a.shape.copy()
4907
4908 for i in range(len(output_shape)):
4909 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4910
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004911 if error_name == ErrorIf.PadOutputShapeMismatch:
4912 bad_dim = rng.choice(range(len(output_shape)))
4913 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004914 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004915 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004916
Matthew Haddone807aae2021-10-11 18:12:58 +01004917 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004918 all_dtypes = [
4919 DType.INT8,
4920 DType.INT16,
4921 DType.INT32,
4922 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004923 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004924 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004925 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004926 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004927 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4928 outputDType = rng.choice(wrong_dtypes)
4929 else:
4930 outputDType = a.dtype
4931
4932 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004933
4934 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00004935 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00004936 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00004937
4938 if error_name == ErrorIf.WrongOutputType:
4939 all_dtypes = [
4940 DType.INT8,
4941 DType.INT16,
4942 DType.INT32,
4943 DType.INT48,
4944 DType.FP32,
4945 DType.FP16,
4946 DType.BF16,
4947 ]
4948 wrong_dtypes = list(set(all_dtypes))
4949 outputDType = rng.choice(wrong_dtypes)
4950 else:
4951 outputDType = DType.SHAPE
4952
4953 return ser.addOutput(output_shape, outputDType)
4954
4955 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004956 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004957 output_shape = shape.copy()
4958
Matthew Haddone807aae2021-10-11 18:12:58 +01004959 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4960 for i in range(len(output_shape)):
4961 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4962
4963 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004964 all_dtypes = [
4965 DType.INT8,
4966 DType.INT16,
4967 DType.INT32,
4968 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004969 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004970 DType.FP16,
4971 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004972 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004973 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4974 outputDType = rng.choice(wrong_dtypes)
4975 else:
4976 outputDType = a.dtype
4977
4978 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004979
4980 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00004981 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004982
Matthew Haddone807aae2021-10-11 18:12:58 +01004983 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004984 all_dtypes = [
4985 DType.INT8,
4986 DType.INT16,
4987 DType.INT32,
4988 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004989 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004990 DType.FP16,
4991 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004992 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00004993 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01004994 outputDType = rng.choice(wrong_dtypes)
4995 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00004996 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01004997
Luke Huttona4e48ca2023-02-22 11:53:48 +00004998 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004999 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005000 for index in range(len(output_shape)):
5001 if output_shape[index] <= 2:
5002 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5003 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005004 output_shape[index] = output_shape[index] + rng.choice(
5005 [-2, -1, 1, 2]
5006 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005007 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5008 output_shape = input.shape.copy()
5009 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005010 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005011
5012 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005013
5014 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005015 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005016
5017 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005018 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005019
5020 for i in range(len(output_shape)):
5021 output_shape[i] = a.shape[i] * multiples[i]
5022
Luke Huttona4e48ca2023-02-22 11:53:48 +00005023 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005024 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005025
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005026 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005027 all_dtypes = [
5028 DType.INT8,
5029 DType.INT16,
5030 DType.INT32,
5031 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005032 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005033 DType.FP16,
5034 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005035 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005036 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5037 outputDType = rng.choice(wrong_dtypes)
5038 else:
5039 outputDType = a.dtype
5040
5041 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005042
5043 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005044 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005045 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005046
Kevin Cheng550ccc52021-03-03 11:21:43 -08005047 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005048
Luke Huttona4e48ca2023-02-22 11:53:48 +00005049 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005050 for i in range(len(output_shape)):
5051 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005052
Luke Huttona4e48ca2023-02-22 11:53:48 +00005053 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5054 for i in range(len(output_shape)):
5055 output_shape[i] += rng.integers(1, 10)
5056 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005057 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005058
Matthew Haddone807aae2021-10-11 18:12:58 +01005059 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005060 all_dtypes = [
5061 DType.INT8,
5062 DType.INT16,
5063 DType.INT32,
5064 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005065 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005066 DType.FP16,
5067 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005068 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005069 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5070 outputDType = rng.choice(wrong_dtypes)
5071 else:
5072 outputDType = a.dtype
5073
5074 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005075
5076 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005077 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005078 if error_name != ErrorIf.WrongRank:
5079 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005080 assert len(indices.shape) == 2
5081 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005082
Kevin Cheng77d0f762020-11-24 10:26:32 -08005083 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5084
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005085 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005086 all_dtypes = [
5087 DType.INT8,
5088 DType.INT16,
5089 DType.INT32,
5090 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005091 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005092 DType.FP16,
5093 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005094 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005095 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5096 outputDType = rng.choice(wrong_dtypes)
5097 else:
5098 outputDType = values.dtype
5099
5100 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005101
5102 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005103 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005104 if error_name != ErrorIf.WrongRank:
5105 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005106 assert len(indices.shape) == 2
5107 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005108 assert values_in.shape[0] == indices.shape[0] # N
5109 assert input.shape[1] == indices.shape[1] # W
5110 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005111
5112 output_shape = values_in.shape
5113
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005114 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005115 all_dtypes = [
5116 DType.INT8,
5117 DType.INT16,
5118 DType.INT32,
5119 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005120 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005121 DType.FP16,
5122 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005123 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005124 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5125 outputDType = rng.choice(wrong_dtypes)
5126 else:
5127 outputDType = values_in.dtype
5128
5129 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005130
5131 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005132 def tableOp(ser, rng, input, error_name=None):
5133 # Same shape as the input, dtype dependent on input dtype
5134 if error_name != ErrorIf.WrongInputType:
5135 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005136 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005137 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005138 wrong_dtypes = [
5139 DType.INT8,
5140 DType.INT16,
5141 DType.INT32,
5142 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005143 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005144 DType.FP16,
5145 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005146 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005147 wrong_dtypes.remove(output_dtype)
5148 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005149 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005150
5151 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005152 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005153 serializer,
5154 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005155 input,
5156 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005157 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005158 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005159 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005160 input_dtype,
5161 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005162 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005163 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005164 # Calculate OH, OW
5165 scale_y_n = scale[0]
5166 scale_y_d = scale[1]
5167 scale_x_n = scale[2]
5168 scale_x_d = scale[3]
5169 if error_name == ErrorIf.ScaleSmallerEqualZero:
5170 scale_y_n = max(scale_y_n, 1)
5171 scale_y_d = max(scale_y_d, 1)
5172 scale_x_n = max(scale_x_n, 1)
5173 scale_x_d = max(scale_x_d, 1)
5174
5175 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5176 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5177
5178 if error_name is not None:
5179 # Make sure the output tensor is valid, which can occur when
5180 # scale, offset or border have been changed for ERROR_IFs
5181 oh = max(oh, 1)
5182 ow = max(ow, 1)
5183 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005184 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5185 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005186
5187 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5188 choices = [1, 2, 3]
5189 change = rng.choice(choices)
5190 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5191 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005192 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005193 oh -= scale_y_d
5194 assert oh > 0 # Should have been caught in agResize
5195 else:
5196 oh += scale_y_d
5197 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005198 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005199 ow -= scale_x_d
5200 assert ow > 0 # Should have been caught in agResize
5201 else:
5202 ow += scale_x_d
5203
Matthew Haddon848efb42021-09-09 12:30:53 +01005204 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005205 output_dims = [
5206 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005207 oh,
5208 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005209 input.shape[0],
5210 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005211 elif error_name == ErrorIf.BatchMismatch:
5212 output_dims = [
5213 input.shape[0] + rng.integers(1, 10),
5214 oh,
5215 ow,
5216 input.shape[3],
5217 ]
5218 elif error_name == ErrorIf.ChannelMismatch:
5219 output_dims = [
5220 input.shape[0],
5221 oh,
5222 ow,
5223 input.shape[3] + rng.integers(1, 10),
5224 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005225 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005226 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005227
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005228 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005229
5230 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005231 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005232 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005233
5234 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005235 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005236 if error_name == ErrorIf.ConvOutputShapeMismatch:
5237 choices = [1, 2, 3]
5238 change = rng.choice(choices)
5239 if change in [1, 3]:
5240 output_shape[1] = output_shape[1] + rng.choice(choices)
5241 if change in [2, 3]:
5242 output_shape[2] = output_shape[2] + rng.choice(choices)
5243
James Ward8b390432022-08-12 20:48:56 +01005244 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005245 # Pick some potentially correct output dtype if input type is incorrect
5246 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005247 else:
James Ward8b390432022-08-12 20:48:56 +01005248 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005249
5250 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005251 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005252 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005253 else:
5254 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005255 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005256 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005257
Kevin Cheng550ccc52021-03-03 11:21:43 -08005258 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005259
5260 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005261 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5262 outputs = []
5263
5264 assert ifm1.dtype == ifm2.dtype
5265 input_dtype = ifm1.dtype
5266
5267 if error_name != ErrorIf.FFTInputShapeMismatch:
5268 assert ifm1.shape == ifm2.shape
5269
5270 input_shape = ifm1.shape
5271 if error_name != ErrorIf.WrongRank:
5272 assert len(input_shape) == 3
5273
5274 output_shape = input_shape.copy()
5275 output_dtype = input_dtype
5276
5277 if error_name == ErrorIf.WrongOutputType:
5278 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005279 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005280 output_dtype = rng.choice(wrong_dtypes)
5281 elif error_name == ErrorIf.BatchMismatch:
5282 output_shape[0] += rng.integers(1, 10)
5283 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5284 modify_dim = rng.choice([1, 2])
5285 output_shape[modify_dim] += rng.integers(1, 10)
5286
5287 outputs.append(serializer.addOutput(output_shape, output_dtype))
5288 outputs.append(serializer.addOutput(output_shape, output_dtype))
5289 return outputs
5290
5291 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005292 def rfft2dOp(serializer, rng, value, error_name=None):
5293 outputs = []
5294
5295 input_shape = value.shape
5296 if error_name != ErrorIf.WrongRank:
5297 assert len(input_shape) == 3
5298
5299 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5300
5301 output_dtype = value.dtype
5302 if error_name == ErrorIf.WrongOutputType:
5303 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005304 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005305 output_dtype = rng.choice(wrong_dtypes)
5306 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005307 output_shape[0] += rng.integers(1, 10)
5308 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5309 modify_dim = rng.choice([1, 2])
5310 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005311
5312 outputs.append(serializer.addOutput(output_shape, output_dtype))
5313 outputs.append(serializer.addOutput(output_shape, output_dtype))
5314 return outputs