blob: 3180cf59057907541e729b8a847840ef5bb82e76 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
40 TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6)
41 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
69 vals.append(v)
70 return tuple(sorted(vals))
71
72 self.random_float_range = {}
73 for dtype in (DType.FP32, DType.FP16, DType.BF16):
74 self.random_float_range[dtype] = convertFPRange(
75 args.tensor_fp_value_range,
76 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
77 )
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 def createSerializer(self, opName, testPath):
80 self.testPath = os.path.join(opName, testPath)
81
82 fullPath = os.path.join(self.basePath, self.testPath)
83 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010084 # Embed const data in the flatbuffer
85 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010086 if self.args.lazy_data_gen:
87 # Lazy data generation - so make constants files
88 constMode = ts.ConstMode.INPUTS
89 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 constMode = ts.ConstMode.EMBED_DUMP
91 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070092
93 def getSerializer(self):
94 return self.ser
95
Jeremy Johnson1271c442023-09-05 11:39:26 +010096 def serialize(self, testName, metaData=None):
97 path = Path(self.basePath) / self.testPath
98
99 # Write out TOSA flatbuffer binary
100 path_fb = path / f"{testName}.tosa"
101 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 fd.write(self.ser.serialize())
103
Jeremy Johnson1271c442023-09-05 11:39:26 +0100104 # Get JSON descriptor from serializer
105 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
106
107 if metaData:
108 # Add extra meta data to desc.json
109 desc["meta"] = metaData
110
111 # Validate desc.json before we output it
112 self.descSchemaValidator.validate_config(desc)
113
114 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100115 if "data_gen" in metaData:
116 if self.args.lazy_data_gen:
117 # Output datagen meta data as CPP data
118 path_md = path / f"{testName}_meta_data_gen.cpp"
119 with path_md.open("w") as fd:
120 fd.write(TOSA_AUTOGENERATED_HEADER)
121 fd.write("// Test meta data for data generation setup\n\n")
122 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
123 json.dump(metaData["data_gen"], fd)
124 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100125 if "compliance" in metaData:
126 # Output datagen meta data as CPP data
127 path_md = path / f"{testName}_meta_compliance.cpp"
128 with path_md.open("w") as fd:
129 fd.write(TOSA_AUTOGENERATED_HEADER)
130 fd.write("// Test meta data for compliance validation\n\n")
131 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
132 json.dump(metaData["compliance"], fd)
133 fd.write(')";\n\n')
134
135 # Write desc.json
136 path_desc = path / "desc.json"
137 with path_desc.open("w") as fd:
138 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700139
Matthew Haddon74567092021-07-16 15:38:20 +0100140 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000141 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100142 seed = self.random_seed + 1
143 self.rng = np.random.default_rng(seed)
144
Jeremy Johnson1271c442023-09-05 11:39:26 +0100145 def getDTypeRange(self, dtype, high_inclusive=False):
146 # Returns dtype value range boundaries (low, high)
147 # The high boundary is excluded in the range
148 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100149 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100150 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 elif dtype == DType.BOOL:
152 rng = (0, 2)
153 elif dtype == DType.UINT8:
154 rng = (0, 256)
155 elif dtype == DType.UINT16:
156 rng = (0, 65536)
157 elif dtype == DType.INT4:
158 # TOSA specific INT4 weight range from -7 to 7
159 rng = (-7, 8)
160 elif dtype == DType.INT8:
161 rng = (-128, 128)
162 elif dtype == DType.INT16:
163 rng = (-32768, 32768)
164 elif dtype in (DType.INT32, DType.SHAPE):
165 # restricting too large value for SHAPE
166 rng = (-(1 << 31), (1 << 31))
167 elif dtype == DType.INT48:
168 rng = (-(1 << 47), (1 << 47))
169 else:
170 raise Exception("Unknown dtype: {}".format(dtype))
171
172 if not high_inclusive:
173 # Exclusive high: low <= range < high
174 return rng
175 else:
176 # Inclusive range: low <= range <= high
177 return (rng[0], rng[1] - 1)
178
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 def getRandTensor(self, shape, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100180 low, high = self.getDTypeRange(dtype)
181
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700183 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700184 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100185 return np.int64(self.rng.integers(low=low, high=high, size=shape))
186 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
187 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
188
189 if dtype == DType.FP16:
190 return np.float16(f_tensor)
191 else:
192 f32_tensor = np.float32(f_tensor)
193 if dtype == DType.BF16:
194 # Floor the last 16 bits of each f32 value
195 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
196 else:
197 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700198 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 # All other integer types
200 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700201
Kevin Cheng989cb052021-04-28 16:29:44 -0700202 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700203 placeholders = []
204
Kevin Cheng989cb052021-04-28 16:29:44 -0700205 assert len(shape_list) == len(dtype_list)
206
Jeremy Johnson1271c442023-09-05 11:39:26 +0100207 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700208 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100209 if not self.args.lazy_data_gen:
210 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700212
213 return placeholders
214
Kevin Cheng989cb052021-04-28 16:29:44 -0700215 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700216 consts = []
217
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 assert len(shape_list) == len(dtype_list)
219
Jeremy Johnson1271c442023-09-05 11:39:26 +0100220 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700221 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100222 if not self.args.lazy_data_gen:
223 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700225
226 return consts
227
228 def makeShape(self, rank):
229 if self.targetted_shape:
230 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800231 return np.int32(
232 self.rng.integers(
233 low=self.args.tensor_shape_range[0],
234 high=self.args.tensor_shape_range[1],
235 size=rank,
236 )
237 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700238
239 def setTargetShape(self, shape):
240 self.targetted_shape = shape
241
242 def randInt(self, low=0, high=256):
243 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
244
245 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100246 low, high = self.getDTypeRange(dtype)
247
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100248 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100249 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100250 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100251 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100252 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100253 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
254 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700255 elif dtype == DType.BOOL:
256 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700257 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 # Special size
259 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700260
261 return np.int32(self.rng.integers(low, high, size=1))[0]
262
263 def shapeStr(self, shape):
264
265 sStr = []
266 # Convert to strings
267 for i in shape:
268 sStr.append(str(i))
269
Kevin Cheng550ccc52021-03-03 11:21:43 -0800270 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700271
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100272 def typeStr(self, dtype):
273 if isinstance(dtype, list) or isinstance(dtype, tuple):
274 assert len(dtype) >= 2
275 strs = [self.typeStr(t) for t in dtype]
276 # Limit types to the first 2 as the 3rd is the accumulator
277 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700278 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100279 if dtype in gtu.DTYPE_ATTRIBUTES:
280 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700281 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100282 raise Exception(
283 "Unknown dtype, cannot convert to string: {}".format(dtype)
284 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100286 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100287 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100288 if dtype in gtu.DTYPE_ATTRIBUTES:
289 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700290 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100291 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700292
Luke Hutton57287132023-02-06 14:54:18 +0000293 def constrictBatchSize(self, shape):
294 # Limit the batch size unless an explicit target shape set
295 if self.args.max_batch_size and not self.args.target_shapes:
296 shape[0] = min(shape[0], self.args.max_batch_size)
297 return shape
298
James Ward30124a82023-02-02 14:56:33 +0000299 def makeDimension(self):
300 return self.randInt(
301 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
302 )
303
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100304 def tensorComplianceMetaData(
305 self, op, inputType, argsDict, outputTensor, errorName
306 ):
307 if (
308 errorName
309 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
310 or not gtu.dtypeIsSupportedByCompliance(inputType)
311 ):
312 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100313 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100314
Jeremy Johnson1271c442023-09-05 11:39:26 +0100315 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100316 compliance_tens = {
317 "mode": None,
318 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
319 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
320 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100321 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
322 mode = gtu.ComplianceMode.DOT_PRODUCT
323 compliance_tens["dot_product_info"] = {
324 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100325 "ks": int(argsDict["ksb"])
326 if "ksb" in argsDict
327 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100328 }
329 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
330 mode = gtu.ComplianceMode.FP_SPECIAL
331 elif "compliance" in op and "ulp" in op["compliance"]:
332 mode = gtu.ComplianceMode.ULP
333 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
334 elif op["op"] == Op.REDUCE_PRODUCT:
335 mode = gtu.ComplianceMode.REDUCE_PRODUCT
336 else:
337 mode = gtu.ComplianceMode.EXACT
338 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
339
340 return compliance_tens
341
342 # Build Op functions
343 # Create the output tensor (calling OutputShaper as needed)
344 # Do final tweaks to attributes (if necessary for errorIf)
345 # Add Op into graph
346 # Return resulting tensor information or BuildInfo
347
348 class BuildInfo:
349 """Enhanced build information containing result tensor and associated compliance dict."""
350
351 def __init__(self, resultTensor, complianceDict):
352 self.resultTensor = resultTensor
353 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700354
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100355 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
356 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
357
Matthew Haddon848efb42021-09-09 12:30:53 +0100358 # build_placeholder returns an int, ABS/other ops does not
359 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000360 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100361 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000362 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000363 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100364 return result_tens
365
366 # Ensure new output type has correct qinfo
367 if error_name == ErrorIf.WrongOutputType:
368 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000369 qinfo = [
370 TosaQuantGen.getZeroPoint(self, a.dtype),
371 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
372 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100373
374 # Invalidate Input/Output list for error if checks.
375 input_list = [a.name]
376 output_list = [result_tens.name]
377 pCount, cCount = op["operands"]
378 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000379 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
380 self, error_name, input_list, output_list
381 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100382
Les Bell729b0352021-11-24 10:28:21 +0000383 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100384 self.ser,
385 validator_fcns,
386 error_name,
387 op=op,
388 input_dtype=a.dtype,
389 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000390 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000391 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100392 input_list=input_list,
393 output_list=output_list,
394 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000395 ):
396 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100397
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000398 attr = None
399 if op["op"] == Op.NEGATE:
400 attr = ts.TosaSerializerAttribute()
401 attr.NegateAttribute(qinfo[0], qinfo[1])
402
403 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700404 return result_tens
405
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000406 def build_binary_broadcast(
407 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
408 ):
409 assert len(inputs) == 2
410 a, b = inputs
411 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000412 self.ser, self.rng, a, b, error_name
413 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100414
415 # Invalidate Input/Output list for error if checks.
416 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000417 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100418 pCount, cCount = op["operands"]
419 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000420 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
421 self, error_name, input_list, output_list
422 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100423
Les Bell729b0352021-11-24 10:28:21 +0000424 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100425 self.ser,
426 validator_fcns,
427 error_name,
428 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000429 input1=a,
430 input2=b,
431 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000432 output_dtype=result_tensor.dtype,
433 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100434 input_list=input_list,
435 output_list=output_list,
436 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000437 ):
438 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100439
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000440 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000441
442 if op["op"] == Op.POW:
443 # TODO - add compliance support
444 compliance = None
445 else:
446 compliance = self.tensorComplianceMetaData(
447 op, a.dtype, args_dict, result_tensor, error_name
448 )
449
450 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700451
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100452 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700453 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000454 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700455 return result_tens
456
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000457 def build_arithmetic_right_shift(
458 self, op, a, b, round, validator_fcns=None, error_name=None
459 ):
460 result_tens = OutputShaper.binaryBroadcastOp(
461 self.ser, self.rng, a, b, error_name
462 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100463
464 # Invalidate Input/Output list for error if checks.
465 input_list = [a.name, b.name]
466 output_list = [result_tens.name]
467 pCount, cCount = op["operands"]
468 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000469 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
470 self, error_name, input_list, output_list
471 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100472
Les Bell729b0352021-11-24 10:28:21 +0000473 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100474 self.ser,
475 validator_fcns,
476 error_name,
477 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000478 input1=a,
479 input2=b,
480 input_dtype=a.dtype,
481 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000482 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100483 input_list=input_list,
484 output_list=output_list,
485 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000486 ):
487 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800488
489 attr = ts.TosaSerializerAttribute()
490 attr.ArithmeticRightShiftAttribute(round)
491
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000492 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800493 return result_tens
494
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100495 def build_mul(
496 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
497 ):
498 assert len(inputs) == 2
499 a, b = inputs
500 shift = args_dict["shift"]
501
502 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000503 self.ser, self.rng, a, b, error_name
504 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700505
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100506 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100507 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100508 result_tensor.setDtype(DType.INT32)
509
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100510 if error_name == ErrorIf.WrongOutputType:
511 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
512 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100513 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100514
515 # Invalidate Input/Output list for error if checks.
516 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100517 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100518 pCount, cCount = op["operands"]
519 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000520 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
521 self, error_name, input_list, output_list
522 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100523
Les Bell729b0352021-11-24 10:28:21 +0000524 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100525 self.ser,
526 validator_fcns,
527 error_name,
528 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000529 input1=a,
530 input2=b,
531 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100532 output_dtype=result_tensor.dtype,
533 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100534 input_list=input_list,
535 output_list=output_list,
536 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000537 ):
538 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700539
Kevin Chengaee1fac2020-11-11 13:54:06 -0800540 attr = ts.TosaSerializerAttribute()
541 attr.MulAttribute(shift)
542
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000543 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100544
545 compliance = self.tensorComplianceMetaData(
546 op, a.dtype, args_dict, result_tensor, error_name
547 )
548
549 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700550
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100551 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
552 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700553
Kevin Chengfe392ce2021-10-18 21:51:55 +0000554 attr = ts.TosaSerializerAttribute()
555 attr.TableAttribute(table)
556
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100557 # Invalidate Input/Output list for error if checks.
558 input_list = [a.name]
559 output_list = [result_tens.name]
560 pCount, cCount = op["operands"]
561 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000562 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
563 self, error_name, input_list, output_list
564 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100565
Les Bell729b0352021-11-24 10:28:21 +0000566 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100567 self.ser,
568 validator_fcns,
569 error_name,
570 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000571 input_shape=a.shape,
572 input_dtype=a.dtype,
573 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000574 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100575 input_list=input_list,
576 output_list=output_list,
577 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000578 ):
579 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100580
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000581 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700582
583 return result_tens
584
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100585 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
586 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
587
588 # Invalidate Input/Output list for error if checks.
589 input_list = [cond.name, a.name, b.name]
590 output_list = [result_tens.name]
591 pCount, cCount = op["operands"]
592 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000593 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
594 self, error_name, input_list, output_list
595 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100596
Les Bell729b0352021-11-24 10:28:21 +0000597 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100598 self.ser,
599 validator_fcns,
600 error_name,
601 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000602 input1=cond,
603 input2=a,
604 input3=b,
605 input_shape=a.shape,
606 input_dtype=a.dtype,
607 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000608 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100609 input_list=input_list,
610 output_list=output_list,
611 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000612 ):
613 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100614
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000615 self.ser.addOperator(
616 op["op"],
617 input_list,
618 output_list,
619 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700620 return result_tens
621
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000623 result_tens = OutputShaper.binaryComparisonOp(
624 self.ser, self.rng, a, b, error_name
625 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100626
627 # Invalidate Input/Output list for error if checks.
628 input_list = [a.name, b.name]
629 output_list = [result_tens.name]
630 pCount, cCount = op["operands"]
631 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000632 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
633 self, error_name, input_list, output_list
634 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100635
Les Bell729b0352021-11-24 10:28:21 +0000636 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100637 self.ser,
638 validator_fcns,
639 error_name,
640 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000641 input1=a,
642 input2=b,
643 input_shape=a.shape,
644 input_dtype=a.dtype,
645 output_shape=result_tens.shape,
646 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000647 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100648 input_list=input_list,
649 output_list=output_list,
650 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000651 ):
652 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100653
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000654 self.ser.addOperator(
655 op["op"],
656 input_list,
657 output_list,
658 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700659 return result_tens
660
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000661 def build_argmax(
662 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
663 ):
664 assert len(inputs) == 1
665 a = inputs[0]
666 axis = args_dict["axis"]
667 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100668
669 # Invalidate Input/Output list for error if checks.
670 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000671 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100672 pCount, cCount = op["operands"]
673 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000674 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
675 self, error_name, input_list, output_list
676 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100677
Les Bell729b0352021-11-24 10:28:21 +0000678 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100679 self.ser,
680 validator_fcns,
681 error_name,
682 op=op,
683 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000684 input_shape=a.shape,
685 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000686 output_shape=result_tensor.shape,
687 output_dtype=result_tensor.dtype,
688 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100689 input_list=input_list,
690 output_list=output_list,
691 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000692 ):
693 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700694
695 attr = ts.TosaSerializerAttribute()
696 attr.AxisAttribute(axis)
697
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000698 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000699
700 compliance = self.tensorComplianceMetaData(
701 op, inputs[0].dtype, args_dict, result_tensor, error_name
702 )
703 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700704
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000705 def build_pool2d(
706 self,
707 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100708 inputs,
709 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000710 validator_fcns=None,
711 error_name=None,
712 qinfo=None,
713 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100714 assert len(inputs) == 1
715 input = inputs[0]
716 # max_pool has no accum_dtype
717 accum_dtype = (
718 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
719 )
720 stride = args_dict["stride"]
721 pad = args_dict["pad"]
722 kernel = args_dict["kernel"]
723
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000724 result_tens = OutputShaper.pool2dOp(
725 self.ser, self.rng, input, kernel, stride, pad, error_name
726 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100727
728 # Ensure new output type has correct qinfo
729 if error_name == ErrorIf.WrongInputType:
730 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000731 qinfo = [
732 TosaQuantGen.getZeroPoint(self, input.dtype),
733 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
734 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100735
736 # Invalidate Input/Output list for error if checks.
737 input_list = [input.name]
738 output_list = [result_tens.name]
739 pCount, cCount = op["operands"]
740 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000741 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
742 self, error_name, input_list, output_list
743 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100744
Les Bell729b0352021-11-24 10:28:21 +0000745 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100746 self.ser,
747 validator_fcns,
748 error_name,
749 op=op,
750 input_shape=input.shape,
751 input_dtype=input.dtype,
752 output_shape=result_tens.shape,
753 output_dtype=result_tens.dtype,
754 kernel=kernel,
755 stride=stride,
756 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000757 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000758 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100759 input_list=input_list,
760 output_list=output_list,
761 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000762 ):
763 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700764
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000765 if qinfo is None:
766 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700767
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000768 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100769 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000770
771 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700772 return result_tens
773
James Ward8b390432022-08-12 20:48:56 +0100774 def build_maxpool2d(
775 self,
776 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100777 inputs,
778 args_dict,
James Ward8b390432022-08-12 20:48:56 +0100779 validator_fcns=None,
780 error_name=None,
781 qinfo=None,
782 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100783 result_tensor = self.build_pool2d(
James Ward8b390432022-08-12 20:48:56 +0100784 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100785 inputs,
786 args_dict,
James Ward8b390432022-08-12 20:48:56 +0100787 validator_fcns,
788 error_name,
789 qinfo,
790 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100791 compliance = self.tensorComplianceMetaData(
792 op, inputs[0].dtype, args_dict, result_tensor, error_name
793 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100794
795 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100796
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000797 def build_conv2d(
798 self,
799 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100800 inputs,
801 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000802 validator_fcns=None,
803 error_name=None,
804 qinfo=None,
805 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100806 assert len(inputs) == 3
807 ifm, filter, bias = inputs
808 accum_dtype = args_dict["acc_type"]
809 strides = args_dict["stride"]
810 padding = args_dict["pad"]
811 dilations = args_dict["dilation"]
812
Kevin Cheng550ccc52021-03-03 11:21:43 -0800813 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100814 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100815 self.ser,
816 self.rng,
817 ifm,
818 filter,
819 accum_dtype,
820 strides,
821 padding,
822 dilations,
823 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000824 )
825
826 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000827 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
828 DType.INT8,
829 DType.UINT8,
830 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000831 qinfo = [
832 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100833 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000834 ]
Les Bell0e027d42021-11-09 14:42:14 +0000835
836 # Invalidate Input/Output list for error_if checks.
837 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100838 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000839 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000840 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
841 self, error_name, input_list, output_list
842 )
Les Bell0e027d42021-11-09 14:42:14 +0000843
Les Bell729b0352021-11-24 10:28:21 +0000844 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000845 self.ser,
846 validator_fcns,
847 error_name,
848 op=op,
849 input_dtype=ifm.dtype,
850 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100851 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000852 qinfo=qinfo,
853 input_list=input_list,
854 num_operands=num_operands,
855 output_list=output_list,
856 pad=padding,
857 stride=strides,
858 dilation=dilations,
859 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100860 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100861 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000862 ):
863 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700864
865 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000866 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700867
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000868 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100869
870 compliance = self.tensorComplianceMetaData(
871 op, ifm.dtype, args_dict, result_tensor, error_name
872 )
873
874 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700875
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000876 def build_conv3d(
877 self,
878 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100879 inputs,
880 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000881 validator_fcns=None,
882 error_name=None,
883 qinfo=None,
884 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100885 assert len(inputs) == 3
886 ifm, filter, bias = inputs
887 accum_dtype = args_dict["acc_type"]
888 strides = args_dict["stride"]
889 padding = args_dict["pad"]
890 dilations = args_dict["dilation"]
891
Kevin Cheng1533b852021-09-01 12:51:58 -0700892 assert len(padding) == 6
893 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100894 self.ser,
895 self.rng,
896 ifm,
897 filter,
898 accum_dtype,
899 strides,
900 padding,
901 dilations,
902 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000903 )
904
905 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000906 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
907 DType.INT8,
908 DType.UINT8,
909 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000910 qinfo = [
911 TosaQuantGen.getZeroPoint(self, ifm.dtype),
912 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
913 ]
Les Bell0e027d42021-11-09 14:42:14 +0000914
915 # Invalidate Input/Output list for error_if checks.
916 input_list = [ifm.name, filter.name, bias.name]
917 output_list = [result_tens.name]
918 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000919 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
920 self, error_name, input_list, output_list
921 )
Les Bell0e027d42021-11-09 14:42:14 +0000922
Les Bell729b0352021-11-24 10:28:21 +0000923 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000924 self.ser,
925 validator_fcns,
926 error_name,
927 op=op,
928 input_dtype=ifm.dtype,
929 weight_dtype=filter.dtype,
930 output_dtype=result_tens.dtype,
931 qinfo=qinfo,
932 input_list=input_list,
933 num_operands=num_operands,
934 output_list=output_list,
935 pad=padding,
936 stride=strides,
937 dilation=dilations,
938 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100939 weight_shape=filter.shape,
940 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000941 ):
942 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700943
944 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000945 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700946
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000947 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700948 return result_tens
949
Kevin Cheng550ccc52021-03-03 11:21:43 -0800950 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000951 self,
952 op,
953 ifm,
954 filter,
955 bias,
James Ward8b390432022-08-12 20:48:56 +0100956 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000957 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700958 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000959 output_shape,
960 validator_fcns=None,
961 error_name=None,
962 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800963 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700964 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000965 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100966 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000967 )
Les Bell0e027d42021-11-09 14:42:14 +0000968
969 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000970 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
971 DType.INT8,
972 DType.UINT8,
973 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000974 qinfo = [
975 TosaQuantGen.getZeroPoint(self, ifm.dtype),
976 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
977 ]
Les Bell0e027d42021-11-09 14:42:14 +0000978
979 # Invalidate Input/Output list for error_if checks.
980 input_list = [ifm.name, filter.name, bias.name]
981 output_list = [result_tens.name]
982 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000983 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
984 self, error_name, input_list, output_list
985 )
Les Bell0e027d42021-11-09 14:42:14 +0000986
Les Bell729b0352021-11-24 10:28:21 +0000987 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000988 self.ser,
989 validator_fcns,
990 error_name,
991 op=op,
992 input_dtype=ifm.dtype,
993 weight_dtype=filter.dtype,
994 output_dtype=result_tens.dtype,
995 qinfo=qinfo,
996 input_list=input_list,
997 num_operands=num_operands,
998 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700999 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001000 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001001 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001002 weight_shape=filter.shape,
1003 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001004 ):
1005 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001006
1007 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001008 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07001009
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001010 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001011 return result_tens
1012
Kevin Cheng550ccc52021-03-03 11:21:43 -08001013 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001014 self,
1015 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001016 inputs,
1017 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001018 validator_fcns=None,
1019 error_name=None,
1020 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001021 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001022 assert len(inputs) == 3
1023 ifm, filter, bias = inputs
1024 accum_dtype = args_dict["acc_type"]
1025 strides = args_dict["stride"]
1026 padding = args_dict["pad"]
1027 dilations = args_dict["dilation"]
1028
Kevin Cheng550ccc52021-03-03 11:21:43 -08001029 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001030 self.ser,
1031 self.rng,
1032 ifm,
1033 filter,
1034 accum_dtype,
1035 strides,
1036 padding,
1037 dilations,
1038 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001039 )
1040
1041 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001042 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1043 DType.INT8,
1044 DType.UINT8,
1045 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001046 qinfo = [
1047 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1048 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1049 ]
Les Bell0e027d42021-11-09 14:42:14 +00001050
1051 # Invalidate Input/Output list for error_if checks.
1052 input_list = [ifm.name, filter.name, bias.name]
1053 output_list = [result_tens.name]
1054 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001055 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1056 self, error_name, input_list, output_list
1057 )
Les Bell0e027d42021-11-09 14:42:14 +00001058
Les Bell729b0352021-11-24 10:28:21 +00001059 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001060 self.ser,
1061 validator_fcns,
1062 error_name,
1063 op=op,
1064 input_dtype=ifm.dtype,
1065 weight_dtype=filter.dtype,
1066 output_dtype=result_tens.dtype,
1067 qinfo=qinfo,
1068 input_list=input_list,
1069 num_operands=num_operands,
1070 output_list=output_list,
1071 pad=padding,
1072 stride=strides,
1073 dilation=dilations,
1074 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001075 weight_shape=filter.shape,
1076 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001077 ):
1078 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001079
1080 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001081 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07001082
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001083 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001084 return result_tens
1085
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001086 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001087 self,
1088 op,
1089 ifm,
1090 filter,
1091 bias,
1092 accum_dtype,
1093 validator_fcns=None,
1094 error_name=None,
1095 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001096 ):
1097 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001098 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001099 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001100
1101 # Invalidate Input/Output list for error if checks.
1102 input_list = [ifm.name, filter.name, bias.name]
1103 output_list = [result_tens.name]
1104 pCount, cCount = op["operands"]
1105 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001106 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1107 self, error_name, input_list, output_list
1108 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001109
Les Bell729b0352021-11-24 10:28:21 +00001110 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001111 self.ser,
1112 validator_fcns,
1113 error_name,
1114 op=op,
1115 input_shape=ifm.shape,
1116 input_dtype=ifm.dtype,
1117 weight_dtype=filter.dtype,
1118 output_shape=result_tens.shape,
1119 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001120 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001121 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001122 input_list=input_list,
1123 output_list=output_list,
1124 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001125 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001126 ):
1127 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001128
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001129 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001130 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001131
1132 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001133 return result_tens
1134
James Ward8b390432022-08-12 20:48:56 +01001135 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001136 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001137 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001138 assert len(inputs) == 2
1139 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001140 accum_dtype = args_dict["acc_type"]
1141 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001142 self.ser, self.rng, a, b, accum_dtype, error_name
1143 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001144
1145 # Invalidate Input/Output list for error if checks.
1146 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001147 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001148 pCount, cCount = op["operands"]
1149 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001150 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1151 self, error_name, input_list, output_list
1152 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001153
Les Bell729b0352021-11-24 10:28:21 +00001154 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001155 self.ser,
1156 validator_fcns,
1157 error_name,
1158 op=op,
1159 input_shape=a.shape,
1160 input_dtype=a.dtype,
1161 input2_shape=b.shape,
1162 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001163 output_shape=result_tensor.shape,
1164 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001165 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001166 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001167 input_list=input_list,
1168 output_list=output_list,
1169 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001170 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001171 ):
1172 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001173
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001174 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001175 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001176
1177 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001178
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001179 compliance = self.tensorComplianceMetaData(
1180 op, a.dtype, args_dict, result_tensor, error_name
1181 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001182
1183 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001184
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001185 def build_reduce(
1186 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1187 ):
1188 assert len(inputs) == 1
1189 a = inputs[0]
1190 axis = args_dict["axis"]
1191 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001192
1193 # Invalidate Input/Output list for error if checks.
1194 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001195 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001196 pCount, cCount = op["operands"]
1197 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001198 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1199 self, error_name, input_list, output_list
1200 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001201
Les Bell729b0352021-11-24 10:28:21 +00001202 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001203 self.ser,
1204 validator_fcns,
1205 error_name,
1206 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001207 axis=axis,
1208 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001209 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001210 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001211 output_dtype=result_tensor.dtype,
1212 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001213 input_list=input_list,
1214 output_list=output_list,
1215 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001216 ):
1217 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001218
1219 attr = ts.TosaSerializerAttribute()
1220 attr.AxisAttribute(axis)
1221
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001222 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001223
1224 if op["op"] == Op.REDUCE_PRODUCT:
1225 # TODO: Add compliance support!
1226 compliance = None
1227 else:
1228 compliance = self.tensorComplianceMetaData(
1229 op, a.dtype, args_dict, result_tensor, error_name
1230 )
1231
1232 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001233
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001234 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1235 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001236
Jeremy Johnson18e26662021-07-22 16:15:29 +01001237 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001238
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001239 if error_name == ErrorIf.MaxSmallerMin:
1240 # Make sure the numbers are different to invoke this error
1241 while v[0] == v[1]:
1242 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1243 max_val = min(v)
1244 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001245 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001246 max_val = max(v)
1247 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001248
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001249 # Invalidate Input/Output list for error if checks.
1250 input_list = [a.name]
1251 output_list = [result_tens.name]
1252 pCount, cCount = op["operands"]
1253 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001254 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1255 self, error_name, input_list, output_list
1256 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001257
Les Bell729b0352021-11-24 10:28:21 +00001258 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001259 self.ser,
1260 validator_fcns,
1261 error_name,
1262 op=op,
1263 max_val=max_val,
1264 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001265 input_shape=a.shape,
1266 output_shape=result_tens.shape,
1267 input_dtype=a.dtype,
1268 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001269 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001270 input_list=input_list,
1271 output_list=output_list,
1272 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001273 ):
1274 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001275
1276 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001277 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1278 if a.dtype == DType.FP16:
1279 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1280 min_val = min_val.astype(np.float32)
1281 max_val = max_val.astype(np.float32)
1282
1283 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001284 else:
James Ward34071252022-12-07 15:48:47 +00001285 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001286
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001287 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001288 return result_tens
1289
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001290 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1291 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001292 attr = ts.TosaSerializerAttribute()
1293
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001294 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001295
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001296 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001297 return result_tens
1298
1299 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001300 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1301 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001302
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001303 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001304 return result_tens
1305
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001306 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1307 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1308
1309 # Invalidate Input/Output list for error if checks.
1310 input_list = [a.name]
1311 output_list = [result_tens.name]
1312 pCount, cCount = op["operands"]
1313 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001314 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1315 self, error_name, input_list, output_list
1316 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001317
Les Bell729b0352021-11-24 10:28:21 +00001318 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001319 self.ser,
1320 validator_fcns,
1321 error_name,
1322 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001323 input_shape=a.shape,
1324 output_shape=result_tens.shape,
1325 input_dtype=a.dtype,
1326 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001327 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001328 input_list=input_list,
1329 output_list=output_list,
1330 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001331 ):
1332 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001333
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001334 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001335 return result_tens
1336
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001337 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1338 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1339
1340 # Invalidate Input/Output list for error if checks.
1341 input_list = [a.name]
1342 output_list = [result_tens.name]
1343 pCount, cCount = op["operands"]
1344 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001345 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1346 self, error_name, input_list, output_list
1347 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001348
Les Bell729b0352021-11-24 10:28:21 +00001349 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001350 self.ser,
1351 validator_fcns,
1352 error_name,
1353 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001354 input_shape=a.shape,
1355 output_shape=result_tens.shape,
1356 input_dtype=a.dtype,
1357 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001358 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001359 input_list=input_list,
1360 output_list=output_list,
1361 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001362 ):
1363 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001364
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001365 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001366 return result_tens
1367
Won Jeon78155c62023-06-10 00:20:04 +00001368 def build_erf(self, op, a, validator_fcns=None, error_name=None):
1369 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1370
1371 # Invalidate Input/Output list for error if checks.
1372 input_list = [a.name]
1373 output_list = [result_tens.name]
1374 pCount, cCount = op["operands"]
1375 num_operands = pCount + cCount
1376 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1377 self, error_name, input_list, output_list
1378 )
1379
1380 if not TosaErrorValidator.evValidateErrorIfs(
1381 self.ser,
1382 validator_fcns,
1383 error_name,
1384 op=op,
1385 input_shape=a.shape,
1386 output_shape=result_tens.shape,
1387 input_dtype=a.dtype,
1388 output_dtype=result_tens.dtype,
1389 result_tensors=[result_tens],
1390 input_list=input_list,
1391 output_list=output_list,
1392 num_operands=num_operands,
1393 ):
1394 return None
1395
1396 self.ser.addOperator(op["op"], input_list, output_list)
1397 return result_tens
1398
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001399 def build_concat(
1400 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1401 ):
1402 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001403 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001404 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001405
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001406 result_tensor = OutputShaper.concatOp(
1407 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001408 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001409
Matthew Haddon818ab902021-07-27 09:12:49 +01001410 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001411 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001412 input_tensor_names.append(tensor.name)
1413
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001414 # Invalidate Input/Output list for error if checks.
1415 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001416 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001417 pCount, cCount = op["operands"]
1418 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001419 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1420 self, error_name, input_list, output_list
1421 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001422
Les Bell729b0352021-11-24 10:28:21 +00001423 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424 self.ser,
1425 validator_fcns,
1426 error_name,
1427 op=op,
1428 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001429 input_shape=inputs[0].shape,
1430 output_shape=result_tensor.shape,
1431 input_dtype=inputs[0].dtype,
1432 output_dtype=result_tensor.dtype,
1433 inputs=inputs,
1434 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001435 input_list=input_list,
1436 output_list=output_list,
1437 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001438 ):
1439 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001440
1441 attr = ts.TosaSerializerAttribute()
1442 attr.AxisAttribute(axis)
1443
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001444 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001445 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001446
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001447 def build_pad(
1448 self,
1449 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001450 inputs,
1451 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001452 validator_fcns=None,
1453 error_name=None,
1454 qinfo=None,
1455 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001456 assert len(inputs) == 1
1457 a = inputs[0]
1458 padding = args_dict["pad"]
1459 pad_const_int = args_dict["pad_const_int"]
1460 pad_const_float = args_dict["pad_const_fp"]
1461
1462 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001463
Kevin Chengfe392ce2021-10-18 21:51:55 +00001464 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001465 attr.PadAttribute(
1466 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1467 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001468
Matthew Haddone807aae2021-10-11 18:12:58 +01001469 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001470 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001471 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001472 pCount, cCount = op["operands"]
1473 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001474 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1475 self, error_name, input_list, output_list
1476 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001477
Les Bell729b0352021-11-24 10:28:21 +00001478 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001479 self.ser,
1480 validator_fcns,
1481 error_name,
1482 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001483 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001484 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001485 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001486 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001487 pad=padding,
1488 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001489 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001490 input_list=input_list,
1491 output_list=output_list,
1492 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001493 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001494 ):
1495 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001496
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001497 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001498
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001499 compliance = self.tensorComplianceMetaData(
1500 op, a.dtype, args_dict, result_tensor, error_name
1501 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001502
1503 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001504
Won Jeona21b2e82023-08-10 10:33:01 +00001505 def build_dim(
1506 self,
1507 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001508 inputs,
1509 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001510 validator_fcns=None,
1511 error_name=None,
1512 qinfo=None,
1513 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001514 assert len(inputs) == 1
1515 a = inputs[0]
1516 axis = args_dict["axis"]
1517 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001518
1519 # Invalidate Input/Output list for error if checks.
1520 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001521 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001522 pCount, cCount = op["operands"]
1523 num_operands = pCount + cCount
1524 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1525 self, error_name, input_list, output_list
1526 )
1527
1528 if not TosaErrorValidator.evValidateErrorIfs(
1529 self.ser,
1530 validator_fcns,
1531 error_name,
1532 op=op,
1533 axis=axis,
1534 input_shape=a.shape,
1535 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001536 output_shape=result_tensor.shape,
1537 output_dtype=result_tensor.dtype,
1538 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001539 input_list=input_list,
1540 output_list=output_list,
1541 num_operands=num_operands,
1542 ):
1543 return None
1544
1545 attr = ts.TosaSerializerAttribute()
1546 attr.AxisAttribute(axis)
1547
1548 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001549 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001550
Matthew Haddone807aae2021-10-11 18:12:58 +01001551 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001552 result_tens = OutputShaper.reshapeOp(
1553 self.ser, self.rng, a, newShape, error_name
1554 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001555
1556 # Invalidate Input/Output list for error if checks.
1557 input_list = [a.name]
1558 output_list = [result_tens.name]
1559 pCount, cCount = op["operands"]
1560 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001561 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1562 self, error_name, input_list, output_list
1563 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001564
Les Bell729b0352021-11-24 10:28:21 +00001565 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001566 self.ser,
1567 validator_fcns,
1568 error_name,
1569 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001570 input_shape=a.shape,
1571 output_shape=result_tens.shape,
1572 input_dtype=a.dtype,
1573 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001574 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001575 input_list=input_list,
1576 output_list=output_list,
1577 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001578 ):
1579 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001580
1581 attr = ts.TosaSerializerAttribute()
1582 attr.ReshapeAttribute(newShape)
1583
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001584 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001585 return result_tens
1586
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001587 def build_reverse(
1588 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1589 ):
1590 assert len(inputs) == 1
1591 a = inputs[0]
1592 axis = args_dict["axis"]
1593 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001594
1595 # Invalidate Input/Output list for error if checks.
1596 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001597 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001598 pCount, cCount = op["operands"]
1599 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001600 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1601 self, error_name, input_list, output_list
1602 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001603
Les Bell729b0352021-11-24 10:28:21 +00001604 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001605 self.ser,
1606 validator_fcns,
1607 error_name,
1608 op=op,
1609 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001610 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001611 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001612 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001613 output_dtype=result_tensor.dtype,
1614 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001615 input_list=input_list,
1616 output_list=output_list,
1617 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001618 ):
1619 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001620
1621 attr = ts.TosaSerializerAttribute()
1622 attr.AxisAttribute(axis)
1623
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001624 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001625 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001626
Matthew Haddone807aae2021-10-11 18:12:58 +01001627 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1628 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001629
Kevin Chengfe392ce2021-10-18 21:51:55 +00001630 attr = ts.TosaSerializerAttribute()
1631 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001632
Matthew Haddone807aae2021-10-11 18:12:58 +01001633 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001634 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001635 output_list = [result_tens.name]
1636 pCount, cCount = op["operands"]
1637 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001638 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1639 self, error_name, input_list, output_list
1640 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001641
Les Bell729b0352021-11-24 10:28:21 +00001642 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001643 self.ser,
1644 validator_fcns,
1645 error_name,
1646 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001647 input_shape=a.shape,
1648 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001649 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001650 input_dtype=a.dtype,
1651 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001652 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001653 input_list=input_list,
1654 output_list=output_list,
1655 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001656 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001657 ):
1658 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001659
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001660 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001661 return result_tens
1662
Matthew Haddone807aae2021-10-11 18:12:58 +01001663 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001664 result_tens = OutputShaper.sliceOp(
1665 self.ser, self.rng, a, start, size, error_name
1666 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001667
1668 # Invalidate Input/Output list for error if checks.
1669 input_list = [a.name]
1670 output_list = [result_tens.name]
1671 pCount, cCount = op["operands"]
1672 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001673 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1674 self, error_name, input_list, output_list
1675 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001676
Les Bell729b0352021-11-24 10:28:21 +00001677 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001678 self.ser,
1679 validator_fcns,
1680 error_name,
1681 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001682 input_shape=a.shape,
1683 output_shape=result_tens.shape,
1684 input_dtype=a.dtype,
1685 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001686 start=start,
1687 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001688 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001689 input_list=input_list,
1690 output_list=output_list,
1691 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001692 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001693 ):
1694 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001695
1696 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001697 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001698
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001699 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001700 return result_tens
1701
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001702 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1703 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1704
1705 # Invalidate Input/Output list for error if checks.
1706 input_list = [a.name]
1707 output_list = [result_tens.name]
1708 pCount, cCount = op["operands"]
1709 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001710 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1711 self, error_name, input_list, output_list
1712 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001713
Les Bell729b0352021-11-24 10:28:21 +00001714 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001715 self.ser,
1716 validator_fcns,
1717 error_name,
1718 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001719 input_shape=a.shape,
1720 output_shape=result_tens.shape,
1721 input_dtype=a.dtype,
1722 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001723 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001724 input_list=input_list,
1725 output_list=output_list,
1726 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001727 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001728 ):
1729 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001730
1731 attr = ts.TosaSerializerAttribute()
1732 attr.TileAttribute(multiples)
1733
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001734 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001735 return result_tens
1736
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001737 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001738
1739 # Create a new indicies tensor
1740 # here with data that doesn't exceed the dimensions of the values tensor
1741
Kevin Cheng550ccc52021-03-03 11:21:43 -08001742 K = values.shape[1] # K
1743 W = self.randInt(
1744 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1745 ) # W
1746 indicies_arr = np.int32(
1747 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1748 ) # (N, W)
1749 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001750
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001751 result_tens = OutputShaper.gatherOp(
1752 self.ser, self.rng, values, indicies, error_name
1753 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001754
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001755 # Invalidate Input/Output list for error if checks.
1756 input_list = [values.name, indicies.name]
1757 output_list = [result_tens.name]
1758 pCount, cCount = op["operands"]
1759 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001760 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1761 self, error_name, input_list, output_list
1762 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001763
Les Bell729b0352021-11-24 10:28:21 +00001764 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001765 self.ser,
1766 validator_fcns,
1767 error_name,
1768 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001769 input_shape=values.shape,
1770 output_shape=result_tens.shape,
1771 input_dtype=values.dtype,
1772 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001773 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001774 input_list=input_list,
1775 output_list=output_list,
1776 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001777 ):
1778 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001779
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001780 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001781
1782 return result_tens
1783
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001784 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001785
1786 # Create a new indicies tensor
1787 # here with data that doesn't exceed the dimensions of the values_in tensor
1788
Kevin Cheng550ccc52021-03-03 11:21:43 -08001789 K = values_in.shape[1] # K
1790 W = input.shape[1] # W
1791 indicies_arr = np.int32(
1792 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1793 ) # (N, W)
1794 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001795
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001796 result_tens = OutputShaper.scatterOp(
1797 self.ser, self.rng, values_in, indicies, input, error_name
1798 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001799
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001800 # Invalidate Input/Output list for error if checks.
1801 input_list = [values_in.name, indicies.name, input.name]
1802 output_list = [result_tens.name]
1803 pCount, cCount = op["operands"]
1804 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001805 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1806 self, error_name, input_list, output_list
1807 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001808
Les Bell729b0352021-11-24 10:28:21 +00001809 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001810 self.ser,
1811 validator_fcns,
1812 error_name,
1813 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001814 input_shape=values_in.shape,
1815 output_shape=result_tens.shape,
1816 input_dtype=values_in.dtype,
1817 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001818 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001819 input_list=input_list,
1820 output_list=output_list,
1821 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001822 ):
1823 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001824
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001825 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001826
Kevin Cheng77d0f762020-11-24 10:26:32 -08001827 return result_tens
1828
Kevin Cheng550ccc52021-03-03 11:21:43 -08001829 def build_resize(
1830 self,
1831 op,
1832 input,
1833 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001834 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001835 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001836 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001837 input_dtype,
1838 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001839 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001840 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001841 ):
1842 result_tens = OutputShaper.resizeOp(
1843 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001844 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001845 input,
1846 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001847 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001848 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001849 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001850 input_dtype,
1851 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001852 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001853 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001854
Matthew Haddon848efb42021-09-09 12:30:53 +01001855 # Invalidate Input/Output list for error if checks.
1856 input_list = [input.name]
1857 output_list = [result_tens.name]
1858 pCount, cCount = op["operands"]
1859 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001860 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1861 self, error_name, input_list, output_list
1862 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001863
Les Bell729b0352021-11-24 10:28:21 +00001864 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001865 self.ser,
1866 validator_fcns,
1867 error_name,
1868 op=op,
1869 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001870 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001871 input_dtype=input_dtype,
1872 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001873 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001874 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001875 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001876 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001877 input_list=input_list,
1878 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001879 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001880 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001881 ):
1882 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001883
Eric Kunzee5e26762020-10-13 16:11:07 -07001884 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001885
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001886 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001887
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001888 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001889 return result_tens
1890
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001891 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1892 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1893 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001894 self.ser.addOperator(
1895 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1896 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001897 return result_tens
1898
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001899 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001900 self.ser.addOutputTensor(val)
1901 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001902
1903 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001904 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001905 result_tens = OutputShaper.typeConversionOp(
1906 self.ser, self.rng, val, out_dtype, error_name
1907 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001908
1909 # Invalidate Input/Output list for error if checks.
1910 input_list = [val.name]
1911 output_list = [result_tens.name]
1912 pCount, cCount = op["operands"]
1913 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001914 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1915 self, error_name, input_list, output_list
1916 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001917
Les Bell729b0352021-11-24 10:28:21 +00001918 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001919 self.ser,
1920 validator_fcns,
1921 error_name,
1922 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001923 input_shape=val.shape,
1924 output_shape=result_tens.shape,
1925 input_dtype=val.dtype,
1926 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001927 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001928 input_list=input_list,
1929 output_list=output_list,
1930 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001931 ):
1932 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001933
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001934 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001935 return result_tens
1936
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001937 def build_rescale(
1938 self,
1939 op,
1940 val,
1941 out_dtype,
1942 scale32,
1943 double_round,
1944 per_channel,
1945 validator_fcns,
1946 error_name,
1947 ):
1948 result_tens = OutputShaper.typeConversionOp(
1949 self.ser, self.rng, val, out_dtype, error_name
1950 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001951
1952 if per_channel:
1953 nc = val.shape[-1]
1954 else:
1955 nc = 1
1956
1957 in_type_width = self.typeWidth(val.dtype)
1958 out_type_width = self.typeWidth(out_dtype)
1959
Kevin Cheng3a478572021-01-22 17:21:02 -08001960 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001961 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001962 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001963 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001964 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001965 in_type_width += 1
1966 elif error_name in [
1967 ErrorIf.InputZeroPointNotZero,
1968 ErrorIf.U16InputZeroPointNotValid,
1969 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001970 input_zp = self.randInt(-128, 128)
1971 if input_zp == 0:
1972 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001973 in_type_width += 1
1974 elif val.dtype == DType.UINT16:
1975 # Must come after ErrorIf.U16InputZeroPointNotValid check
1976 input_zp = self.rng.choice([0, 32768])
1977 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001978 else:
1979 input_zp = 0
1980
Kevin Cheng3a478572021-01-22 17:21:02 -08001981 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001982 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001983 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001984 elif out_dtype == DType.UINT8:
1985 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001986 out_type_width += 1
1987 elif error_name in [
1988 ErrorIf.OutputZeroPointNotZero,
1989 ErrorIf.U16OutputZeroPointNotValid,
1990 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001991 output_zp = self.randInt(-128, 128)
1992 if output_zp == 0:
1993 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001994 out_type_width += 1
1995 elif out_dtype == DType.UINT16:
1996 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1997 output_zp = self.rng.choice([0, 32768])
1998 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001999 else:
2000 output_zp = 0
2001
2002 # Calculate scale based on:
2003 # scale = a *(2^output_width)/(2^input_width))
2004
2005 a = np.float32(self.rng.random(size=[nc]))
2006 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2007
2008 if scale32:
2009 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002010 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002011 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2012 else:
2013 # Cap the scaling at 2^15 - 1 for scale16
2014 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2015
Kevin Cheng550ccc52021-03-03 11:21:43 -08002016 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002017
2018 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2019 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002020 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2021 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002022
2023 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002024 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2025 scale_arr[i], scale32
2026 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002027 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2028 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002029
Kevin Cheng550ccc52021-03-03 11:21:43 -08002030 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002031 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002032 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002033 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002034 assert val.placeholderFilename
2035 values = np.load(
2036 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2037 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002038 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2039 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2040 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2041 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002042 if not np.all(np.array_equal(values, val_adj)):
2043 # Values changed so overwrite file with new values
2044 np.save(
2045 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2046 val_adj,
2047 False,
2048 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002049
Matthew Haddonc2025212021-10-08 21:21:05 +01002050 # Invalidate Input/Output list for error if checks.
2051 input_list = [val.name]
2052 output_list = [result_tens.name]
2053 pCount, cCount = op["operands"]
2054 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002055 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2056 self, error_name, input_list, output_list
2057 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002058
2059 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002060 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002061 self.ser,
2062 validator_fcns,
2063 error_name,
2064 op=op,
2065 input_dtype=val.dtype,
2066 output_dtype=out_dtype,
2067 input_shape=val.shape,
2068 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002069 scale32=scale32,
2070 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002071 input_list=input_list,
2072 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002073 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002074 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002075 ):
2076 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002077
Eric Kunzee5e26762020-10-13 16:11:07 -07002078 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002079 attr.RescaleAttribute(
2080 input_zp,
2081 output_zp,
2082 multiplier_arr,
2083 shift_arr,
2084 scale32,
2085 double_round,
2086 per_channel,
2087 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002088
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002089 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002090 return result_tens
2091
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002092 def _get_condition_tensor(self, op, cond, error_name):
2093 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002094 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002095 else:
2096 cond_type = DType.BOOL
2097 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2098 choice = self.rng.choice([1, 2])
2099 if choice == 1:
2100 cond_shape = [2]
2101 else:
2102 cond_shape = [1, 2]
2103 else:
2104 # Must be of size 1 (rank 0)
2105 cond_shape = []
2106 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2107 return cond_tens
2108
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002109 def build_cond_if_const(
2110 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2111 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002112 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002113 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002114 # and fill them with const nodes for the body.
2115
2116 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002117 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002118
2119 # Make then/else tensors
2120 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002121
2122 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002123 if error_name in [
2124 ErrorIf.CondIfOutputListThenGraphMismatch,
2125 ErrorIf.CondIfOutputListElseGraphMismatch,
2126 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002127 incorrect_shape = deepcopy(then_tens.shape)
2128 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002129 incorrect_shape[i] += (
2130 self.rng.choice([-3, -2, 2, 3])
2131 if incorrect_shape[i] > 3
2132 else self.rng.choice([1, 2, 4])
2133 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002134 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2135
Jeremy Johnson18e26662021-07-22 16:15:29 +01002136 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2137 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002138
2139 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002140 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002141
2142 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002143 then_block = "THEN_BLOCK"
2144 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002145 attr = ts.TosaSerializerAttribute()
2146 attr.CondIfAttribute(then_block, else_block)
2147
2148 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002149 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002150
Jerry Ge9e94af82022-10-27 09:57:00 -07002151 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002152 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002153 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2154 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2155 else:
2156 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002157 self.ser.addOutputTensor(then_tens)
2158
Jerry Ge9e94af82022-10-27 09:57:00 -07002159 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002160 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2161 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2162 else:
2163 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002164 self.ser.addOutputTensor(else_tens)
2165
Les Bell729b0352021-11-24 10:28:21 +00002166 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002167 self.ser,
2168 validator_fcns,
2169 error_name,
2170 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002171 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002172 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002173 ):
2174 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002175
Eric Kunzee5e26762020-10-13 16:11:07 -07002176 return result_tens
2177
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002178 def build_cond_if_binary(
2179 self, op, a, b, cond, validator_fcns=None, error_name=None
2180 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002181 # For cond_if with a binary op in the then/else blocks, take a and b and
2182 # alternately add or subtract them based on the condition
2183
2184 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002185 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002186
Kevin Cheng550ccc52021-03-03 11:21:43 -08002187 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002188
2189 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002190 then_block = "THEN_BLOCK"
2191 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002192 attr = ts.TosaSerializerAttribute()
2193 attr.CondIfAttribute(then_block, else_block)
2194
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002195 if error_name in [
2196 ErrorIf.CondIfInputListThenGraphMismatch,
2197 ErrorIf.CondIfInputListElseGraphMismatch,
2198 ErrorIf.CondIfOutputListElseGraphMismatch,
2199 ErrorIf.CondIfOutputListThenGraphMismatch,
2200 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002201 incorrect_shape = a.shape.copy()
2202 for i in range(len(incorrect_shape)):
2203 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2204 incorrect_block_input = deepcopy(a)
2205 incorrect_block_input.shape = incorrect_shape
2206
Eric Kunzee5e26762020-10-13 16:11:07 -07002207 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002208 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002209 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002210 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002211
James Ward24dbc422022-10-19 12:20:31 +01002212 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002213 then_op, else_op = Op.ADD, Op.SUB
2214 elif a.dtype in (DType.INT8, DType.INT16):
2215 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2216 else:
2217 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002218
Les Bell6040b4d2021-10-11 12:50:31 +01002219 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002220 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002221 if (
2222 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2223 and block == then_block
2224 ) or (
2225 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2226 and block == else_block
2227 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002228 self.ser.addInputTensor(incorrect_block_input)
2229 self.ser.addInputTensor(b)
2230 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002231 elif (
2232 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2233 and block == then_block
2234 ) or (
2235 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2236 and block == else_block
2237 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002238 self.ser.addInputTensor(a)
2239 self.ser.addInputTensor(b)
2240 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2241 else:
2242 self.ser.addInputTensor(a)
2243 self.ser.addInputTensor(b)
2244 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002245 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002246
Les Bell729b0352021-11-24 10:28:21 +00002247 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002248 self.ser,
2249 validator_fcns,
2250 error_name,
2251 op=op,
2252 a=a,
2253 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002254 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002255 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002256 ):
2257 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002258
Eric Kunzee5e26762020-10-13 16:11:07 -07002259 return result_tens
2260
Matthew Haddon630c17c2021-10-14 15:05:41 +01002261 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002262 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002263
Kevin Cheng550ccc52021-03-03 11:21:43 -08002264 cond_block = "COND_BLOCK"
2265 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002266
2267 attr = ts.TosaSerializerAttribute()
2268 attr.WhileLoopAttribute(cond_block, body_block)
2269
2270 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002271 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002272 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002273 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002274
2275 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002276 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2277 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002278 if error_name == ErrorIf.InputListOutputListMismatch:
2279 incorrect_acc = deepcopy(acc)
2280 for i in range(len(incorrect_acc.shape)):
2281 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2282 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2283 else:
2284 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002285
2286 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002287 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002288 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002289 [iter.name, a.name, acc.name],
2290 [iter_out.name, a_out.name, acc_out.name],
2291 attr,
2292 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002293 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002294
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002295 if error_name in [
2296 ErrorIf.InputListCondGraphMismatch,
2297 ErrorIf.InputListBodyGraphInputMismatch,
2298 ErrorIf.InputListBodyGraphOutputMismatch,
2299 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002300 incorrect_iter = deepcopy(iter)
2301 for i in range(len(incorrect_iter.shape)):
2302 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2303 if len(incorrect_iter.shape) == 0:
2304 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2305
2306 incorrect_acc = deepcopy(acc)
2307 for i in range(len(incorrect_acc.shape)):
2308 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2309
Eric Kunzee5e26762020-10-13 16:11:07 -07002310 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002311 self.ser.addBasicBlock(cond_block)
2312
Matthew Haddon630c17c2021-10-14 15:05:41 +01002313 if error_name == ErrorIf.InputListCondGraphMismatch:
2314 self.ser.addInputTensor(incorrect_iter)
2315 self.ser.addInputTensor(a)
2316 self.ser.addInputTensor(incorrect_acc)
2317 else:
2318 self.ser.addInputTensor(iter)
2319 self.ser.addInputTensor(a)
2320 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002321 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002322
2323 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002324 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002325 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002326 cond_type = DType.BOOL
2327 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2328 choice = self.rng.choice([1, 2])
2329 if choice == 1:
2330 cond_shape = [3]
2331 else:
2332 cond_shape = [1, 2]
2333 else:
2334 cond_shape = []
2335 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002336
Kevin Cheng550ccc52021-03-03 11:21:43 -08002337 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002338
2339 # BODY block (input: a, acc, iter, output: a, acc, iter)
2340 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002341 self.ser.addBasicBlock(body_block)
2342
Matthew Haddon630c17c2021-10-14 15:05:41 +01002343 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2344 self.ser.addInputTensor(incorrect_iter)
2345 self.ser.addInputTensor(a)
2346 self.ser.addInputTensor(incorrect_acc)
2347 else:
2348 self.ser.addInputTensor(iter)
2349 self.ser.addInputTensor(a)
2350 self.ser.addInputTensor(acc)
2351
Kevin Cheng550ccc52021-03-03 11:21:43 -08002352 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002353
2354 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002355 iter_body_out = self.ser.addIntermediate(
2356 incorrect_iter.shape, incorrect_iter.dtype
2357 )
2358 acc_body_out = self.ser.addIntermediate(
2359 incorrect_acc.shape, incorrect_acc.dtype
2360 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002361 else:
2362 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2363 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2364
Eric Kunzee5e26762020-10-13 16:11:07 -07002365 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2366 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2367 self.ser.addOutputTensor(iter_body_out)
2368 self.ser.addOutputTensor(a)
2369 self.ser.addOutputTensor(acc_body_out)
2370
Les Bell729b0352021-11-24 10:28:21 +00002371 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002372 self.ser,
2373 validator_fcns,
2374 error_name,
2375 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002376 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002377 ):
2378 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002379
Eric Kunzee5e26762020-10-13 16:11:07 -07002380 return acc_out
2381
Luke Hutton57287132023-02-06 14:54:18 +00002382 def build_fft2d(
2383 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2384 ):
2385 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2386
2387 input_names = [val1.name, val2.name]
2388 pCount, cCount = op["operands"]
2389 num_operands = pCount + cCount
2390
2391 output_names = [res.name for res in results]
2392 output_shapes = [res.shape for res in results]
2393 output_dtypes = [res.dtype for res in results]
2394
2395 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2396 self, error_name, input_names, output_names
2397 )
2398
2399 if not TosaErrorValidator.evValidateErrorIfs(
2400 self.ser,
2401 validator_fcns,
2402 error_name,
2403 op=op,
2404 inverse=inverse,
2405 input1=val1,
2406 input2=val2,
2407 input_shape=val1.shape,
2408 input_dtype=val1.dtype,
2409 output_shape=output_shapes,
2410 output_dtype=output_dtypes,
2411 result_tensors=results,
2412 input_list=input_names,
2413 output_list=output_names,
2414 num_operands=num_operands,
2415 ):
2416 return None
2417
2418 attr = ts.TosaSerializerAttribute()
2419 attr.FFTAttribute(inverse)
2420
2421 self.ser.addOperator(op["op"], input_names, output_names, attr)
2422 return results
2423
Luke Hutton261b7b62023-01-10 14:50:31 +00002424 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2425 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2426
2427 input_names = [val.name]
2428 pCount, cCount = op["operands"]
2429 num_operands = pCount + cCount
2430
2431 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002432 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002433 output_dtypes = [res.dtype for res in results]
2434
2435 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2436 self, error_name, input_names, output_names
2437 )
2438
2439 if not TosaErrorValidator.evValidateErrorIfs(
2440 self.ser,
2441 validator_fcns,
2442 error_name,
2443 op=op,
2444 input_shape=val.shape,
2445 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002446 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002447 output_dtype=output_dtypes,
2448 result_tensors=results,
2449 input_list=input_names,
2450 output_list=output_names,
2451 num_operands=num_operands,
2452 ):
2453 return None
2454
2455 self.ser.addOperator(op["op"], input_names, output_names)
2456 return results
2457
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002458 def create_filter_lists(
2459 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2460 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002461 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2462 default_test_rank_range = range(1, 5)
2463 if not shapeFilter:
2464 shapeFilter = [None]
2465
2466 # Calculate the filters based on what is requested and what the operator allows
2467 rmin, rmax = op["rank"]
2468 if rankFilter is not None:
2469 cleanRankFilter = []
2470 # Ensure rankFilter values are allowed by operator
2471 for rank in rankFilter:
2472 if rank >= rmin and rank <= rmax:
2473 cleanRankFilter.append(rank)
2474 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002475 # Ensure default behaviour is bounded by default range or by operator,
2476 # whichever is the smaller range of ranks.
2477 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002478 cleanRankFilter = (
2479 opRankRange
2480 if len(opRankRange) <= len(default_test_rank_range)
2481 else default_test_rank_range
2482 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002483 else:
2484 cleanRankFilter = range(rmin, rmax + 1)
2485
2486 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002487
Matthew Haddon1c00b712021-10-01 15:51:03 +01002488 if dtypeFilter is not None:
2489 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002490 # Create list of operator dtypes filtered by requested dtypes
2491 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002492 if dtype in dtypeFilter or (
2493 isinstance(dtype, list) and dtype[0] in dtypeFilter
2494 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002495 cleanDtypeFilter.append(dtype)
2496 else:
2497 cleanDtypeFilter = dtypes
2498
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002499 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002500 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002501 "shapeFilter": shapeFilter,
2502 "rankFilter": cleanRankFilter,
2503 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002504 }
2505 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002506 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002507 if validator is not None:
2508 validator_info = validator(check=False, op=op)
2509 else:
2510 return None
2511
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002512 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002513
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002514 # Set parameters as required
2515 if error_arguments["rank"] is not None:
2516 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002517 else:
2518 rankFilter = cleanRankFilter
2519
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002520 if error_arguments["dtype"] is not None:
2521 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002522 else:
2523 dtypeFilter = cleanDtypeFilter
2524
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002525 if error_arguments["shape"] is not None:
2526 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002527 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002528 shapeFilter = shapeFilter[
2529 :2
2530 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002531
2532 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002533 "shapeFilter": shapeFilter,
2534 "rankFilter": rankFilter,
2535 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002536 }
2537 return filterDict
2538
Kevin Cheng550ccc52021-03-03 11:21:43 -08002539 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002540 self,
2541 opName,
2542 shapeFilter=[None],
2543 rankFilter=None,
2544 dtypeFilter=None,
2545 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002546 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002547
2548 try:
2549 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002550 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002551 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002552
2553 # Initialize a new random number generator
2554 self.rng = np.random.default_rng(self.random_seed)
2555
Jeremy Johnson1271c442023-09-05 11:39:26 +01002556 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002557
Eric Kunzee5e26762020-10-13 16:11:07 -07002558 # Test list consists of a tuple of:
2559 # (opName, testNameStr, dtype, shapeList, argumentsList)
2560 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002561 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002562 error_if_validators = op["error_if_validators"]
2563 else:
2564 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002565
Matthew Haddon1c00b712021-10-01 15:51:03 +01002566 for validator in error_if_validators:
2567 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002568 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002569 else:
2570 error_name = None
2571
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002572 filterDict = self.create_filter_lists(
2573 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2574 )
2575 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002576 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002577 cleanRankFilter = filterDict["rankFilter"]
2578 cleanDtypeFilter = filterDict["dtypeFilter"]
2579 cleanShapeFilter = filterDict["shapeFilter"]
2580 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002581
2582 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002583 for t in cleanDtypeFilter:
2584 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002585 # Filter out by rank
2586 if shape is not None and len(shape) != r:
2587 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002588 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002589 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002590
Matthew Haddon74567092021-07-16 15:38:20 +01002591 shapeStr = self.shapeStr(shapeList[0])
2592 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002593
Matthew Haddon74567092021-07-16 15:38:20 +01002594 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2595 argList = []
2596 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002597 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002598 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002599 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002600
Matthew Haddon74567092021-07-16 15:38:20 +01002601 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002602 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002603 if argStr:
2604 testStr = "{}_{}_{}_{}".format(
2605 opName, shapeStr, typeStr, argStr
2606 )
2607 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002608 testStr = "{}_{}_{}".format(
2609 opName, shapeStr, typeStr
2610 )
2611 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002612 if argStr:
2613 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2614 opName, error_name, shapeStr, typeStr, argStr
2615 )
2616 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002617 testStr = "{}_ERRORIF_{}_{}_{}".format(
2618 opName, error_name, shapeStr, typeStr
2619 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002620
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002621 testList.append(
2622 (opName, testStr, t, error_name, shapeList, args)
2623 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002624
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002625 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002626 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2627 if "invalid_test_validators" in op:
2628 invalid_test_validators = op["invalid_test_validators"]
2629 clean_testList = []
2630 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002631 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002632 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002633 if validator_fcn(
2634 opName=test[0],
2635 input_dtype=test[2],
2636 shapeList=test[4],
2637 args=test[5],
2638 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002639 remove_test = True
2640 if not remove_test:
2641 clean_testList.append(test)
2642 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002643
2644 return testList
2645
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002646 def serializeTest(
2647 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2648 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002649 try:
2650 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002651 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002652 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002653
Jeremy Johnson0c716862023-04-13 17:18:19 +01002654 if self.args.verbose:
2655 print(f"Creating {testStr}")
2656
Eric Kunzee5e26762020-10-13 16:11:07 -07002657 # Create a serializer
2658 self.createSerializer(opName, testStr)
2659
Jeremy Johnson1271c442023-09-05 11:39:26 +01002660 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002661 if "error_if_validators" in op:
2662 error_if_validators = op["error_if_validators"]
2663 else:
2664 error_if_validators = None
2665
Kevin Cheng550ccc52021-03-03 11:21:43 -08002666 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002667 num_operands = pCount + cCount
2668
2669 if isinstance(dtype_or_dtypeList, list):
2670 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002671 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002672 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002673 else:
2674 dtypeList = [dtype_or_dtypeList] * (num_operands)
2675
Kevin Cheng93a16282021-08-31 16:14:03 -07002676 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002677 assert (
2678 len(shapeList) == num_operands
2679 ), "shapeList length {} must match number of operands {}".format(
2680 len(shapeList), num_operands
2681 )
2682 assert (
2683 len(dtypeList) == num_operands
2684 ), "dtypeList length {} must match number of operands {}".format(
2685 len(dtypeList), num_operands
2686 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002687
2688 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002689 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002690 except KeyError:
2691 qgen = None
2692
2693 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002694
Matthew Haddon1c00b712021-10-01 15:51:03 +01002695 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002696 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002697 else:
2698 qinfo = None
2699
Jeremy Johnson1271c442023-09-05 11:39:26 +01002700 # Extra meta data for the desc.json
2701 tensMeta = {}
2702
2703 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002704 if isinstance(testArgs, dict):
2705 # New interface with args info in dictionary
2706 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002707 assert "dg_type" in argsDict
2708 tvgInfo = tvgen_fcn(
2709 self, opName, dtypeList, shapeList, argsDict, error_name
2710 )
2711 if tvgInfo.dataGenDict:
2712 tensMeta["data_gen"] = tvgInfo.dataGenDict
2713 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002714
2715 result = build_fcn(
2716 self,
2717 op,
2718 tens,
2719 argsDict,
2720 validator_fcns=error_if_validators,
2721 error_name=error_name,
2722 qinfo=qinfo,
2723 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002724 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002725 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002726 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002727
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002728 try:
2729 if error_if_validators is None:
2730 if qinfo is not None:
2731 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2732 else:
2733 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002734 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002735 if qinfo is not None:
2736 result = build_fcn(
2737 self,
2738 op,
2739 *tens,
2740 *testArgs,
2741 validator_fcns=error_if_validators,
2742 error_name=error_name,
2743 qinfo=qinfo,
2744 )
2745 else:
2746 result = build_fcn(
2747 self,
2748 op,
2749 *tens,
2750 *testArgs,
2751 validator_fcns=error_if_validators,
2752 error_name=error_name,
2753 )
2754 except TypeError as e:
2755 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2756 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002757
Jeremy Johnson1271c442023-09-05 11:39:26 +01002758 if result:
Les Bell729b0352021-11-24 10:28:21 +00002759 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002760 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2761 # Add the compliance meta data
2762 # NOTE: This currently expects only one result output
2763 tensMeta["compliance"] = {
2764 "version": "0.1",
2765 "tensors": {result.resultTensor.name: result.complianceDict},
2766 }
2767 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002768 else:
2769 # The test is not valid
2770 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002771
Eric Kunzee5e26762020-10-13 16:11:07 -07002772 def createDynamicOpLists(self):
2773
Jeremy Johnson00423432022-09-12 17:27:37 +01002774 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2775 # Already created these lists (can occur when class is initialized more than once)
2776 return
2777
Eric Kunzee5e26762020-10-13 16:11:07 -07002778 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002779 if not self.args.level8k:
2780 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2781 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2782 else:
2783 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2784 KERNELS_2D = [[1, bigK], [bigK, 2]]
2785 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002786
Kevin Cheng1533b852021-09-01 12:51:58 -07002787 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002788 testName = "conv2d_{}x{}".format(k[0], k[1])
2789 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2790 self.TOSA_OP_LIST[testName]["filter"] = k
2791 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002792
Kevin Cheng550ccc52021-03-03 11:21:43 -08002793 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2794 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2795 "depthwise_conv2d_TEMPLATE"
2796 ].copy()
2797 self.TOSA_OP_LIST[testName]["filter"] = k
2798 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002799
Kevin Cheng550ccc52021-03-03 11:21:43 -08002800 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2801 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2802 "transpose_conv2d_TEMPLATE"
2803 ].copy()
2804 self.TOSA_OP_LIST[testName]["filter"] = k
2805 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002806
Kevin Cheng1533b852021-09-01 12:51:58 -07002807 for k in KERNELS_3D:
2808 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2809 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2810 self.TOSA_OP_LIST[testName]["filter"] = k
2811 self.TOSA_OP_LIST[testName]["template"] = False
2812
Eric Kunzee5e26762020-10-13 16:11:07 -07002813 # Delete any templates after having created any dynamic ops
2814 # This is a two-pass operation because it's bad practice to delete
2815 # keys from dictionaries while iterating
2816 keyList = []
2817 for k in self.TOSA_OP_LIST:
2818 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002819 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002820 keyList.append(k)
2821 continue
2822 except KeyError:
2823 pass
2824
2825 for k in keyList:
2826 del self.TOSA_OP_LIST[k]
2827
2828 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002829 """Fill in default fields for ops if they aren't already specified.
2830 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002831 for op in self.TOSA_OP_LIST:
2832
2833 # Required fields
2834 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002835 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002836 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002837 raise Exception(
2838 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2839 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002840
2841 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002842 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002843 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002844 raise Exception(
2845 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2846 op
2847 )
2848 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002849
2850 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002851 _ = self.TOSA_OP_LIST[op]["types"]
2852 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002853 raise Exception(
2854 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2855 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002856
2857 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002858 _ = self.TOSA_OP_LIST[op]["op"]
2859 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002860 raise Exception(
2861 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2862 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002863
2864 # Put in default rank range, if missing
2865 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002866 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002867 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002868 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002869
2870 # Tensor operator list
2871 # 'op': op name
2872 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002873 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2874 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002875 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2876 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002877 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002878
Kevin Cheng550ccc52021-03-03 11:21:43 -08002879 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002880 TYPE_INT_FP = [
2881 DType.INT8,
2882 DType.INT16,
2883 DType.INT32,
2884 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002885 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002886 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002887 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002888
Kevin Cheng550ccc52021-03-03 11:21:43 -08002889 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002890 TYPE_FI32 = [
2891 DType.FP32,
2892 DType.FP16,
2893 DType.BF16,
2894 DType.INT32,
2895 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002896 TYPE_FIB = [
2897 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002898 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002899 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002900 DType.INT8,
2901 DType.INT16,
2902 DType.INT32,
2903 DType.BOOL,
2904 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002905 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002906
James Ward24dbc422022-10-19 12:20:31 +01002907 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002908
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002909 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002910 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002911 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002912 [DType.INT8, DType.INT8, DType.INT32],
2913 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002914 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002915 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002916 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002917 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002918 ]
2919
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002920 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002921
2922 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002923 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002924 "argmax": {
2925 "op": Op.ARGMAX,
2926 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002927 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002928 "build_fcn": (
2929 build_argmax,
2930 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002931 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002932 TosaArgGen.agAxis,
2933 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002934 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002935 "error_if_validators": (
2936 TosaErrorValidator.evAxisSmallerZero,
2937 TosaErrorValidator.evAxisLargerRank,
2938 TosaErrorValidator.evArgmaxOutputRankMismatch,
2939 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2940 TosaErrorValidator.evWrongRank,
2941 TosaErrorValidator.evWrongInputType,
2942 TosaErrorValidator.evWrongOutputType,
2943 TosaErrorValidator.evWrongInputList,
2944 TosaErrorValidator.evWrongOutputList,
2945 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002946 "data_gen": {
2947 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
2948 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002949 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002950 "avg_pool2d": {
2951 "op": Op.AVG_POOL2D,
2952 "operands": (1, 0),
2953 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002954 "build_fcn": (
2955 build_pool2d,
2956 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002957 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002958 TosaArgGen.agPooling,
2959 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002960 "qgen": TosaQuantGen.qgUnary,
2961 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002962 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002963 "error_if_validators": (
2964 TosaErrorValidator.evKernelSmallerOne,
2965 TosaErrorValidator.evStrideSmallerOne,
2966 TosaErrorValidator.evPadSmallerZero,
2967 TosaErrorValidator.evWrongRank,
2968 TosaErrorValidator.evWrongInputType,
2969 TosaErrorValidator.evWrongOutputType,
2970 TosaErrorValidator.evWrongInputList,
2971 TosaErrorValidator.evWrongOutputList,
2972 TosaErrorValidator.evInputZeroPointNotZero,
2973 TosaErrorValidator.evOutputZeroPointNotZero,
2974 TosaErrorValidator.evPadLargerEqualKernel,
2975 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002976 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002977 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002978 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002979 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002980 "conv2d_TEMPLATE": {
2981 "op": Op.CONV2D,
2982 "operands": (1, 2),
2983 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002984 "build_fcn": (
2985 build_conv2d,
2986 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002987 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002988 TosaArgGen.agConv,
2989 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002990 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002991 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002992 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2993 "error_if_validators": (
2994 TosaErrorValidator.evWrongInputType,
2995 TosaErrorValidator.evWrongOutputType,
2996 TosaErrorValidator.evWrongInputList,
2997 TosaErrorValidator.evWrongOutputList,
2998 TosaErrorValidator.evInputZeroPointNotZero,
2999 TosaErrorValidator.evWeightZeroPointNotZero,
3000 TosaErrorValidator.evPadSmallerZero,
3001 TosaErrorValidator.evStrideSmallerOne,
3002 TosaErrorValidator.evDilationSmallerOne,
3003 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003004 TosaErrorValidator.evConvOutputShapeMismatch,
3005 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003006 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003007 "data_gen": {
3008 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3009 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003010 "template": True,
3011 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003012 # Templated operator. Filled in by createDynamicOpLists
3013 "conv3d_TEMPLATE": {
3014 "op": Op.CONV3D,
3015 "operands": (1, 2),
3016 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003017 "build_fcn": (
3018 build_conv3d,
3019 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003020 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003021 TosaArgGen.agConv,
3022 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003023 "qgen": TosaQuantGen.qgConv,
3024 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003025 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3026 "error_if_validators": (
3027 TosaErrorValidator.evWrongInputType,
3028 TosaErrorValidator.evWrongOutputType,
3029 TosaErrorValidator.evWrongInputList,
3030 TosaErrorValidator.evWrongOutputList,
3031 TosaErrorValidator.evInputZeroPointNotZero,
3032 TosaErrorValidator.evWeightZeroPointNotZero,
3033 TosaErrorValidator.evPadSmallerZero,
3034 TosaErrorValidator.evStrideSmallerOne,
3035 TosaErrorValidator.evDilationSmallerOne,
3036 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003037 TosaErrorValidator.evConvOutputShapeMismatch,
3038 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003039 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003040 "template": True,
3041 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003042 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003043 "depthwise_conv2d_TEMPLATE": {
3044 "op": Op.DEPTHWISE_CONV2D,
3045 "operands": (1, 2),
3046 "filter": [1, 1],
3047 "rank": (4, 4),
3048 "build_fcn": (
3049 build_depthwise_conv2d,
3050 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003051 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003052 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003053 ),
3054 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003055 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003056 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3057 "error_if_validators": (
3058 TosaErrorValidator.evWrongInputType,
3059 TosaErrorValidator.evWrongOutputType,
3060 TosaErrorValidator.evWrongInputList,
3061 TosaErrorValidator.evWrongOutputList,
3062 TosaErrorValidator.evInputZeroPointNotZero,
3063 TosaErrorValidator.evWeightZeroPointNotZero,
3064 TosaErrorValidator.evPadSmallerZero,
3065 TosaErrorValidator.evStrideSmallerOne,
3066 TosaErrorValidator.evDilationSmallerOne,
3067 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003068 TosaErrorValidator.evConvOutputShapeMismatch,
3069 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003070 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003071 "template": True,
3072 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003073 "fully_connected": {
3074 "op": Op.FULLY_CONNECTED,
3075 "operands": (1, 2),
3076 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003077 "build_fcn": (
3078 build_fully_connected,
3079 TosaTensorGen.tgFullyConnected,
3080 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01003081 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003082 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003083 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003084 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003085 "error_if_validators": (
3086 TosaErrorValidator.evInputZeroPointNotZero,
3087 TosaErrorValidator.evWeightZeroPointNotZero,
3088 TosaErrorValidator.evWrongRank,
3089 TosaErrorValidator.evWrongInputType,
3090 TosaErrorValidator.evWrongOutputType,
3091 TosaErrorValidator.evWrongInputList,
3092 TosaErrorValidator.evWrongOutputList,
3093 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003094 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003095 "matmul": {
3096 "op": Op.MATMUL,
3097 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003098 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003099 "build_fcn": (
3100 build_matmul,
3101 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003102 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003103 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003104 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003105 "qgen": TosaQuantGen.qgMatmul,
3106 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003107 "error_if_validators": (
3108 TosaErrorValidator.evInputZeroPointNotZero,
3109 TosaErrorValidator.evWrongRank,
3110 TosaErrorValidator.evWrongInputType,
3111 TosaErrorValidator.evWrongOutputType,
3112 TosaErrorValidator.evWrongInputList,
3113 TosaErrorValidator.evWrongOutputList,
3114 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003115 "data_gen": {
3116 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003117 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003118 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003119 "max_pool2d": {
3120 "op": Op.MAX_POOL2D,
3121 "operands": (1, 0),
3122 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003123 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01003124 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003125 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003126 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003127 TosaArgGen.agPooling,
3128 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003129 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003130 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003131 "error_if_validators": (
3132 TosaErrorValidator.evKernelSmallerOne,
3133 TosaErrorValidator.evStrideSmallerOne,
3134 TosaErrorValidator.evPadSmallerZero,
3135 TosaErrorValidator.evWrongRank,
3136 TosaErrorValidator.evWrongInputType,
3137 TosaErrorValidator.evWrongOutputType,
3138 TosaErrorValidator.evWrongInputList,
3139 TosaErrorValidator.evWrongOutputList,
3140 TosaErrorValidator.evPadLargerEqualKernel,
3141 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003142 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003143 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003144 "data_gen": {
3145 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3146 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003147 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003148 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003149 "transpose_conv2d_TEMPLATE": {
3150 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003151 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003152 "rank": (4, 4),
3153 "build_fcn": (
3154 build_transpose_conv2d,
3155 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003156 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003157 TosaArgGen.agTransposeConv2D,
3158 ),
3159 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003160 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003161 "invalid_test_validators": (
3162 TosaInvalidValidator.ivHeightWidthInvalid,
3163 TosaInvalidValidator.ivNonPositiveOutputShape,
3164 ),
3165 "error_if_validators": (
3166 TosaErrorValidator.evWrongInputType,
3167 TosaErrorValidator.evWrongOutputType,
3168 TosaErrorValidator.evWrongInputList,
3169 TosaErrorValidator.evWrongOutputList,
3170 TosaErrorValidator.evInputZeroPointNotZero,
3171 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003172 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003173 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003174 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003175 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003176 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003177 "template": True,
3178 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003179 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003180 "clamp": {
3181 "op": Op.CLAMP,
3182 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003183 "build_fcn": (
3184 build_clamp,
3185 TosaTensorGen.tgBasic,
3186 TosaTensorValuesGen.tvgDefault,
3187 None,
3188 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003189 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003190 "error_if_validators": (
3191 TosaErrorValidator.evMaxSmallerMin,
3192 TosaErrorValidator.evWrongInputType,
3193 TosaErrorValidator.evWrongOutputType,
3194 TosaErrorValidator.evWrongInputList,
3195 TosaErrorValidator.evWrongOutputList,
3196 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003197 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003198 "sigmoid": {
3199 "op": Op.SIGMOID,
3200 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003201 "build_fcn": (
3202 build_sigmoid,
3203 TosaTensorGen.tgBasic,
3204 TosaTensorValuesGen.tvgDefault,
3205 None,
3206 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003207 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003208 "error_if_validators": (
3209 TosaErrorValidator.evWrongInputType,
3210 TosaErrorValidator.evWrongOutputType,
3211 TosaErrorValidator.evWrongInputList,
3212 TosaErrorValidator.evWrongOutputList,
3213 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003214 },
3215 "tanh": {
3216 "op": Op.TANH,
3217 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003218 "build_fcn": (
3219 build_tanh,
3220 TosaTensorGen.tgBasic,
3221 TosaTensorValuesGen.tvgDefault,
3222 None,
3223 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003224 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003225 "error_if_validators": (
3226 TosaErrorValidator.evWrongInputType,
3227 TosaErrorValidator.evWrongOutputType,
3228 TosaErrorValidator.evWrongInputList,
3229 TosaErrorValidator.evWrongOutputList,
3230 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003231 },
Won Jeon78155c62023-06-10 00:20:04 +00003232 "erf": {
3233 "op": Op.ERF,
3234 "operands": (1, 0),
3235 "build_fcn": (
3236 build_erf,
3237 TosaTensorGen.tgBasic,
3238 TosaTensorValuesGen.tvgDefault,
3239 None,
3240 ),
3241 "types": TYPE_FP,
3242 "error_if_validators": (
3243 TosaErrorValidator.evWrongInputType,
3244 TosaErrorValidator.evWrongOutputType,
3245 TosaErrorValidator.evWrongInputList,
3246 TosaErrorValidator.evWrongOutputList,
3247 ),
3248 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003249 # Elementwise Binary Operators
3250 "add": {
3251 "op": Op.ADD,
3252 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003253 "build_fcn": (
3254 build_binary_broadcast,
3255 TosaTensorGen.tgBroadcastFuzz,
3256 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003257 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003258 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003259 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003260 "error_if_validators": (
3261 TosaErrorValidator.evRankMismatch,
3262 TosaErrorValidator.evWrongInputType,
3263 TosaErrorValidator.evWrongOutputType,
3264 TosaErrorValidator.evWrongInputList,
3265 TosaErrorValidator.evWrongOutputList,
3266 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003267 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003268 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003269 "data_gen": {
3270 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3271 },
3272 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003273 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003274 "arithmetic_right_shift": {
3275 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3276 "operands": (2, 0),
3277 "build_fcn": (
3278 build_arithmetic_right_shift,
3279 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003280 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003281 TosaArgGen.agArithmeticRightShift,
3282 ),
3283 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003284 "error_if_validators": (
3285 TosaErrorValidator.evRankMismatch,
3286 TosaErrorValidator.evWrongInputType,
3287 TosaErrorValidator.evWrongOutputType,
3288 TosaErrorValidator.evWrongInputList,
3289 TosaErrorValidator.evWrongOutputList,
3290 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003291 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003292 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003293 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003294 "bitwise_and": {
3295 "op": Op.BITWISE_AND,
3296 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003297 "build_fcn": (
3298 build_binary_broadcast,
3299 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003300 TosaTensorValuesGen.tvgLazyGenDefault,
3301 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003302 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003303 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003304 "error_if_validators": (
3305 TosaErrorValidator.evRankMismatch,
3306 TosaErrorValidator.evWrongInputType,
3307 TosaErrorValidator.evWrongOutputType,
3308 TosaErrorValidator.evWrongInputList,
3309 TosaErrorValidator.evWrongOutputList,
3310 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003311 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003312 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003313 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003314 "bitwise_or": {
3315 "op": Op.BITWISE_OR,
3316 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003317 "build_fcn": (
3318 build_binary_broadcast,
3319 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003320 TosaTensorValuesGen.tvgLazyGenDefault,
3321 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003322 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003323 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003324 "error_if_validators": (
3325 TosaErrorValidator.evRankMismatch,
3326 TosaErrorValidator.evWrongInputType,
3327 TosaErrorValidator.evWrongOutputType,
3328 TosaErrorValidator.evWrongInputList,
3329 TosaErrorValidator.evWrongOutputList,
3330 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003331 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003332 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003333 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003334 "bitwise_xor": {
3335 "op": Op.BITWISE_XOR,
3336 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003337 "build_fcn": (
3338 build_binary_broadcast,
3339 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003340 TosaTensorValuesGen.tvgLazyGenDefault,
3341 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003342 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003343 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003344 "error_if_validators": (
3345 TosaErrorValidator.evRankMismatch,
3346 TosaErrorValidator.evWrongInputType,
3347 TosaErrorValidator.evWrongOutputType,
3348 TosaErrorValidator.evWrongInputList,
3349 TosaErrorValidator.evWrongOutputList,
3350 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003351 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003352 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003353 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003354 "intdiv": {
3355 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003356 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003357 "build_fcn": (
3358 build_binary_broadcast,
3359 TosaTensorGen.tgBroadcastFuzz,
3360 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003361 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003362 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003363 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003364 "error_if_validators": (
3365 TosaErrorValidator.evRankMismatch,
3366 TosaErrorValidator.evWrongInputType,
3367 TosaErrorValidator.evWrongOutputType,
3368 TosaErrorValidator.evWrongInputList,
3369 TosaErrorValidator.evWrongOutputList,
3370 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003371 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003372 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003373 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003374 "logical_and": {
3375 "op": Op.LOGICAL_AND,
3376 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003377 "build_fcn": (
3378 build_binary_broadcast,
3379 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003380 TosaTensorValuesGen.tvgLazyGenDefault,
3381 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003382 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003383 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003384 "error_if_validators": (
3385 TosaErrorValidator.evRankMismatch,
3386 TosaErrorValidator.evWrongInputType,
3387 TosaErrorValidator.evWrongOutputType,
3388 TosaErrorValidator.evWrongInputList,
3389 TosaErrorValidator.evWrongOutputList,
3390 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003391 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003392 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003393 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003394 "logical_left_shift": {
3395 "op": Op.LOGICAL_LEFT_SHIFT,
3396 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003397 "build_fcn": (
3398 build_binary_broadcast,
3399 TosaTensorGen.tgBroadcastFuzz,
3400 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003401 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003402 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003403 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003404 "error_if_validators": (
3405 TosaErrorValidator.evRankMismatch,
3406 TosaErrorValidator.evWrongInputType,
3407 TosaErrorValidator.evWrongOutputType,
3408 TosaErrorValidator.evWrongInputList,
3409 TosaErrorValidator.evWrongOutputList,
3410 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003411 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003412 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003413 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003414 "logical_right_shift": {
3415 "op": Op.LOGICAL_RIGHT_SHIFT,
3416 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003417 "build_fcn": (
3418 build_binary_broadcast,
3419 TosaTensorGen.tgBroadcastFuzz,
3420 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003421 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003422 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003423 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003424 "error_if_validators": (
3425 TosaErrorValidator.evRankMismatch,
3426 TosaErrorValidator.evWrongInputType,
3427 TosaErrorValidator.evWrongOutputType,
3428 TosaErrorValidator.evWrongInputList,
3429 TosaErrorValidator.evWrongOutputList,
3430 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003431 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003432 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003433 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003434 "logical_or": {
3435 "op": Op.LOGICAL_OR,
3436 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003437 "build_fcn": (
3438 build_binary_broadcast,
3439 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003440 TosaTensorValuesGen.tvgLazyGenDefault,
3441 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003442 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003443 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003444 "error_if_validators": (
3445 TosaErrorValidator.evRankMismatch,
3446 TosaErrorValidator.evWrongInputType,
3447 TosaErrorValidator.evWrongOutputType,
3448 TosaErrorValidator.evWrongInputList,
3449 TosaErrorValidator.evWrongOutputList,
3450 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003451 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003452 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003453 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003454 "logical_xor": {
3455 "op": Op.LOGICAL_XOR,
3456 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003457 "build_fcn": (
3458 build_binary_broadcast,
3459 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003460 TosaTensorValuesGen.tvgLazyGenDefault,
3461 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003462 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003463 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003464 "error_if_validators": (
3465 TosaErrorValidator.evRankMismatch,
3466 TosaErrorValidator.evWrongInputType,
3467 TosaErrorValidator.evWrongOutputType,
3468 TosaErrorValidator.evWrongInputList,
3469 TosaErrorValidator.evWrongOutputList,
3470 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003471 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003472 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003474 "maximum": {
3475 "op": Op.MAXIMUM,
3476 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003477 "build_fcn": (
3478 build_binary_broadcast,
3479 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003480 TosaTensorValuesGen.tvgLazyGenDefault,
3481 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003482 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003483 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003484 "error_if_validators": (
3485 TosaErrorValidator.evRankMismatch,
3486 TosaErrorValidator.evWrongInputType,
3487 TosaErrorValidator.evWrongOutputType,
3488 TosaErrorValidator.evWrongInputList,
3489 TosaErrorValidator.evWrongOutputList,
3490 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003491 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003492 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003493 "data_gen": {
3494 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3495 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003496 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003497 "minimum": {
3498 "op": Op.MINIMUM,
3499 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003500 "build_fcn": (
3501 build_binary_broadcast,
3502 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003503 TosaTensorValuesGen.tvgLazyGenDefault,
3504 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003505 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003506 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003507 "error_if_validators": (
3508 TosaErrorValidator.evRankMismatch,
3509 TosaErrorValidator.evWrongInputType,
3510 TosaErrorValidator.evWrongOutputType,
3511 TosaErrorValidator.evWrongInputList,
3512 TosaErrorValidator.evWrongOutputList,
3513 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003514 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003515 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003516 "data_gen": {
3517 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3518 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003519 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003520 "mul": {
3521 "op": Op.MUL,
3522 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003523 "build_fcn": (
3524 build_mul,
3525 TosaTensorGen.tgBroadcastFuzz,
3526 TosaTensorValuesGen.tvgMul,
3527 TosaArgGen.agMul,
3528 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003529 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003530 "error_if_validators": (
3531 TosaErrorValidator.evWrongInputType,
3532 TosaErrorValidator.evWrongOutputType,
3533 TosaErrorValidator.evWrongInputList,
3534 TosaErrorValidator.evWrongOutputList,
3535 TosaErrorValidator.evRankMismatch,
3536 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003537 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003538 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003539 "data_gen": {
3540 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3541 },
3542 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003543 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003544 "pow": {
3545 "op": Op.POW,
3546 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003547 "build_fcn": (
3548 build_binary_broadcast,
3549 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003550 TosaTensorValuesGen.tvgLazyGenDefault,
3551 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003552 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003553 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003554 "error_if_validators": (
3555 TosaErrorValidator.evRankMismatch,
3556 TosaErrorValidator.evWrongInputType,
3557 TosaErrorValidator.evWrongOutputType,
3558 TosaErrorValidator.evWrongInputList,
3559 TosaErrorValidator.evWrongOutputList,
3560 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003561 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003562 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003563 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003564 "sub": {
3565 "op": Op.SUB,
3566 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003567 "build_fcn": (
3568 build_binary_broadcast,
3569 TosaTensorGen.tgBroadcastFuzz,
3570 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003571 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003572 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003573 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003574 "error_if_validators": (
3575 TosaErrorValidator.evRankMismatch,
3576 TosaErrorValidator.evWrongInputType,
3577 TosaErrorValidator.evWrongOutputType,
3578 TosaErrorValidator.evWrongInputList,
3579 TosaErrorValidator.evWrongOutputList,
3580 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003581 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003582 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003583 "data_gen": {
3584 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3585 },
3586 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003587 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003588 "table": {
3589 "op": Op.TABLE,
3590 # Use the automatic generation functions to create the input array
3591 # but create the table tensor in the build function, as it may be
3592 # a different type from the input
3593 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003594 "build_fcn": (
3595 build_table,
3596 TosaTensorGen.tgBasic,
3597 TosaTensorValuesGen.tvgDefault,
3598 TosaArgGen.agTable,
3599 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003600 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003601 "error_if_validators": (
3602 TosaErrorValidator.evWrongInputType,
3603 TosaErrorValidator.evWrongOutputType,
3604 TosaErrorValidator.evWrongInputList,
3605 TosaErrorValidator.evWrongOutputList,
3606 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003607 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003608 # Elementwise Unary operators
3609 "abs": {
3610 "op": Op.ABS,
3611 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003612 "build_fcn": (
3613 build_unary,
3614 TosaTensorGen.tgBasic,
3615 TosaTensorValuesGen.tvgDefault,
3616 None,
3617 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003618 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003619 "error_if_validators": (
3620 TosaErrorValidator.evWrongInputType,
3621 TosaErrorValidator.evWrongOutputType,
3622 TosaErrorValidator.evWrongInputList,
3623 TosaErrorValidator.evWrongOutputList,
3624 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003625 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003626 "bitwise_not": {
3627 "op": Op.BITWISE_NOT,
3628 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003629 "build_fcn": (
3630 build_unary,
3631 TosaTensorGen.tgBasic,
3632 TosaTensorValuesGen.tvgDefault,
3633 None,
3634 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003635 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003636 "error_if_validators": (
3637 TosaErrorValidator.evWrongInputType,
3638 TosaErrorValidator.evWrongOutputType,
3639 TosaErrorValidator.evWrongInputList,
3640 TosaErrorValidator.evWrongOutputList,
3641 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003642 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003643 "ceil": {
3644 "op": Op.CEIL,
3645 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003646 "build_fcn": (
3647 build_unary,
3648 TosaTensorGen.tgBasic,
3649 TosaTensorValuesGen.tvgDefault,
3650 None,
3651 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003652 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003653 "error_if_validators": (
3654 TosaErrorValidator.evWrongInputType,
3655 TosaErrorValidator.evWrongOutputType,
3656 TosaErrorValidator.evWrongInputList,
3657 TosaErrorValidator.evWrongOutputList,
3658 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003659 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003660 "clz": {
3661 "op": Op.CLZ,
3662 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003663 "build_fcn": (
3664 build_unary,
3665 TosaTensorGen.tgBasic,
3666 TosaTensorValuesGen.tvgDefault,
3667 None,
3668 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003669 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003670 "error_if_validators": (
3671 TosaErrorValidator.evWrongInputType,
3672 TosaErrorValidator.evWrongOutputType,
3673 TosaErrorValidator.evWrongInputList,
3674 TosaErrorValidator.evWrongOutputList,
3675 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003676 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003677 "exp": {
3678 "op": Op.EXP,
3679 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003680 "build_fcn": (
3681 build_unary,
3682 TosaTensorGen.tgBasic,
3683 TosaTensorValuesGen.tvgDefault,
3684 None,
3685 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003686 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003687 "error_if_validators": (
3688 TosaErrorValidator.evWrongInputType,
3689 TosaErrorValidator.evWrongOutputType,
3690 TosaErrorValidator.evWrongInputList,
3691 TosaErrorValidator.evWrongOutputList,
3692 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003693 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003694 "floor": {
3695 "op": Op.FLOOR,
3696 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003697 "build_fcn": (
3698 build_unary,
3699 TosaTensorGen.tgBasic,
3700 TosaTensorValuesGen.tvgDefault,
3701 None,
3702 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003703 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003704 "error_if_validators": (
3705 TosaErrorValidator.evWrongInputType,
3706 TosaErrorValidator.evWrongOutputType,
3707 TosaErrorValidator.evWrongInputList,
3708 TosaErrorValidator.evWrongOutputList,
3709 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003710 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003711 "log": {
3712 "op": Op.LOG,
3713 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003714 "build_fcn": (
3715 build_unary,
3716 TosaTensorGen.tgBasic,
3717 TosaTensorValuesGen.tvgDefault,
3718 None,
3719 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003720 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003721 "error_if_validators": (
3722 TosaErrorValidator.evWrongInputType,
3723 TosaErrorValidator.evWrongOutputType,
3724 TosaErrorValidator.evWrongInputList,
3725 TosaErrorValidator.evWrongOutputList,
3726 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003727 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003728 "logical_not": {
3729 "op": Op.LOGICAL_NOT,
3730 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003731 "build_fcn": (
3732 build_unary,
3733 TosaTensorGen.tgBasic,
3734 TosaTensorValuesGen.tvgDefault,
3735 None,
3736 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003737 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003738 "error_if_validators": (
3739 TosaErrorValidator.evWrongInputType,
3740 TosaErrorValidator.evWrongOutputType,
3741 TosaErrorValidator.evWrongInputList,
3742 TosaErrorValidator.evWrongOutputList,
3743 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003744 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003745 "negate": {
3746 "op": Op.NEGATE,
3747 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003748 "build_fcn": (
3749 build_unary,
3750 TosaTensorGen.tgBasic,
3751 TosaTensorValuesGen.tvgNegate,
3752 None,
3753 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003754 "qgen": TosaQuantGen.qgUnary,
3755 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003756 "error_if_validators": (
3757 TosaErrorValidator.evInputZeroPointNotZero,
3758 TosaErrorValidator.evOutputZeroPointNotZero,
3759 TosaErrorValidator.evWrongInputType,
3760 TosaErrorValidator.evWrongOutputType,
3761 TosaErrorValidator.evWrongInputList,
3762 TosaErrorValidator.evWrongOutputList,
3763 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003764 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003765 "reciprocal": {
3766 "op": Op.RECIPROCAL,
3767 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003768 "build_fcn": (
3769 build_unary,
3770 TosaTensorGen.tgBasic,
3771 TosaTensorValuesGen.tvgDefault,
3772 None,
3773 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003774 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003775 "error_if_validators": (
3776 TosaErrorValidator.evWrongInputType,
3777 TosaErrorValidator.evWrongOutputType,
3778 TosaErrorValidator.evWrongInputList,
3779 TosaErrorValidator.evWrongOutputList,
3780 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003782 "rsqrt": {
3783 "op": Op.RSQRT,
3784 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003785 "build_fcn": (
3786 build_unary,
3787 TosaTensorGen.tgBasic,
3788 TosaTensorValuesGen.tvgDefault,
3789 None,
3790 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003791 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003792 "error_if_validators": (
3793 TosaErrorValidator.evWrongInputType,
3794 TosaErrorValidator.evWrongOutputType,
3795 TosaErrorValidator.evWrongInputList,
3796 TosaErrorValidator.evWrongOutputList,
3797 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003798 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003799 # Elementwise Ternary operators
3800 "select": {
3801 "op": Op.SELECT,
3802 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003803 "build_fcn": (
3804 build_select,
3805 TosaTensorGen.tgBroadcastFuzz,
3806 TosaTensorValuesGen.tvgSelect,
3807 None,
3808 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003809 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003810 "error_if_validators": (
3811 TosaErrorValidator.evRankMismatch,
3812 TosaErrorValidator.evWrongInputType,
3813 TosaErrorValidator.evWrongOutputType,
3814 TosaErrorValidator.evWrongInputList,
3815 TosaErrorValidator.evWrongOutputList,
3816 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003817 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003818 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003819 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003820 # Comparison operators
3821 "equal": {
3822 "op": Op.EQUAL,
3823 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003824 "build_fcn": (
3825 build_comparison,
3826 TosaTensorGen.tgBroadcastFuzz,
3827 TosaTensorValuesGen.tvgEqual,
3828 None,
3829 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003830 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003831 "error_if_validators": (
3832 TosaErrorValidator.evRankMismatch,
3833 TosaErrorValidator.evWrongInputType,
3834 TosaErrorValidator.evWrongOutputType,
3835 TosaErrorValidator.evWrongInputList,
3836 TosaErrorValidator.evWrongOutputList,
3837 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003838 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003839 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003840 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003841 "greater_equal": {
3842 "op": Op.GREATER_EQUAL,
3843 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003844 "build_fcn": (
3845 build_comparison,
3846 TosaTensorGen.tgBroadcastFuzz,
3847 TosaTensorValuesGen.tvgDefault,
3848 None,
3849 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003850 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003851 "error_if_validators": (
3852 TosaErrorValidator.evRankMismatch,
3853 TosaErrorValidator.evWrongInputType,
3854 TosaErrorValidator.evWrongOutputType,
3855 TosaErrorValidator.evWrongInputList,
3856 TosaErrorValidator.evWrongOutputList,
3857 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003858 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003859 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003860 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003861 "greater": {
3862 "op": Op.GREATER,
3863 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003864 "build_fcn": (
3865 build_comparison,
3866 TosaTensorGen.tgBroadcastFuzz,
3867 TosaTensorValuesGen.tvgDefault,
3868 None,
3869 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003870 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003871 "error_if_validators": (
3872 TosaErrorValidator.evRankMismatch,
3873 TosaErrorValidator.evWrongInputType,
3874 TosaErrorValidator.evWrongOutputType,
3875 TosaErrorValidator.evWrongInputList,
3876 TosaErrorValidator.evWrongOutputList,
3877 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003878 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003879 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003880 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003881 # Reduction operators
3882 "reduce_all": {
3883 "op": Op.REDUCE_ALL,
3884 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003885 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003886 "build_fcn": (
3887 build_reduce,
3888 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003889 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003890 TosaArgGen.agAxis,
3891 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003892 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003893 "error_if_validators": (
3894 TosaErrorValidator.evAxisLargerRank,
3895 TosaErrorValidator.evAxisSmallerZero,
3896 TosaErrorValidator.evShapeOfAxisNotOne,
3897 TosaErrorValidator.evWrongInputType,
3898 TosaErrorValidator.evWrongOutputType,
3899 TosaErrorValidator.evWrongRank,
3900 TosaErrorValidator.evWrongInputList,
3901 TosaErrorValidator.evWrongOutputList,
3902 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003903 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003904 "reduce_any": {
3905 "op": Op.REDUCE_ANY,
3906 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003907 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003908 "build_fcn": (
3909 build_reduce,
3910 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003911 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003912 TosaArgGen.agAxis,
3913 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003914 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003915 "error_if_validators": (
3916 TosaErrorValidator.evAxisLargerRank,
3917 TosaErrorValidator.evAxisSmallerZero,
3918 TosaErrorValidator.evShapeOfAxisNotOne,
3919 TosaErrorValidator.evWrongInputType,
3920 TosaErrorValidator.evWrongOutputType,
3921 TosaErrorValidator.evWrongRank,
3922 TosaErrorValidator.evWrongInputList,
3923 TosaErrorValidator.evWrongOutputList,
3924 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003925 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003926 "reduce_max": {
3927 "op": Op.REDUCE_MAX,
3928 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003929 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003930 "build_fcn": (
3931 build_reduce,
3932 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003933 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003934 TosaArgGen.agAxis,
3935 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003936 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003937 "error_if_validators": (
3938 TosaErrorValidator.evAxisLargerRank,
3939 TosaErrorValidator.evAxisSmallerZero,
3940 TosaErrorValidator.evShapeOfAxisNotOne,
3941 TosaErrorValidator.evWrongInputType,
3942 TosaErrorValidator.evWrongOutputType,
3943 TosaErrorValidator.evWrongRank,
3944 TosaErrorValidator.evWrongInputList,
3945 TosaErrorValidator.evWrongOutputList,
3946 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003947 "data_gen": {
3948 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3949 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003950 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003951 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003952 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003953 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003954 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003955 "build_fcn": (
3956 build_reduce,
3957 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003958 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003959 TosaArgGen.agAxis,
3960 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003961 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003962 "error_if_validators": (
3963 TosaErrorValidator.evAxisLargerRank,
3964 TosaErrorValidator.evAxisSmallerZero,
3965 TosaErrorValidator.evShapeOfAxisNotOne,
3966 TosaErrorValidator.evWrongInputType,
3967 TosaErrorValidator.evWrongOutputType,
3968 TosaErrorValidator.evWrongRank,
3969 TosaErrorValidator.evWrongInputList,
3970 TosaErrorValidator.evWrongOutputList,
3971 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003972 "data_gen": {
3973 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3974 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003975 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003976 "reduce_product": {
3977 "op": Op.REDUCE_PRODUCT,
3978 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003979 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003980 "build_fcn": (
3981 build_reduce,
3982 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003983 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003984 TosaArgGen.agAxis,
3985 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003986 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003987 "error_if_validators": (
3988 TosaErrorValidator.evAxisLargerRank,
3989 TosaErrorValidator.evAxisSmallerZero,
3990 TosaErrorValidator.evShapeOfAxisNotOne,
3991 TosaErrorValidator.evWrongInputType,
3992 TosaErrorValidator.evWrongOutputType,
3993 TosaErrorValidator.evWrongRank,
3994 TosaErrorValidator.evWrongInputList,
3995 TosaErrorValidator.evWrongOutputList,
3996 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003997 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003998 "reduce_sum": {
3999 "op": Op.REDUCE_SUM,
4000 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004001 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004002 "build_fcn": (
4003 build_reduce,
4004 TosaTensorGen.tgBasic,
4005 TosaTensorValuesGen.tvgReduceSum,
4006 TosaArgGen.agAxis,
4007 ),
James Ward24dbc422022-10-19 12:20:31 +01004008 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004009 "error_if_validators": (
4010 TosaErrorValidator.evAxisLargerRank,
4011 TosaErrorValidator.evAxisSmallerZero,
4012 TosaErrorValidator.evShapeOfAxisNotOne,
4013 TosaErrorValidator.evWrongInputType,
4014 TosaErrorValidator.evWrongOutputType,
4015 TosaErrorValidator.evWrongRank,
4016 TosaErrorValidator.evWrongInputList,
4017 TosaErrorValidator.evWrongOutputList,
4018 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004019 "data_gen": {
4020 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4021 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004022 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004023 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004024 "concat": {
4025 "op": Op.CONCAT,
4026 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004027 "build_fcn": (
4028 build_concat,
4029 TosaTensorGen.tgConcat,
4030 TosaTensorValuesGen.tvgConcat,
4031 TosaArgGen.agAxis,
4032 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004033 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004034 "error_if_validators": (
4035 TosaErrorValidator.evAxisLargerRank,
4036 TosaErrorValidator.evAxisSmallerZero,
4037 TosaErrorValidator.evConcatInputRankMismatch,
4038 TosaErrorValidator.evConcatShapeSumMismatch,
4039 TosaErrorValidator.evConcatInputDimMismatch,
4040 TosaErrorValidator.evWrongInputType,
4041 TosaErrorValidator.evWrongOutputType,
4042 TosaErrorValidator.evWrongOutputList,
4043 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004044 },
4045 "pad": {
4046 "op": Op.PAD,
4047 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004048 "build_fcn": (
4049 build_pad,
4050 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004051 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004052 TosaArgGen.agPad,
4053 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004054 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004055 "error_if_validators": (
4056 TosaErrorValidator.evWrongInputType,
4057 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004058 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004059 TosaErrorValidator.evWrongOutputType,
4060 TosaErrorValidator.evWrongInputList,
4061 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004062 TosaErrorValidator.evRankMismatch,
4063 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004064 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004065 "data_gen": {
4066 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4067 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004068 },
Won Jeona21b2e82023-08-10 10:33:01 +00004069 "dim": {
4070 "op": Op.DIM,
4071 "operands": (1, 0),
4072 "build_fcn": (
4073 build_dim,
4074 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004075 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004076 TosaArgGen.agAxis,
4077 ),
4078 "types": TYPE_FIB,
4079 "error_if_validators": (
4080 TosaErrorValidator.evAxisLargerRank,
4081 TosaErrorValidator.evAxisSmallerZero,
4082 TosaErrorValidator.evWrongInputType,
4083 TosaErrorValidator.evWrongInputList,
4084 TosaErrorValidator.evWrongOutputList,
4085 TosaErrorValidator.evWrongRank,
4086 ),
4087 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004088 "reshape": {
4089 "op": Op.RESHAPE,
4090 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004091 "build_fcn": (
4092 build_reshape,
4093 TosaTensorGen.tgBasic,
4094 TosaTensorValuesGen.tvgDefault,
4095 TosaArgGen.agReshape,
4096 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004097 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004098 "error_if_validators": (
4099 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4100 TosaErrorValidator.evWrongInputType,
4101 TosaErrorValidator.evWrongOutputType,
4102 TosaErrorValidator.evWrongInputList,
4103 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00004104 TosaErrorValidator.evReshapeOutputSizeMultiInference,
4105 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004106 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004107 },
4108 "reverse": {
4109 "op": Op.REVERSE,
4110 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004111 "build_fcn": (
4112 build_reverse,
4113 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004114 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004115 TosaArgGen.agAxis,
4116 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004117 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004118 "error_if_validators": (
4119 TosaErrorValidator.evAxisSmallerZero,
4120 TosaErrorValidator.evAxisLargerRank,
4121 TosaErrorValidator.evWrongInputType,
4122 TosaErrorValidator.evWrongOutputType,
4123 TosaErrorValidator.evWrongInputList,
4124 TosaErrorValidator.evWrongOutputList,
4125 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004126 },
4127 "slice": {
4128 "op": Op.SLICE,
4129 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004130 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004131 "build_fcn": (
4132 build_slice,
4133 TosaTensorGen.tgBasic,
4134 TosaTensorValuesGen.tvgDefault,
4135 TosaArgGen.agSlice,
4136 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004137 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004138 "error_if_validators": (
4139 TosaErrorValidator.evStartSmallerZero,
4140 TosaErrorValidator.evSizeSmallerEqualZero,
4141 TosaErrorValidator.evStartSizeOutsideBounds,
4142 TosaErrorValidator.evSizeOutputShapeMismatch,
4143 TosaErrorValidator.evInputSizeStartLengthMismatch,
4144 TosaErrorValidator.evWrongRank,
4145 TosaErrorValidator.evWrongInputType,
4146 TosaErrorValidator.evWrongOutputType,
4147 TosaErrorValidator.evWrongInputList,
4148 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004149 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004150 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004151 },
4152 "tile": {
4153 "op": Op.TILE,
4154 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004155 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004156 "build_fcn": (
4157 build_tile,
4158 TosaTensorGen.tgBasic,
4159 TosaTensorValuesGen.tvgDefault,
4160 TosaArgGen.agTile,
4161 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004162 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004163 "error_if_validators": (
4164 TosaErrorValidator.evWrongInputType,
4165 TosaErrorValidator.evWrongOutputType,
4166 TosaErrorValidator.evWrongInputList,
4167 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004168 TosaErrorValidator.evRankMismatch,
4169 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004170 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004171 },
4172 "transpose": {
4173 "op": Op.TRANSPOSE,
4174 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004175 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004176 "build_fcn": (
4177 build_transpose,
4178 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004179 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004180 TosaArgGen.agTranspose,
4181 ),
4182 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004183 "error_if_validators": (
4184 TosaErrorValidator.evIndexOutsideBounds,
4185 TosaErrorValidator.evIndexUsedTwice,
4186 TosaErrorValidator.evWrongInputType,
4187 TosaErrorValidator.evWrongOutputType,
4188 TosaErrorValidator.evWrongInputList,
4189 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004190 TosaErrorValidator.evWrongRank,
4191 TosaErrorValidator.evRankMismatch,
4192 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004193 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004194 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004195 # Data nodes
4196 "const": {
4197 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004198 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004199 "build_fcn": (
4200 build_const,
4201 TosaTensorGen.tgBasic,
4202 TosaTensorValuesGen.tvgDefault,
4203 None,
4204 ),
Luke Hutton65872422023-02-20 10:33:04 +00004205 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004206 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004207 "identity": {
4208 "op": Op.IDENTITY,
4209 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004210 "build_fcn": (
4211 build_unary,
4212 TosaTensorGen.tgBasic,
4213 TosaTensorValuesGen.tvgDefault,
4214 None,
4215 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004216 "types": TYPE_FIB,
4217 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004218 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004219 "gather": {
4220 "op": Op.GATHER,
4221 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4222 "operands": (1, 0),
4223 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004224 "build_fcn": (
4225 build_gather,
4226 TosaTensorGen.tgBasic,
4227 TosaTensorValuesGen.tvgDefault,
4228 None,
4229 ),
James Ward24dbc422022-10-19 12:20:31 +01004230 "types": (
4231 DType.INT8,
4232 DType.INT16,
4233 DType.INT32,
4234 DType.FP16,
4235 DType.BF16,
4236 DType.FP32,
4237 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004238 "error_if_validators": (
4239 TosaErrorValidator.evWrongInputType,
4240 TosaErrorValidator.evWrongOutputType,
4241 TosaErrorValidator.evWrongInputList,
4242 TosaErrorValidator.evWrongOutputList,
4243 TosaErrorValidator.evWrongRank,
4244 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004245 },
4246 "scatter": {
4247 "op": Op.SCATTER,
4248 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004249 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004250 "operands": (2, 0),
4251 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004252 "build_fcn": (
4253 build_scatter,
4254 TosaTensorGen.tgScatter,
4255 TosaTensorValuesGen.tvgDefault,
4256 None,
4257 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004258 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004259 "error_if_validators": (
4260 TosaErrorValidator.evWrongInputType,
4261 TosaErrorValidator.evWrongOutputType,
4262 TosaErrorValidator.evWrongInputList,
4263 TosaErrorValidator.evWrongOutputList,
4264 TosaErrorValidator.evWrongRank,
4265 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004266 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004267 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004268 "resize": {
4269 "op": Op.RESIZE,
4270 "operands": (1, 0),
4271 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004272 "build_fcn": (
4273 build_resize,
4274 TosaTensorGen.tgNHWC,
4275 TosaTensorValuesGen.tvgDefault,
4276 TosaArgGen.agResize,
4277 ),
James Ward24dbc422022-10-19 12:20:31 +01004278 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004279 "invalid_test_validators": (
4280 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004281 ),
4282 "error_if_validators": (
4283 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004284 TosaErrorValidator.evScaleSmallerEqualZero,
4285 TosaErrorValidator.evScaleNLargerMax,
4286 TosaErrorValidator.evScaleDLargerMax,
4287 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004288 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004289 TosaErrorValidator.evBorderSmallerMin,
4290 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004291 TosaErrorValidator.evWrongInputType,
4292 TosaErrorValidator.evWrongOutputType,
4293 TosaErrorValidator.evWrongRank,
4294 TosaErrorValidator.evWrongInputList,
4295 TosaErrorValidator.evWrongOutputList,
4296 TosaErrorValidator.evBatchMismatch,
4297 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004298 TosaErrorValidator.evResizeOutputShapeMismatch,
4299 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004300 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004301 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004302 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004303 "cast": {
4304 "op": Op.CAST,
4305 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004306 "build_fcn": (
4307 build_cast,
4308 TosaTensorGen.tgBasic,
4309 TosaTensorValuesGen.tvgDefault,
4310 TosaArgGen.agCast,
4311 ),
James Ward8b390432022-08-12 20:48:56 +01004312 "types": (
4313 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004314 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004315 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004316 DType.INT8,
4317 DType.INT16,
4318 DType.INT32,
4319 DType.BOOL,
4320 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004321 "error_if_validators": (
4322 TosaErrorValidator.evWrongInputType,
4323 TosaErrorValidator.evWrongOutputType,
4324 TosaErrorValidator.evWrongInputList,
4325 TosaErrorValidator.evWrongOutputList,
4326 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004327 },
4328 "rescale": {
4329 "op": Op.RESCALE,
4330 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004331 "build_fcn": (
4332 build_rescale,
4333 TosaTensorGen.tgBasic,
4334 TosaTensorValuesGen.tvgDefault,
4335 TosaArgGen.agRescale,
4336 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004337 "types": [
4338 DType.UINT8,
4339 DType.INT8,
4340 DType.INT16,
4341 DType.INT32,
4342 DType.INT48,
4343 DType.UINT16,
4344 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004345 "error_if_validators": (
4346 TosaErrorValidator.evInputZeroPointNotZero,
4347 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004348 TosaErrorValidator.evU16InputZeroPointNotValid,
4349 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004350 TosaErrorValidator.evScaleTrue,
4351 TosaErrorValidator.evScaleNotTrue,
4352 TosaErrorValidator.evWrongInputType,
4353 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004354 TosaErrorValidator.evWrongInputList,
4355 TosaErrorValidator.evWrongOutputList,
4356 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004357 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004358 # Custom
4359 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004360 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004361 # Two varients of cond_if, one that generates one of two constant tensors (no
4362 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4363 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004364 "cond_if_const": {
4365 "op": Op.COND_IF,
4366 "operands": (0, 2),
4367 "build_fcn": (
4368 build_cond_if_const,
4369 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004370 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004371 TosaArgGen.agCondIf,
4372 ),
4373 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004374 "error_if_validators": (
4375 TosaErrorValidator.evOutputListThenGraphMismatch,
4376 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004377 TosaErrorValidator.evCondIfCondNotMatchingBool,
4378 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004379 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004380 },
4381 "cond_if_binary": {
4382 "op": Op.COND_IF,
4383 "operands": (2, 0),
4384 "build_fcn": (
4385 build_cond_if_binary,
4386 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004387 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004388 TosaArgGen.agCondIf,
4389 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004390 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004391 "error_if_validators": (
4392 TosaErrorValidator.evInputListThenGraphMismatch,
4393 TosaErrorValidator.evInputListElseGraphMismatch,
4394 TosaErrorValidator.evOutputListThenGraphMismatch,
4395 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004396 TosaErrorValidator.evCondIfCondNotMatchingBool,
4397 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004398 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004399 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004400 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004401 "while_loop": {
4402 "op": Op.WHILE_LOOP,
4403 "operands": (0, 1),
4404 "build_fcn": (
4405 build_while_loop,
4406 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004407 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004408 TosaArgGen.agWhileLoop,
4409 ),
4410 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004411 "error_if_validators": (
4412 TosaErrorValidator.evInputListOutputListMismatch,
4413 TosaErrorValidator.evInputListCondGraphMismatch,
4414 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4415 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4416 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004417 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004418 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004419 },
Luke Hutton57287132023-02-06 14:54:18 +00004420 "fft2d": {
4421 "op": Op.FFT2D,
4422 "operands": (2, 0),
4423 "rank": (3, 3),
4424 "build_fcn": (
4425 build_fft2d,
4426 TosaTensorGen.tgFFT2d,
4427 TosaTensorValuesGen.tvgDefault,
4428 TosaArgGen.agFFT2d,
4429 ),
4430 "types": [DType.FP32],
4431 "error_if_validators": (
4432 TosaErrorValidator.evWrongInputType,
4433 TosaErrorValidator.evWrongOutputType,
4434 TosaErrorValidator.evWrongInputList,
4435 TosaErrorValidator.evWrongOutputList,
4436 TosaErrorValidator.evWrongRank,
4437 TosaErrorValidator.evBatchMismatch,
4438 TosaErrorValidator.evKernelNotPowerOfTwo,
4439 TosaErrorValidator.evFFTInputShapeMismatch,
4440 TosaErrorValidator.evFFTOutputShapeMismatch,
4441 ),
4442 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004443 "rfft2d": {
4444 "op": Op.RFFT2D,
4445 "operands": (1, 0),
4446 "rank": (3, 3),
4447 "build_fcn": (
4448 build_rfft2d,
4449 TosaTensorGen.tgRFFT2d,
4450 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004451 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004452 ),
4453 "types": [DType.FP32],
4454 "error_if_validators": (
4455 TosaErrorValidator.evWrongInputType,
4456 TosaErrorValidator.evWrongOutputType,
4457 TosaErrorValidator.evWrongInputList,
4458 TosaErrorValidator.evWrongOutputList,
4459 TosaErrorValidator.evWrongRank,
4460 TosaErrorValidator.evBatchMismatch,
4461 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004462 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004463 ),
4464 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004465 }
4466
Kevin Cheng550ccc52021-03-03 11:21:43 -08004467
Eric Kunzee5e26762020-10-13 16:11:07 -07004468class OutputShaper:
4469 # Methods in this class compute the expected output shape and datatype
4470 # for common classes of operations
4471 def __init__(self):
4472 pass
4473
4474 # These methods return arguments that can be used for
4475 # creating a new output tensor
4476 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004477 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4478 if error_name != ErrorIf.RankMismatch:
4479 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004480 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004481
4482 shape = []
4483 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004484 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004485 shape.append(b.shape[i])
4486 else:
4487 shape.append(a.shape[i])
4488
Jerry Ge135c9552023-05-23 20:59:32 +00004489 fuzz_idx = rng.integers(0, len(a.shape))
4490 if error_name == ErrorIf.DimensionMismatch:
4491 shape[fuzz_idx] += 1
4492
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004493 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004494 all_dtypes = [
4495 DType.INT8,
4496 DType.INT16,
4497 DType.INT32,
4498 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004499 DType.FP16,
4500 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004501 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004502 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004503 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4504 outputDType = rng.choice(wrong_dtypes)
4505 else:
4506 outputDType = a.dtype
4507
4508 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004509
4510 @staticmethod
4511 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004512 assert len(a.shape) == len(b.shape)
4513 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004514
4515 shape = []
4516 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004517 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004518 shape.append(a.shape[i])
4519
Kevin Cheng550ccc52021-03-03 11:21:43 -08004520 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004521
4522 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004523 def unaryOp(ser, rng, a, error_name=None):
4524 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004525 all_dtypes = [
4526 DType.INT8,
4527 DType.INT16,
4528 DType.INT32,
4529 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004530 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004531 DType.FP16,
4532 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004533 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004534 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4535 outputDType = rng.choice(wrong_dtypes)
4536 else:
4537 outputDType = a.dtype
4538
4539 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004540
4541 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004542 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004543 if error_name != ErrorIf.RankMismatch:
4544 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004545 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004546
4547 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004548 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004549 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004550 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4551 else:
4552 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004553
Jerry Ge135c9552023-05-23 20:59:32 +00004554 fuzz_idx = rng.integers(0, len(a.shape))
4555 if error_name == ErrorIf.DimensionMismatch:
4556 shape[fuzz_idx] += 1
4557
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004558 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004559 all_dtypes = [
4560 DType.INT8,
4561 DType.INT16,
4562 DType.INT32,
4563 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004564 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004565 DType.FP16,
4566 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004567 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004568 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4569 outputDType = rng.choice(wrong_dtypes)
4570 else:
4571 outputDType = a.dtype
4572
4573 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004574
4575 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004576 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004577 if error_name != ErrorIf.RankMismatch:
4578 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004579 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004580
4581 # Do broadcast
4582 shape = []
4583 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004584 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004585 shape.append(b.shape[i])
4586 else:
4587 shape.append(a.shape[i])
4588
Jerry Ge135c9552023-05-23 20:59:32 +00004589 fuzz_idx = rng.integers(0, len(a.shape))
4590 if error_name == ErrorIf.DimensionMismatch:
4591 shape[fuzz_idx] += 1
4592
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004593 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004594 wrong_dtypes = [
4595 DType.INT8,
4596 DType.INT16,
4597 DType.INT32,
4598 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004599 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004600 DType.FP16,
4601 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004602 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004603 outputDType = rng.choice(wrong_dtypes)
4604 else:
4605 outputDType = DType.BOOL
4606
4607 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004608
4609 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004610 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004611 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004612 if error_name not in [
4613 ErrorIf.AxisSmallerZero,
4614 ErrorIf.AxisLargerRank,
4615 ErrorIf.ShapeOfAxisNotOne,
4616 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004617 shape[axis] = 1
4618 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4619 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004620
Matthew Haddond6ce7252021-09-29 15:35:44 +01004621 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004622 all_dtypes = [
4623 DType.INT8,
4624 DType.INT16,
4625 DType.INT32,
4626 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004627 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004628 DType.FP16,
4629 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004630 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004631 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4632 outputDType = rng.choice(wrong_dtypes)
4633 else:
4634 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004635
Matthew Haddond6ce7252021-09-29 15:35:44 +01004636 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004637
4638 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004639 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004640 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004641
4642 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4643 del shape[axis]
4644
4645 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4646 remove = rng.choice([True, False])
4647 if remove and len(shape) > 1:
4648 del shape[0]
4649 else:
4650 shape.append(1)
4651 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4652 for i in range(len(shape)):
4653 shape[i] = shape[i] + rng.integers(1, 10)
4654
4655 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004656 all_dtypes = [
4657 DType.INT8,
4658 DType.INT16,
4659 DType.INT32,
4660 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004661 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004662 DType.FP16,
4663 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004664 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004665 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4666 outputDType = rng.choice(wrong_dtypes)
4667 else:
4668 outputDType = DType.INT32
4669
4670 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004671
4672 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004673 def conv2dOp(
4674 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4675 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004676
4677 # IFM: NHWC
4678 # Filter: OHWI
4679 # OFM: NHWC
4680
Kevin Cheng550ccc52021-03-03 11:21:43 -08004681 h = (
4682 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004683 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004684 + padding[0]
4685 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004686 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004687 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004688
Kevin Cheng550ccc52021-03-03 11:21:43 -08004689 w = (
4690 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004691 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004692 + padding[2]
4693 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004694 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004695 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004696
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004697 if error_name == ErrorIf.ConvOutputShapeMismatch:
4698 choices = [1, 2, 3]
4699 change = rng.choice(choices)
4700 # increment in multiples of stride to not hit non-integer error case
4701 if change in [1, 3]:
4702 h = h + (rng.choice(choices) * strides[0])
4703 if change in [2, 3]:
4704 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004705
Eric Kunzee5e26762020-10-13 16:11:07 -07004706 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4707
James Ward8b390432022-08-12 20:48:56 +01004708 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004709 # Pick some potentially correct output dtype if input type is incorrect
4710 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004711 else:
James Ward8b390432022-08-12 20:48:56 +01004712 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004713
4714 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004715 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004716 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004717 else:
4718 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004719 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004720 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004721
Kevin Cheng550ccc52021-03-03 11:21:43 -08004722 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004723
4724 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004725 def conv3dOp(
4726 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4727 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004728
4729 # IFM: NDHWC
4730 # Filter: ODHWI
4731 # OFM: NDHWC
4732
4733 d = (
4734 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004735 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004736 + padding[0]
4737 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004738 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004739 ) // strides[0] + 1
4740
4741 h = (
4742 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004743 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004744 + padding[2]
4745 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004746 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004747 ) // strides[1] + 1
4748
4749 w = (
4750 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004751 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004752 + padding[4]
4753 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004754 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004755 ) // strides[2] + 1
4756
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004757 if error_name == ErrorIf.ConvOutputShapeMismatch:
4758 choices = [1, 2, 3, 4]
4759 change = rng.choice(choices)
4760 # increment in multiples of stride to not hit non-integer error case
4761 if change in [1, 4]:
4762 d = d + (rng.choice(choices) * strides[0])
4763 if change in [2, 4]:
4764 h = h + (rng.choice(choices) * strides[1])
4765 if change in [3, 4]:
4766 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004767
Kevin Cheng1533b852021-09-01 12:51:58 -07004768 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4769
James Ward8b390432022-08-12 20:48:56 +01004770 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004771 # Pick some potentially correct output dtype if input type is incorrect
4772 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004773 else:
James Ward8b390432022-08-12 20:48:56 +01004774 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004775
4776 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004777 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004778 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004779 else:
4780 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004781 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004782 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004783
4784 return ser.addOutput(ofm_shape, out_dtype)
4785
4786 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004787 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004788 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004789 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004790 # IFM: NHWC
4791 # Filter: HWCM
4792 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004793
Kevin Cheng550ccc52021-03-03 11:21:43 -08004794 h = (
4795 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004796 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004797 + padding[0]
4798 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004799 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004800 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004801
Kevin Cheng550ccc52021-03-03 11:21:43 -08004802 w = (
4803 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004804 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004805 + padding[2]
4806 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004807 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004808 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004809
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004810 if error_name == ErrorIf.ConvOutputShapeMismatch:
4811 choices = [1, 2, 3]
4812 change = rng.choice(choices)
4813 # increment in multiples of stride to not hit non-integer error case
4814 if change in [1, 3]:
4815 h = h + (rng.choice(choices) * strides[0])
4816 if change in [2, 3]:
4817 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004818
Eric Kunzee5e26762020-10-13 16:11:07 -07004819 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4820
James Ward8b390432022-08-12 20:48:56 +01004821 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004822 # Pick some potentially correct output dtype if input type is incorrect
4823 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004824 else:
James Ward8b390432022-08-12 20:48:56 +01004825 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004826
4827 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004828 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004829 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004830 else:
4831 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004832 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004833 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004834
Kevin Cheng550ccc52021-03-03 11:21:43 -08004835 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004836
4837 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004838 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004839 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004840 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004841 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004842 h = 1
4843 w = 1
4844 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004845 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4846 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004847
4848 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004849 choices = [1, 2, 3]
4850 change = rng.choice(choices)
4851 # increment in multiples of stride to not hit non-integer error case
4852 if change in [1, 3]:
4853 h = h + (rng.choice(choices) * stride[0])
4854 if change in [2, 3]:
4855 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004856 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004857
4858 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004859 all_dtypes = [
4860 DType.INT8,
4861 DType.INT16,
4862 DType.INT32,
4863 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004864 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004865 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004866 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004867 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004868 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4869 outputDType = rng.choice(wrong_dtypes)
4870 else:
4871 outputDType = ifm.dtype
4872
4873 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004874
4875 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004876 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004877 # input: N, IC
4878 # filter: OC, IC
4879 # output: N, OC
4880
4881 output_shape = [input.shape[0], filter.shape[0]]
4882
James Ward8b390432022-08-12 20:48:56 +01004883 # Validated in arg_gen (also invalidated for ErrorIf)
4884 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004885
Kevin Cheng550ccc52021-03-03 11:21:43 -08004886 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004887
4888 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004889 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004890 # a: N, H, C
4891 # b: N, C, W
4892 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004893
Kevin Cheng2d60f002021-06-09 14:18:32 -07004894 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004895
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004896 if error_name == ErrorIf.WrongOutputType:
4897 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004898 incorrect_types = (
4899 DType.INT4,
4900 DType.INT8,
4901 DType.INT16,
4902 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004903 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004904 DType.FP16,
4905 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004906 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004907 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004908 incorrect_types = (
4909 DType.INT4,
4910 DType.INT8,
4911 DType.INT16,
4912 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004913 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004914 DType.FP16,
4915 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004916 )
James Ward24dbc422022-10-19 12:20:31 +01004917 elif (
4918 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4919 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004920 incorrect_types = (
4921 DType.INT4,
4922 DType.INT8,
4923 DType.INT16,
4924 DType.INT32,
4925 DType.INT48,
4926 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004927 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004928 elif error_name == ErrorIf.WrongInputType:
4929 # Pick some potentially correct output dtype if input type is incorrect
4930 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004931 else:
James Ward8b390432022-08-12 20:48:56 +01004932 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004933
Kevin Cheng550ccc52021-03-03 11:21:43 -08004934 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004935
4936 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004937 def concatOp(ser, rng, axis, inputs, error_name=None):
4938 input1 = inputs[0]
4939 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004940
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004941 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004942 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004943 if not (
4944 # unable to concat tensors of different ranks
4945 error_name == ErrorIf.ConcatInputRankMismatch
4946 # unable to concat tensors along an invalid axis
4947 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004948 ):
4949 for tensor in remaining_inputs:
4950 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004951
Matthew Haddon01c359d2021-10-15 16:30:48 +01004952 if error_name == ErrorIf.ConcatShapeSumMismatch:
4953 output_shape[axis] += rng.integers(5, 10)
4954
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004955 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004956 all_dtypes = {
4957 DType.INT8,
4958 DType.INT16,
4959 DType.INT32,
4960 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004961 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004962 DType.FP16,
4963 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004964 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004965 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4966 outputDType = rng.choice(wrong_dtypes)
4967 else:
4968 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004969
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004970 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004971
4972 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004973 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004974
4975 output_shape = a.shape.copy()
4976
4977 for i in range(len(output_shape)):
4978 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4979
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004980 if error_name == ErrorIf.PadOutputShapeMismatch:
4981 bad_dim = rng.choice(range(len(output_shape)))
4982 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004983 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004984 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004985
Matthew Haddone807aae2021-10-11 18:12:58 +01004986 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004987 all_dtypes = [
4988 DType.INT8,
4989 DType.INT16,
4990 DType.INT32,
4991 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004992 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004993 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004994 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004995 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004996 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4997 outputDType = rng.choice(wrong_dtypes)
4998 else:
4999 outputDType = a.dtype
5000
5001 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005002
5003 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005004 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00005005 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00005006
5007 if error_name == ErrorIf.WrongOutputType:
5008 all_dtypes = [
5009 DType.INT8,
5010 DType.INT16,
5011 DType.INT32,
5012 DType.INT48,
5013 DType.FP32,
5014 DType.FP16,
5015 DType.BF16,
5016 ]
5017 wrong_dtypes = list(set(all_dtypes))
5018 outputDType = rng.choice(wrong_dtypes)
5019 else:
5020 outputDType = DType.SHAPE
5021
5022 return ser.addOutput(output_shape, outputDType)
5023
5024 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005025 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005026 output_shape = shape.copy()
5027
Matthew Haddone807aae2021-10-11 18:12:58 +01005028 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5029 for i in range(len(output_shape)):
5030 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5031
5032 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005033 all_dtypes = [
5034 DType.INT8,
5035 DType.INT16,
5036 DType.INT32,
5037 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005038 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005039 DType.FP16,
5040 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005041 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005042 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5043 outputDType = rng.choice(wrong_dtypes)
5044 else:
5045 outputDType = a.dtype
5046
5047 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005048
5049 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005050 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005051
Matthew Haddone807aae2021-10-11 18:12:58 +01005052 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005053 all_dtypes = [
5054 DType.INT8,
5055 DType.INT16,
5056 DType.INT32,
5057 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005058 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005059 DType.FP16,
5060 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005061 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005062 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005063 outputDType = rng.choice(wrong_dtypes)
5064 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005065 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005066
Luke Huttona4e48ca2023-02-22 11:53:48 +00005067 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005068 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005069 for index in range(len(output_shape)):
5070 if output_shape[index] <= 2:
5071 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5072 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005073 output_shape[index] = output_shape[index] + rng.choice(
5074 [-2, -1, 1, 2]
5075 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005076 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5077 output_shape = input.shape.copy()
5078 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005079 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005080
5081 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005082
5083 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005084 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005085
5086 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005087 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005088
5089 for i in range(len(output_shape)):
5090 output_shape[i] = a.shape[i] * multiples[i]
5091
Luke Huttona4e48ca2023-02-22 11:53:48 +00005092 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005093 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005094
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005095 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005096 all_dtypes = [
5097 DType.INT8,
5098 DType.INT16,
5099 DType.INT32,
5100 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005101 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005102 DType.FP16,
5103 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005104 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005105 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5106 outputDType = rng.choice(wrong_dtypes)
5107 else:
5108 outputDType = a.dtype
5109
5110 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005111
5112 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005113 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005114 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005115
Kevin Cheng550ccc52021-03-03 11:21:43 -08005116 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005117
Luke Huttona4e48ca2023-02-22 11:53:48 +00005118 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005119 for i in range(len(output_shape)):
5120 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005121
Luke Huttona4e48ca2023-02-22 11:53:48 +00005122 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5123 for i in range(len(output_shape)):
5124 output_shape[i] += rng.integers(1, 10)
5125 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005126 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005127
Matthew Haddone807aae2021-10-11 18:12:58 +01005128 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005129 all_dtypes = [
5130 DType.INT8,
5131 DType.INT16,
5132 DType.INT32,
5133 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005134 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005135 DType.FP16,
5136 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005137 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005138 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5139 outputDType = rng.choice(wrong_dtypes)
5140 else:
5141 outputDType = a.dtype
5142
5143 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005144
5145 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005146 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005147 if error_name != ErrorIf.WrongRank:
5148 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005149 assert len(indices.shape) == 2
5150 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005151
Kevin Cheng77d0f762020-11-24 10:26:32 -08005152 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5153
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005154 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005155 all_dtypes = [
5156 DType.INT8,
5157 DType.INT16,
5158 DType.INT32,
5159 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005160 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005161 DType.FP16,
5162 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005163 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005164 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5165 outputDType = rng.choice(wrong_dtypes)
5166 else:
5167 outputDType = values.dtype
5168
5169 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005170
5171 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005172 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005173 if error_name != ErrorIf.WrongRank:
5174 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005175 assert len(indices.shape) == 2
5176 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005177 assert values_in.shape[0] == indices.shape[0] # N
5178 assert input.shape[1] == indices.shape[1] # W
5179 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005180
5181 output_shape = values_in.shape
5182
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005183 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005184 all_dtypes = [
5185 DType.INT8,
5186 DType.INT16,
5187 DType.INT32,
5188 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005189 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005190 DType.FP16,
5191 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005192 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005193 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5194 outputDType = rng.choice(wrong_dtypes)
5195 else:
5196 outputDType = values_in.dtype
5197
5198 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005199
5200 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005201 def tableOp(ser, rng, input, error_name=None):
5202 # Same shape as the input, dtype dependent on input dtype
5203 if error_name != ErrorIf.WrongInputType:
5204 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005205 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005206 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005207 wrong_dtypes = [
5208 DType.INT8,
5209 DType.INT16,
5210 DType.INT32,
5211 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005212 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005213 DType.FP16,
5214 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005215 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005216 wrong_dtypes.remove(output_dtype)
5217 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005218 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005219
5220 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005221 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005222 serializer,
5223 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005224 input,
5225 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005226 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005227 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005228 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005229 input_dtype,
5230 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005231 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005232 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005233 # Calculate OH, OW
5234 scale_y_n = scale[0]
5235 scale_y_d = scale[1]
5236 scale_x_n = scale[2]
5237 scale_x_d = scale[3]
5238 if error_name == ErrorIf.ScaleSmallerEqualZero:
5239 scale_y_n = max(scale_y_n, 1)
5240 scale_y_d = max(scale_y_d, 1)
5241 scale_x_n = max(scale_x_n, 1)
5242 scale_x_d = max(scale_x_d, 1)
5243
5244 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5245 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5246
5247 if error_name is not None:
5248 # Make sure the output tensor is valid, which can occur when
5249 # scale, offset or border have been changed for ERROR_IFs
5250 oh = max(oh, 1)
5251 ow = max(ow, 1)
5252 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005253 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5254 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005255
5256 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5257 choices = [1, 2, 3]
5258 change = rng.choice(choices)
5259 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5260 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005261 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005262 oh -= scale_y_d
5263 assert oh > 0 # Should have been caught in agResize
5264 else:
5265 oh += scale_y_d
5266 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005267 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005268 ow -= scale_x_d
5269 assert ow > 0 # Should have been caught in agResize
5270 else:
5271 ow += scale_x_d
5272
Matthew Haddon848efb42021-09-09 12:30:53 +01005273 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005274 output_dims = [
5275 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005276 oh,
5277 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005278 input.shape[0],
5279 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005280 elif error_name == ErrorIf.BatchMismatch:
5281 output_dims = [
5282 input.shape[0] + rng.integers(1, 10),
5283 oh,
5284 ow,
5285 input.shape[3],
5286 ]
5287 elif error_name == ErrorIf.ChannelMismatch:
5288 output_dims = [
5289 input.shape[0],
5290 oh,
5291 ow,
5292 input.shape[3] + rng.integers(1, 10),
5293 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005294 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005295 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005296
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005297 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005298
5299 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005300 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005301 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005302
5303 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005304 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005305 if error_name == ErrorIf.ConvOutputShapeMismatch:
5306 choices = [1, 2, 3]
5307 change = rng.choice(choices)
5308 if change in [1, 3]:
5309 output_shape[1] = output_shape[1] + rng.choice(choices)
5310 if change in [2, 3]:
5311 output_shape[2] = output_shape[2] + rng.choice(choices)
5312
James Ward8b390432022-08-12 20:48:56 +01005313 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005314 # Pick some potentially correct output dtype if input type is incorrect
5315 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005316 else:
James Ward8b390432022-08-12 20:48:56 +01005317 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005318
5319 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005320 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005321 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005322 else:
5323 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005324 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005325 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005326
Kevin Cheng550ccc52021-03-03 11:21:43 -08005327 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005328
5329 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005330 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5331 outputs = []
5332
5333 assert ifm1.dtype == ifm2.dtype
5334 input_dtype = ifm1.dtype
5335
5336 if error_name != ErrorIf.FFTInputShapeMismatch:
5337 assert ifm1.shape == ifm2.shape
5338
5339 input_shape = ifm1.shape
5340 if error_name != ErrorIf.WrongRank:
5341 assert len(input_shape) == 3
5342
5343 output_shape = input_shape.copy()
5344 output_dtype = input_dtype
5345
5346 if error_name == ErrorIf.WrongOutputType:
5347 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005348 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005349 output_dtype = rng.choice(wrong_dtypes)
5350 elif error_name == ErrorIf.BatchMismatch:
5351 output_shape[0] += rng.integers(1, 10)
5352 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5353 modify_dim = rng.choice([1, 2])
5354 output_shape[modify_dim] += rng.integers(1, 10)
5355
5356 outputs.append(serializer.addOutput(output_shape, output_dtype))
5357 outputs.append(serializer.addOutput(output_shape, output_dtype))
5358 return outputs
5359
5360 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005361 def rfft2dOp(serializer, rng, value, error_name=None):
5362 outputs = []
5363
5364 input_shape = value.shape
5365 if error_name != ErrorIf.WrongRank:
5366 assert len(input_shape) == 3
5367
5368 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5369
5370 output_dtype = value.dtype
5371 if error_name == ErrorIf.WrongOutputType:
5372 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005373 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005374 output_dtype = rng.choice(wrong_dtypes)
5375 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005376 output_shape[0] += rng.integers(1, 10)
5377 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5378 modify_dim = rng.choice([1, 2])
5379 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005380
5381 outputs.append(serializer.addOutput(output_shape, output_dtype))
5382 outputs.append(serializer.addOutput(output_shape, output_dtype))
5383 return outputs