blob: d1fe11d936b726a78e4ad05d028271c625600104 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
40 TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6)
41 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
69 vals.append(v)
70 return tuple(sorted(vals))
71
72 self.random_float_range = {}
73 for dtype in (DType.FP32, DType.FP16, DType.BF16):
74 self.random_float_range[dtype] = convertFPRange(
75 args.tensor_fp_value_range,
76 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
77 )
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 def createSerializer(self, opName, testPath):
80 self.testPath = os.path.join(opName, testPath)
81
82 fullPath = os.path.join(self.basePath, self.testPath)
83 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010084 # Embed const data in the flatbuffer
85 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010086 if self.args.lazy_data_gen:
87 # Lazy data generation - so make constants files
88 constMode = ts.ConstMode.INPUTS
89 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 constMode = ts.ConstMode.EMBED_DUMP
91 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070092
93 def getSerializer(self):
94 return self.ser
95
Jeremy Johnson1271c442023-09-05 11:39:26 +010096 def serialize(self, testName, metaData=None):
97 path = Path(self.basePath) / self.testPath
98
99 # Write out TOSA flatbuffer binary
100 path_fb = path / f"{testName}.tosa"
101 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 fd.write(self.ser.serialize())
103
Jeremy Johnson1271c442023-09-05 11:39:26 +0100104 # Get JSON descriptor from serializer
105 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
106
107 if metaData:
108 # Add extra meta data to desc.json
109 desc["meta"] = metaData
110
111 # Validate desc.json before we output it
112 self.descSchemaValidator.validate_config(desc)
113
114 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100115 if "data_gen" in metaData:
116 if self.args.lazy_data_gen:
117 # Output datagen meta data as CPP data
118 path_md = path / f"{testName}_meta_data_gen.cpp"
119 with path_md.open("w") as fd:
120 fd.write(TOSA_AUTOGENERATED_HEADER)
121 fd.write("// Test meta data for data generation setup\n\n")
122 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
123 json.dump(metaData["data_gen"], fd)
124 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100125 if "compliance" in metaData:
126 # Output datagen meta data as CPP data
127 path_md = path / f"{testName}_meta_compliance.cpp"
128 with path_md.open("w") as fd:
129 fd.write(TOSA_AUTOGENERATED_HEADER)
130 fd.write("// Test meta data for compliance validation\n\n")
131 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
132 json.dump(metaData["compliance"], fd)
133 fd.write(')";\n\n')
134
135 # Write desc.json
136 path_desc = path / "desc.json"
137 with path_desc.open("w") as fd:
138 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700139
Matthew Haddon74567092021-07-16 15:38:20 +0100140 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000141 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100142 seed = self.random_seed + 1
143 self.rng = np.random.default_rng(seed)
144
Jeremy Johnson1271c442023-09-05 11:39:26 +0100145 def getDTypeRange(self, dtype, high_inclusive=False):
146 # Returns dtype value range boundaries (low, high)
147 # The high boundary is excluded in the range
148 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100149 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100150 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 elif dtype == DType.BOOL:
152 rng = (0, 2)
153 elif dtype == DType.UINT8:
154 rng = (0, 256)
155 elif dtype == DType.UINT16:
156 rng = (0, 65536)
157 elif dtype == DType.INT4:
158 # TOSA specific INT4 weight range from -7 to 7
159 rng = (-7, 8)
160 elif dtype == DType.INT8:
161 rng = (-128, 128)
162 elif dtype == DType.INT16:
163 rng = (-32768, 32768)
164 elif dtype in (DType.INT32, DType.SHAPE):
165 # restricting too large value for SHAPE
166 rng = (-(1 << 31), (1 << 31))
167 elif dtype == DType.INT48:
168 rng = (-(1 << 47), (1 << 47))
169 else:
170 raise Exception("Unknown dtype: {}".format(dtype))
171
172 if not high_inclusive:
173 # Exclusive high: low <= range < high
174 return rng
175 else:
176 # Inclusive range: low <= range <= high
177 return (rng[0], rng[1] - 1)
178
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 def getRandTensor(self, shape, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100180 low, high = self.getDTypeRange(dtype)
181
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700183 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700184 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100185 return np.int64(self.rng.integers(low=low, high=high, size=shape))
186 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
187 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
188
189 if dtype == DType.FP16:
190 return np.float16(f_tensor)
191 else:
192 f32_tensor = np.float32(f_tensor)
193 if dtype == DType.BF16:
194 # Floor the last 16 bits of each f32 value
195 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
196 else:
197 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700198 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 # All other integer types
200 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700201
Kevin Cheng989cb052021-04-28 16:29:44 -0700202 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700203 placeholders = []
204
Kevin Cheng989cb052021-04-28 16:29:44 -0700205 assert len(shape_list) == len(dtype_list)
206
Jeremy Johnson1271c442023-09-05 11:39:26 +0100207 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700208 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100209 if not self.args.lazy_data_gen:
210 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700212
213 return placeholders
214
Kevin Cheng989cb052021-04-28 16:29:44 -0700215 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700216 consts = []
217
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 assert len(shape_list) == len(dtype_list)
219
Jeremy Johnson1271c442023-09-05 11:39:26 +0100220 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700221 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100222 if not self.args.lazy_data_gen:
223 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700225
226 return consts
227
228 def makeShape(self, rank):
229 if self.targetted_shape:
230 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800231 return np.int32(
232 self.rng.integers(
233 low=self.args.tensor_shape_range[0],
234 high=self.args.tensor_shape_range[1],
235 size=rank,
236 )
237 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700238
239 def setTargetShape(self, shape):
240 self.targetted_shape = shape
241
242 def randInt(self, low=0, high=256):
243 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
244
245 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100246 low, high = self.getDTypeRange(dtype)
247
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100248 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100249 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100250 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100251 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100252 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100253 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
254 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700255 elif dtype == DType.BOOL:
256 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700257 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 # Special size
259 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700260
261 return np.int32(self.rng.integers(low, high, size=1))[0]
262
263 def shapeStr(self, shape):
264
265 sStr = []
266 # Convert to strings
267 for i in shape:
268 sStr.append(str(i))
269
Kevin Cheng550ccc52021-03-03 11:21:43 -0800270 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700271
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100272 def typeStr(self, dtype):
273 if isinstance(dtype, list) or isinstance(dtype, tuple):
274 assert len(dtype) >= 2
275 strs = [self.typeStr(t) for t in dtype]
276 # Limit types to the first 2 as the 3rd is the accumulator
277 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700278 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100279 if dtype in gtu.DTYPE_ATTRIBUTES:
280 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700281 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100282 raise Exception(
283 "Unknown dtype, cannot convert to string: {}".format(dtype)
284 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100286 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100287 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100288 if dtype in gtu.DTYPE_ATTRIBUTES:
289 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700290 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100291 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700292
Luke Hutton57287132023-02-06 14:54:18 +0000293 def constrictBatchSize(self, shape):
294 # Limit the batch size unless an explicit target shape set
295 if self.args.max_batch_size and not self.args.target_shapes:
296 shape[0] = min(shape[0], self.args.max_batch_size)
297 return shape
298
James Ward30124a82023-02-02 14:56:33 +0000299 def makeDimension(self):
300 return self.randInt(
301 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
302 )
303
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100304 def tensorComplianceMetaData(
305 self, op, inputType, argsDict, outputTensor, errorName
306 ):
307 if (
308 errorName
309 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
310 or not gtu.dtypeIsSupportedByCompliance(inputType)
311 ):
312 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100313 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100314
Jeremy Johnson1271c442023-09-05 11:39:26 +0100315 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100316 compliance_tens = {
317 "mode": None,
318 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
319 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
320 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100321 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
322 mode = gtu.ComplianceMode.DOT_PRODUCT
323 compliance_tens["dot_product_info"] = {
324 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100325 "ks": int(argsDict["ksb"])
326 if "ksb" in argsDict
327 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100328 }
329 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
330 mode = gtu.ComplianceMode.FP_SPECIAL
331 elif "compliance" in op and "ulp" in op["compliance"]:
332 mode = gtu.ComplianceMode.ULP
333 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
334 elif op["op"] == Op.REDUCE_PRODUCT:
335 mode = gtu.ComplianceMode.REDUCE_PRODUCT
336 else:
337 mode = gtu.ComplianceMode.EXACT
338 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
339
340 return compliance_tens
341
342 # Build Op functions
343 # Create the output tensor (calling OutputShaper as needed)
344 # Do final tweaks to attributes (if necessary for errorIf)
345 # Add Op into graph
346 # Return resulting tensor information or BuildInfo
347
348 class BuildInfo:
349 """Enhanced build information containing result tensor and associated compliance dict."""
350
351 def __init__(self, resultTensor, complianceDict):
352 self.resultTensor = resultTensor
353 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700354
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100355 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
356 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
357
Matthew Haddon848efb42021-09-09 12:30:53 +0100358 # build_placeholder returns an int, ABS/other ops does not
359 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000360 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100361 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000362 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000363 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100364 return result_tens
365
366 # Ensure new output type has correct qinfo
367 if error_name == ErrorIf.WrongOutputType:
368 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000369 qinfo = [
370 TosaQuantGen.getZeroPoint(self, a.dtype),
371 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
372 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100373
374 # Invalidate Input/Output list for error if checks.
375 input_list = [a.name]
376 output_list = [result_tens.name]
377 pCount, cCount = op["operands"]
378 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000379 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
380 self, error_name, input_list, output_list
381 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100382
Les Bell729b0352021-11-24 10:28:21 +0000383 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100384 self.ser,
385 validator_fcns,
386 error_name,
387 op=op,
388 input_dtype=a.dtype,
389 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000390 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000391 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100392 input_list=input_list,
393 output_list=output_list,
394 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000395 ):
396 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100397
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000398 attr = None
399 if op["op"] == Op.NEGATE:
400 attr = ts.TosaSerializerAttribute()
401 attr.NegateAttribute(qinfo[0], qinfo[1])
402
403 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700404 return result_tens
405
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000406 def build_binary_broadcast(
407 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
408 ):
409 assert len(inputs) == 2
410 a, b = inputs
411 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000412 self.ser, self.rng, a, b, error_name
413 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100414
415 # Invalidate Input/Output list for error if checks.
416 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000417 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100418 pCount, cCount = op["operands"]
419 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000420 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
421 self, error_name, input_list, output_list
422 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100423
Les Bell729b0352021-11-24 10:28:21 +0000424 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100425 self.ser,
426 validator_fcns,
427 error_name,
428 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000429 input1=a,
430 input2=b,
431 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000432 output_dtype=result_tensor.dtype,
433 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100434 input_list=input_list,
435 output_list=output_list,
436 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000437 ):
438 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100439
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000440 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000441
442 if op["op"] == Op.POW:
443 # TODO - add compliance support
444 compliance = None
445 else:
446 compliance = self.tensorComplianceMetaData(
447 op, a.dtype, args_dict, result_tensor, error_name
448 )
449
450 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700451
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100452 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700453 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000454 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700455 return result_tens
456
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000457 def build_arithmetic_right_shift(
458 self, op, a, b, round, validator_fcns=None, error_name=None
459 ):
460 result_tens = OutputShaper.binaryBroadcastOp(
461 self.ser, self.rng, a, b, error_name
462 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100463
464 # Invalidate Input/Output list for error if checks.
465 input_list = [a.name, b.name]
466 output_list = [result_tens.name]
467 pCount, cCount = op["operands"]
468 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000469 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
470 self, error_name, input_list, output_list
471 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100472
Les Bell729b0352021-11-24 10:28:21 +0000473 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100474 self.ser,
475 validator_fcns,
476 error_name,
477 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000478 input1=a,
479 input2=b,
480 input_dtype=a.dtype,
481 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000482 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100483 input_list=input_list,
484 output_list=output_list,
485 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000486 ):
487 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800488
489 attr = ts.TosaSerializerAttribute()
490 attr.ArithmeticRightShiftAttribute(round)
491
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000492 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800493 return result_tens
494
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100495 def build_mul(
496 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
497 ):
498 assert len(inputs) == 2
499 a, b = inputs
500 shift = args_dict["shift"]
501
502 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000503 self.ser, self.rng, a, b, error_name
504 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700505
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100506 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100507 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100508 result_tensor.setDtype(DType.INT32)
509
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100510 if error_name == ErrorIf.WrongOutputType:
511 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
512 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100513 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100514
515 # Invalidate Input/Output list for error if checks.
516 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100517 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100518 pCount, cCount = op["operands"]
519 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000520 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
521 self, error_name, input_list, output_list
522 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100523
Les Bell729b0352021-11-24 10:28:21 +0000524 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100525 self.ser,
526 validator_fcns,
527 error_name,
528 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000529 input1=a,
530 input2=b,
531 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100532 output_dtype=result_tensor.dtype,
533 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100534 input_list=input_list,
535 output_list=output_list,
536 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000537 ):
538 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700539
Kevin Chengaee1fac2020-11-11 13:54:06 -0800540 attr = ts.TosaSerializerAttribute()
541 attr.MulAttribute(shift)
542
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000543 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100544
545 compliance = self.tensorComplianceMetaData(
546 op, a.dtype, args_dict, result_tensor, error_name
547 )
548
549 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700550
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100551 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
552 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700553
Kevin Chengfe392ce2021-10-18 21:51:55 +0000554 attr = ts.TosaSerializerAttribute()
555 attr.TableAttribute(table)
556
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100557 # Invalidate Input/Output list for error if checks.
558 input_list = [a.name]
559 output_list = [result_tens.name]
560 pCount, cCount = op["operands"]
561 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000562 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
563 self, error_name, input_list, output_list
564 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100565
Les Bell729b0352021-11-24 10:28:21 +0000566 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100567 self.ser,
568 validator_fcns,
569 error_name,
570 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000571 input_shape=a.shape,
572 input_dtype=a.dtype,
573 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000574 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100575 input_list=input_list,
576 output_list=output_list,
577 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000578 ):
579 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100580
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000581 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700582
583 return result_tens
584
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100585 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
586 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
587
588 # Invalidate Input/Output list for error if checks.
589 input_list = [cond.name, a.name, b.name]
590 output_list = [result_tens.name]
591 pCount, cCount = op["operands"]
592 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000593 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
594 self, error_name, input_list, output_list
595 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100596
Les Bell729b0352021-11-24 10:28:21 +0000597 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100598 self.ser,
599 validator_fcns,
600 error_name,
601 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000602 input1=cond,
603 input2=a,
604 input3=b,
605 input_shape=a.shape,
606 input_dtype=a.dtype,
607 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000608 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100609 input_list=input_list,
610 output_list=output_list,
611 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000612 ):
613 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100614
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000615 self.ser.addOperator(
616 op["op"],
617 input_list,
618 output_list,
619 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700620 return result_tens
621
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000623 result_tens = OutputShaper.binaryComparisonOp(
624 self.ser, self.rng, a, b, error_name
625 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100626
627 # Invalidate Input/Output list for error if checks.
628 input_list = [a.name, b.name]
629 output_list = [result_tens.name]
630 pCount, cCount = op["operands"]
631 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000632 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
633 self, error_name, input_list, output_list
634 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100635
Les Bell729b0352021-11-24 10:28:21 +0000636 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100637 self.ser,
638 validator_fcns,
639 error_name,
640 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000641 input1=a,
642 input2=b,
643 input_shape=a.shape,
644 input_dtype=a.dtype,
645 output_shape=result_tens.shape,
646 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000647 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100648 input_list=input_list,
649 output_list=output_list,
650 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000651 ):
652 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100653
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000654 self.ser.addOperator(
655 op["op"],
656 input_list,
657 output_list,
658 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700659 return result_tens
660
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000661 def build_argmax(
662 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
663 ):
664 assert len(inputs) == 1
665 a = inputs[0]
666 axis = args_dict["axis"]
667 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100668
669 # Invalidate Input/Output list for error if checks.
670 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000671 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100672 pCount, cCount = op["operands"]
673 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000674 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
675 self, error_name, input_list, output_list
676 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100677
Les Bell729b0352021-11-24 10:28:21 +0000678 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100679 self.ser,
680 validator_fcns,
681 error_name,
682 op=op,
683 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000684 input_shape=a.shape,
685 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000686 output_shape=result_tensor.shape,
687 output_dtype=result_tensor.dtype,
688 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100689 input_list=input_list,
690 output_list=output_list,
691 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000692 ):
693 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700694
695 attr = ts.TosaSerializerAttribute()
696 attr.AxisAttribute(axis)
697
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000698 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000699
700 compliance = self.tensorComplianceMetaData(
701 op, inputs[0].dtype, args_dict, result_tensor, error_name
702 )
703 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700704
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000705 def build_pool2d(
706 self,
707 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100708 inputs,
709 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000710 validator_fcns=None,
711 error_name=None,
712 qinfo=None,
713 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100714 assert len(inputs) == 1
715 input = inputs[0]
716 # max_pool has no accum_dtype
717 accum_dtype = (
718 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
719 )
720 stride = args_dict["stride"]
721 pad = args_dict["pad"]
722 kernel = args_dict["kernel"]
723
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000724 result_tens = OutputShaper.pool2dOp(
725 self.ser, self.rng, input, kernel, stride, pad, error_name
726 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100727
728 # Ensure new output type has correct qinfo
729 if error_name == ErrorIf.WrongInputType:
730 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000731 qinfo = [
732 TosaQuantGen.getZeroPoint(self, input.dtype),
733 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
734 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100735
736 # Invalidate Input/Output list for error if checks.
737 input_list = [input.name]
738 output_list = [result_tens.name]
739 pCount, cCount = op["operands"]
740 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000741 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
742 self, error_name, input_list, output_list
743 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100744
Les Bell729b0352021-11-24 10:28:21 +0000745 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100746 self.ser,
747 validator_fcns,
748 error_name,
749 op=op,
750 input_shape=input.shape,
751 input_dtype=input.dtype,
752 output_shape=result_tens.shape,
753 output_dtype=result_tens.dtype,
754 kernel=kernel,
755 stride=stride,
756 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000757 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000758 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100759 input_list=input_list,
760 output_list=output_list,
761 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000762 ):
763 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700764
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000765 if qinfo is None:
766 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700767
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000768 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100769 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000770
771 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700772 return result_tens
773
James Ward8b390432022-08-12 20:48:56 +0100774 def build_maxpool2d(
775 self,
776 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100777 inputs,
778 args_dict,
James Ward8b390432022-08-12 20:48:56 +0100779 validator_fcns=None,
780 error_name=None,
781 qinfo=None,
782 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100783 result_tensor = self.build_pool2d(
James Ward8b390432022-08-12 20:48:56 +0100784 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100785 inputs,
786 args_dict,
James Ward8b390432022-08-12 20:48:56 +0100787 validator_fcns,
788 error_name,
789 qinfo,
790 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100791 compliance = self.tensorComplianceMetaData(
792 op, inputs[0].dtype, args_dict, result_tensor, error_name
793 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100794
795 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100796
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000797 def build_conv2d(
798 self,
799 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100800 inputs,
801 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000802 validator_fcns=None,
803 error_name=None,
804 qinfo=None,
805 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100806 assert len(inputs) == 3
807 ifm, filter, bias = inputs
808 accum_dtype = args_dict["acc_type"]
809 strides = args_dict["stride"]
810 padding = args_dict["pad"]
811 dilations = args_dict["dilation"]
812
Kevin Cheng550ccc52021-03-03 11:21:43 -0800813 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100814 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100815 self.ser,
816 self.rng,
817 ifm,
818 filter,
819 accum_dtype,
820 strides,
821 padding,
822 dilations,
823 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000824 )
825
826 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000827 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
828 DType.INT8,
829 DType.UINT8,
830 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000831 qinfo = [
832 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100833 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000834 ]
Les Bell0e027d42021-11-09 14:42:14 +0000835
836 # Invalidate Input/Output list for error_if checks.
837 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100838 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000839 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000840 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
841 self, error_name, input_list, output_list
842 )
Les Bell0e027d42021-11-09 14:42:14 +0000843
Les Bell729b0352021-11-24 10:28:21 +0000844 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000845 self.ser,
846 validator_fcns,
847 error_name,
848 op=op,
849 input_dtype=ifm.dtype,
850 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100851 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000852 qinfo=qinfo,
853 input_list=input_list,
854 num_operands=num_operands,
855 output_list=output_list,
856 pad=padding,
857 stride=strides,
858 dilation=dilations,
859 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100860 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100861 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000862 ):
863 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700864
865 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000866 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700867
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000868 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100869
870 compliance = self.tensorComplianceMetaData(
871 op, ifm.dtype, args_dict, result_tensor, error_name
872 )
873
874 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700875
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000876 def build_conv3d(
877 self,
878 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100879 inputs,
880 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000881 validator_fcns=None,
882 error_name=None,
883 qinfo=None,
884 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100885 assert len(inputs) == 3
886 ifm, filter, bias = inputs
887 accum_dtype = args_dict["acc_type"]
888 strides = args_dict["stride"]
889 padding = args_dict["pad"]
890 dilations = args_dict["dilation"]
891
Kevin Cheng1533b852021-09-01 12:51:58 -0700892 assert len(padding) == 6
893 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100894 self.ser,
895 self.rng,
896 ifm,
897 filter,
898 accum_dtype,
899 strides,
900 padding,
901 dilations,
902 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000903 )
904
905 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000906 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
907 DType.INT8,
908 DType.UINT8,
909 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000910 qinfo = [
911 TosaQuantGen.getZeroPoint(self, ifm.dtype),
912 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
913 ]
Les Bell0e027d42021-11-09 14:42:14 +0000914
915 # Invalidate Input/Output list for error_if checks.
916 input_list = [ifm.name, filter.name, bias.name]
917 output_list = [result_tens.name]
918 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000919 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
920 self, error_name, input_list, output_list
921 )
Les Bell0e027d42021-11-09 14:42:14 +0000922
Les Bell729b0352021-11-24 10:28:21 +0000923 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000924 self.ser,
925 validator_fcns,
926 error_name,
927 op=op,
928 input_dtype=ifm.dtype,
929 weight_dtype=filter.dtype,
930 output_dtype=result_tens.dtype,
931 qinfo=qinfo,
932 input_list=input_list,
933 num_operands=num_operands,
934 output_list=output_list,
935 pad=padding,
936 stride=strides,
937 dilation=dilations,
938 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100939 weight_shape=filter.shape,
940 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000941 ):
942 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700943
944 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000945 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700946
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000947 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700948 return result_tens
949
Kevin Cheng550ccc52021-03-03 11:21:43 -0800950 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000951 self,
952 op,
953 ifm,
954 filter,
955 bias,
James Ward8b390432022-08-12 20:48:56 +0100956 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000957 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700958 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000959 output_shape,
960 validator_fcns=None,
961 error_name=None,
962 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800963 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700964 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000965 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100966 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000967 )
Les Bell0e027d42021-11-09 14:42:14 +0000968
969 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000970 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
971 DType.INT8,
972 DType.UINT8,
973 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000974 qinfo = [
975 TosaQuantGen.getZeroPoint(self, ifm.dtype),
976 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
977 ]
Les Bell0e027d42021-11-09 14:42:14 +0000978
979 # Invalidate Input/Output list for error_if checks.
980 input_list = [ifm.name, filter.name, bias.name]
981 output_list = [result_tens.name]
982 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000983 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
984 self, error_name, input_list, output_list
985 )
Les Bell0e027d42021-11-09 14:42:14 +0000986
Les Bell729b0352021-11-24 10:28:21 +0000987 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000988 self.ser,
989 validator_fcns,
990 error_name,
991 op=op,
992 input_dtype=ifm.dtype,
993 weight_dtype=filter.dtype,
994 output_dtype=result_tens.dtype,
995 qinfo=qinfo,
996 input_list=input_list,
997 num_operands=num_operands,
998 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700999 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001000 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001001 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001002 weight_shape=filter.shape,
1003 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001004 ):
1005 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001006
1007 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001008 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07001009
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001010 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001011 return result_tens
1012
Kevin Cheng550ccc52021-03-03 11:21:43 -08001013 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001014 self,
1015 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001016 inputs,
1017 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001018 validator_fcns=None,
1019 error_name=None,
1020 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001021 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001022 assert len(inputs) == 3
1023 ifm, filter, bias = inputs
1024 accum_dtype = args_dict["acc_type"]
1025 strides = args_dict["stride"]
1026 padding = args_dict["pad"]
1027 dilations = args_dict["dilation"]
1028
Kevin Cheng550ccc52021-03-03 11:21:43 -08001029 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001030 self.ser,
1031 self.rng,
1032 ifm,
1033 filter,
1034 accum_dtype,
1035 strides,
1036 padding,
1037 dilations,
1038 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001039 )
1040
1041 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001042 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1043 DType.INT8,
1044 DType.UINT8,
1045 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001046 qinfo = [
1047 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1048 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1049 ]
Les Bell0e027d42021-11-09 14:42:14 +00001050
1051 # Invalidate Input/Output list for error_if checks.
1052 input_list = [ifm.name, filter.name, bias.name]
1053 output_list = [result_tens.name]
1054 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001055 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1056 self, error_name, input_list, output_list
1057 )
Les Bell0e027d42021-11-09 14:42:14 +00001058
Les Bell729b0352021-11-24 10:28:21 +00001059 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001060 self.ser,
1061 validator_fcns,
1062 error_name,
1063 op=op,
1064 input_dtype=ifm.dtype,
1065 weight_dtype=filter.dtype,
1066 output_dtype=result_tens.dtype,
1067 qinfo=qinfo,
1068 input_list=input_list,
1069 num_operands=num_operands,
1070 output_list=output_list,
1071 pad=padding,
1072 stride=strides,
1073 dilation=dilations,
1074 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001075 weight_shape=filter.shape,
1076 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001077 ):
1078 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001079
1080 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001081 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07001082
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001083 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001084 return result_tens
1085
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001086 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001087 self,
1088 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001089 inputs,
1090 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001091 validator_fcns=None,
1092 error_name=None,
1093 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001094 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001095 assert len(inputs) == 3
1096 ifm, filter, bias = inputs
1097 accum_dtype = args_dict["acc_type"]
1098
1099 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001100 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001101 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001102
1103 # Invalidate Input/Output list for error if checks.
1104 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001105 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001106 pCount, cCount = op["operands"]
1107 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001108 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1109 self, error_name, input_list, output_list
1110 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001111
Les Bell729b0352021-11-24 10:28:21 +00001112 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001113 self.ser,
1114 validator_fcns,
1115 error_name,
1116 op=op,
1117 input_shape=ifm.shape,
1118 input_dtype=ifm.dtype,
1119 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001120 output_shape=result_tensor.shape,
1121 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001122 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001123 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001124 input_list=input_list,
1125 output_list=output_list,
1126 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001127 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001128 ):
1129 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001130
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001131 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001132 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001133
1134 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001135
1136 compliance = self.tensorComplianceMetaData(
1137 op, ifm.dtype, args_dict, result_tensor, error_name
1138 )
1139
1140 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001141
James Ward8b390432022-08-12 20:48:56 +01001142 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001143 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001144 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001145 assert len(inputs) == 2
1146 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001147 accum_dtype = args_dict["acc_type"]
1148 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001149 self.ser, self.rng, a, b, accum_dtype, error_name
1150 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001151
1152 # Invalidate Input/Output list for error if checks.
1153 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001154 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001155 pCount, cCount = op["operands"]
1156 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001157 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1158 self, error_name, input_list, output_list
1159 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001160
Les Bell729b0352021-11-24 10:28:21 +00001161 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001162 self.ser,
1163 validator_fcns,
1164 error_name,
1165 op=op,
1166 input_shape=a.shape,
1167 input_dtype=a.dtype,
1168 input2_shape=b.shape,
1169 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001170 output_shape=result_tensor.shape,
1171 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001172 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001173 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001174 input_list=input_list,
1175 output_list=output_list,
1176 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001177 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001178 ):
1179 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001180
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001181 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001182 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001183
1184 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001185
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001186 compliance = self.tensorComplianceMetaData(
1187 op, a.dtype, args_dict, result_tensor, error_name
1188 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001189
1190 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001191
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001192 def build_reduce(
1193 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1194 ):
1195 assert len(inputs) == 1
1196 a = inputs[0]
1197 axis = args_dict["axis"]
1198 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001199
1200 # Invalidate Input/Output list for error if checks.
1201 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001202 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001203 pCount, cCount = op["operands"]
1204 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001205 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1206 self, error_name, input_list, output_list
1207 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001208
Les Bell729b0352021-11-24 10:28:21 +00001209 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001210 self.ser,
1211 validator_fcns,
1212 error_name,
1213 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001214 axis=axis,
1215 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001216 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001217 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001218 output_dtype=result_tensor.dtype,
1219 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001220 input_list=input_list,
1221 output_list=output_list,
1222 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001223 ):
1224 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001225
1226 attr = ts.TosaSerializerAttribute()
1227 attr.AxisAttribute(axis)
1228
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001229 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001230
1231 if op["op"] == Op.REDUCE_PRODUCT:
1232 # TODO: Add compliance support!
1233 compliance = None
1234 else:
1235 compliance = self.tensorComplianceMetaData(
1236 op, a.dtype, args_dict, result_tensor, error_name
1237 )
1238
1239 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001240
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001241 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1242 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001243
Jeremy Johnson18e26662021-07-22 16:15:29 +01001244 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001245
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001246 if error_name == ErrorIf.MaxSmallerMin:
1247 # Make sure the numbers are different to invoke this error
1248 while v[0] == v[1]:
1249 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1250 max_val = min(v)
1251 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001252 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001253 max_val = max(v)
1254 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001255
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001256 # Invalidate Input/Output list for error if checks.
1257 input_list = [a.name]
1258 output_list = [result_tens.name]
1259 pCount, cCount = op["operands"]
1260 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001261 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1262 self, error_name, input_list, output_list
1263 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001264
Les Bell729b0352021-11-24 10:28:21 +00001265 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001266 self.ser,
1267 validator_fcns,
1268 error_name,
1269 op=op,
1270 max_val=max_val,
1271 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001272 input_shape=a.shape,
1273 output_shape=result_tens.shape,
1274 input_dtype=a.dtype,
1275 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001276 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001277 input_list=input_list,
1278 output_list=output_list,
1279 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001280 ):
1281 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001282
1283 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001284 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1285 if a.dtype == DType.FP16:
1286 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1287 min_val = min_val.astype(np.float32)
1288 max_val = max_val.astype(np.float32)
1289
1290 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001291 else:
James Ward34071252022-12-07 15:48:47 +00001292 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001293
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001294 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001295 return result_tens
1296
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001297 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1298 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001299 attr = ts.TosaSerializerAttribute()
1300
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001301 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001302
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001303 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001304 return result_tens
1305
1306 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001307 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1308 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001309
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001311 return result_tens
1312
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001313 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1314 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1315
1316 # Invalidate Input/Output list for error if checks.
1317 input_list = [a.name]
1318 output_list = [result_tens.name]
1319 pCount, cCount = op["operands"]
1320 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001321 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1322 self, error_name, input_list, output_list
1323 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001324
Les Bell729b0352021-11-24 10:28:21 +00001325 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001326 self.ser,
1327 validator_fcns,
1328 error_name,
1329 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001330 input_shape=a.shape,
1331 output_shape=result_tens.shape,
1332 input_dtype=a.dtype,
1333 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001334 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001335 input_list=input_list,
1336 output_list=output_list,
1337 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001338 ):
1339 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001340
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001341 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001342 return result_tens
1343
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001344 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1345 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1346
1347 # Invalidate Input/Output list for error if checks.
1348 input_list = [a.name]
1349 output_list = [result_tens.name]
1350 pCount, cCount = op["operands"]
1351 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001352 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1353 self, error_name, input_list, output_list
1354 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001355
Les Bell729b0352021-11-24 10:28:21 +00001356 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001357 self.ser,
1358 validator_fcns,
1359 error_name,
1360 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001361 input_shape=a.shape,
1362 output_shape=result_tens.shape,
1363 input_dtype=a.dtype,
1364 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001365 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001366 input_list=input_list,
1367 output_list=output_list,
1368 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001369 ):
1370 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001371
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001372 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001373 return result_tens
1374
Won Jeon78155c62023-06-10 00:20:04 +00001375 def build_erf(self, op, a, validator_fcns=None, error_name=None):
1376 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1377
1378 # Invalidate Input/Output list for error if checks.
1379 input_list = [a.name]
1380 output_list = [result_tens.name]
1381 pCount, cCount = op["operands"]
1382 num_operands = pCount + cCount
1383 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1384 self, error_name, input_list, output_list
1385 )
1386
1387 if not TosaErrorValidator.evValidateErrorIfs(
1388 self.ser,
1389 validator_fcns,
1390 error_name,
1391 op=op,
1392 input_shape=a.shape,
1393 output_shape=result_tens.shape,
1394 input_dtype=a.dtype,
1395 output_dtype=result_tens.dtype,
1396 result_tensors=[result_tens],
1397 input_list=input_list,
1398 output_list=output_list,
1399 num_operands=num_operands,
1400 ):
1401 return None
1402
1403 self.ser.addOperator(op["op"], input_list, output_list)
1404 return result_tens
1405
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001406 def build_concat(
1407 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1408 ):
1409 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001410 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001411 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001412
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001413 result_tensor = OutputShaper.concatOp(
1414 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001415 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001416
Matthew Haddon818ab902021-07-27 09:12:49 +01001417 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001418 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001419 input_tensor_names.append(tensor.name)
1420
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001421 # Invalidate Input/Output list for error if checks.
1422 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001423 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424 pCount, cCount = op["operands"]
1425 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001426 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1427 self, error_name, input_list, output_list
1428 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001429
Les Bell729b0352021-11-24 10:28:21 +00001430 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001431 self.ser,
1432 validator_fcns,
1433 error_name,
1434 op=op,
1435 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001436 input_shape=inputs[0].shape,
1437 output_shape=result_tensor.shape,
1438 input_dtype=inputs[0].dtype,
1439 output_dtype=result_tensor.dtype,
1440 inputs=inputs,
1441 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001442 input_list=input_list,
1443 output_list=output_list,
1444 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001445 ):
1446 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001447
1448 attr = ts.TosaSerializerAttribute()
1449 attr.AxisAttribute(axis)
1450
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001451 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001452 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001453
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001454 def build_pad(
1455 self,
1456 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001457 inputs,
1458 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001459 validator_fcns=None,
1460 error_name=None,
1461 qinfo=None,
1462 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001463 assert len(inputs) == 1
1464 a = inputs[0]
1465 padding = args_dict["pad"]
1466 pad_const_int = args_dict["pad_const_int"]
1467 pad_const_float = args_dict["pad_const_fp"]
1468
1469 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001470
Kevin Chengfe392ce2021-10-18 21:51:55 +00001471 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001472 attr.PadAttribute(
1473 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1474 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001475
Matthew Haddone807aae2021-10-11 18:12:58 +01001476 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001477 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001478 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001479 pCount, cCount = op["operands"]
1480 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001481 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1482 self, error_name, input_list, output_list
1483 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001484
Les Bell729b0352021-11-24 10:28:21 +00001485 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001486 self.ser,
1487 validator_fcns,
1488 error_name,
1489 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001490 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001491 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001492 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001493 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001494 pad=padding,
1495 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001496 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001497 input_list=input_list,
1498 output_list=output_list,
1499 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001500 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001501 ):
1502 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001503
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001504 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001505
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001506 compliance = self.tensorComplianceMetaData(
1507 op, a.dtype, args_dict, result_tensor, error_name
1508 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001509
1510 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001511
Won Jeona21b2e82023-08-10 10:33:01 +00001512 def build_dim(
1513 self,
1514 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001515 inputs,
1516 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001517 validator_fcns=None,
1518 error_name=None,
1519 qinfo=None,
1520 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001521 assert len(inputs) == 1
1522 a = inputs[0]
1523 axis = args_dict["axis"]
1524 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001525
1526 # Invalidate Input/Output list for error if checks.
1527 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001528 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001529 pCount, cCount = op["operands"]
1530 num_operands = pCount + cCount
1531 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1532 self, error_name, input_list, output_list
1533 )
1534
1535 if not TosaErrorValidator.evValidateErrorIfs(
1536 self.ser,
1537 validator_fcns,
1538 error_name,
1539 op=op,
1540 axis=axis,
1541 input_shape=a.shape,
1542 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001543 output_shape=result_tensor.shape,
1544 output_dtype=result_tensor.dtype,
1545 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001546 input_list=input_list,
1547 output_list=output_list,
1548 num_operands=num_operands,
1549 ):
1550 return None
1551
1552 attr = ts.TosaSerializerAttribute()
1553 attr.AxisAttribute(axis)
1554
1555 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001556 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001557
Matthew Haddone807aae2021-10-11 18:12:58 +01001558 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001559 result_tens = OutputShaper.reshapeOp(
1560 self.ser, self.rng, a, newShape, error_name
1561 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001562
1563 # Invalidate Input/Output list for error if checks.
1564 input_list = [a.name]
1565 output_list = [result_tens.name]
1566 pCount, cCount = op["operands"]
1567 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001568 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1569 self, error_name, input_list, output_list
1570 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001571
Les Bell729b0352021-11-24 10:28:21 +00001572 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001573 self.ser,
1574 validator_fcns,
1575 error_name,
1576 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001577 input_shape=a.shape,
1578 output_shape=result_tens.shape,
1579 input_dtype=a.dtype,
1580 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001581 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001582 input_list=input_list,
1583 output_list=output_list,
1584 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001585 ):
1586 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001587
1588 attr = ts.TosaSerializerAttribute()
1589 attr.ReshapeAttribute(newShape)
1590
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001591 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001592 return result_tens
1593
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001594 def build_reverse(
1595 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1596 ):
1597 assert len(inputs) == 1
1598 a = inputs[0]
1599 axis = args_dict["axis"]
1600 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001601
1602 # Invalidate Input/Output list for error if checks.
1603 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001604 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001605 pCount, cCount = op["operands"]
1606 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001607 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1608 self, error_name, input_list, output_list
1609 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001610
Les Bell729b0352021-11-24 10:28:21 +00001611 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001612 self.ser,
1613 validator_fcns,
1614 error_name,
1615 op=op,
1616 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001617 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001618 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001619 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001620 output_dtype=result_tensor.dtype,
1621 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001622 input_list=input_list,
1623 output_list=output_list,
1624 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001625 ):
1626 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001627
1628 attr = ts.TosaSerializerAttribute()
1629 attr.AxisAttribute(axis)
1630
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001631 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001632 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001633
Matthew Haddone807aae2021-10-11 18:12:58 +01001634 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1635 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001636
Kevin Chengfe392ce2021-10-18 21:51:55 +00001637 attr = ts.TosaSerializerAttribute()
1638 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001639
Matthew Haddone807aae2021-10-11 18:12:58 +01001640 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001641 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001642 output_list = [result_tens.name]
1643 pCount, cCount = op["operands"]
1644 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001645 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1646 self, error_name, input_list, output_list
1647 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001648
Les Bell729b0352021-11-24 10:28:21 +00001649 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001650 self.ser,
1651 validator_fcns,
1652 error_name,
1653 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001654 input_shape=a.shape,
1655 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001656 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001657 input_dtype=a.dtype,
1658 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001659 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001660 input_list=input_list,
1661 output_list=output_list,
1662 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001663 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001664 ):
1665 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001666
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001667 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001668 return result_tens
1669
Matthew Haddone807aae2021-10-11 18:12:58 +01001670 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001671 result_tens = OutputShaper.sliceOp(
1672 self.ser, self.rng, a, start, size, error_name
1673 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001674
1675 # Invalidate Input/Output list for error if checks.
1676 input_list = [a.name]
1677 output_list = [result_tens.name]
1678 pCount, cCount = op["operands"]
1679 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001680 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1681 self, error_name, input_list, output_list
1682 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001683
Les Bell729b0352021-11-24 10:28:21 +00001684 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001685 self.ser,
1686 validator_fcns,
1687 error_name,
1688 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001689 input_shape=a.shape,
1690 output_shape=result_tens.shape,
1691 input_dtype=a.dtype,
1692 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001693 start=start,
1694 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001695 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001696 input_list=input_list,
1697 output_list=output_list,
1698 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001699 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001700 ):
1701 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001702
1703 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001704 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001705
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001706 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001707 return result_tens
1708
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001709 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1710 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1711
1712 # Invalidate Input/Output list for error if checks.
1713 input_list = [a.name]
1714 output_list = [result_tens.name]
1715 pCount, cCount = op["operands"]
1716 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001717 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1718 self, error_name, input_list, output_list
1719 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001720
Les Bell729b0352021-11-24 10:28:21 +00001721 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001722 self.ser,
1723 validator_fcns,
1724 error_name,
1725 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001726 input_shape=a.shape,
1727 output_shape=result_tens.shape,
1728 input_dtype=a.dtype,
1729 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001730 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001731 input_list=input_list,
1732 output_list=output_list,
1733 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001734 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001735 ):
1736 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001737
1738 attr = ts.TosaSerializerAttribute()
1739 attr.TileAttribute(multiples)
1740
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001741 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001742 return result_tens
1743
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001744 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001745
1746 # Create a new indicies tensor
1747 # here with data that doesn't exceed the dimensions of the values tensor
1748
Kevin Cheng550ccc52021-03-03 11:21:43 -08001749 K = values.shape[1] # K
1750 W = self.randInt(
1751 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1752 ) # W
1753 indicies_arr = np.int32(
1754 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1755 ) # (N, W)
1756 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001757
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001758 result_tens = OutputShaper.gatherOp(
1759 self.ser, self.rng, values, indicies, error_name
1760 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001761
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001762 # Invalidate Input/Output list for error if checks.
1763 input_list = [values.name, indicies.name]
1764 output_list = [result_tens.name]
1765 pCount, cCount = op["operands"]
1766 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001767 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1768 self, error_name, input_list, output_list
1769 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001770
Les Bell729b0352021-11-24 10:28:21 +00001771 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001772 self.ser,
1773 validator_fcns,
1774 error_name,
1775 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001776 input_shape=values.shape,
1777 output_shape=result_tens.shape,
1778 input_dtype=values.dtype,
1779 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001780 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001781 input_list=input_list,
1782 output_list=output_list,
1783 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001784 ):
1785 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001786
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001787 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001788
1789 return result_tens
1790
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001791 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001792
1793 # Create a new indicies tensor
1794 # here with data that doesn't exceed the dimensions of the values_in tensor
1795
Kevin Cheng550ccc52021-03-03 11:21:43 -08001796 K = values_in.shape[1] # K
1797 W = input.shape[1] # W
1798 indicies_arr = np.int32(
1799 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1800 ) # (N, W)
1801 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001802
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001803 result_tens = OutputShaper.scatterOp(
1804 self.ser, self.rng, values_in, indicies, input, error_name
1805 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001806
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001807 # Invalidate Input/Output list for error if checks.
1808 input_list = [values_in.name, indicies.name, input.name]
1809 output_list = [result_tens.name]
1810 pCount, cCount = op["operands"]
1811 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001812 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1813 self, error_name, input_list, output_list
1814 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001815
Les Bell729b0352021-11-24 10:28:21 +00001816 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001817 self.ser,
1818 validator_fcns,
1819 error_name,
1820 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001821 input_shape=values_in.shape,
1822 output_shape=result_tens.shape,
1823 input_dtype=values_in.dtype,
1824 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001825 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001826 input_list=input_list,
1827 output_list=output_list,
1828 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001829 ):
1830 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001831
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001832 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001833
Kevin Cheng77d0f762020-11-24 10:26:32 -08001834 return result_tens
1835
Kevin Cheng550ccc52021-03-03 11:21:43 -08001836 def build_resize(
1837 self,
1838 op,
1839 input,
1840 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001841 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001842 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001843 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001844 input_dtype,
1845 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001846 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001847 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001848 ):
1849 result_tens = OutputShaper.resizeOp(
1850 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001851 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001852 input,
1853 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001854 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001855 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001856 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001857 input_dtype,
1858 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001859 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001860 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001861
Matthew Haddon848efb42021-09-09 12:30:53 +01001862 # Invalidate Input/Output list for error if checks.
1863 input_list = [input.name]
1864 output_list = [result_tens.name]
1865 pCount, cCount = op["operands"]
1866 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001867 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1868 self, error_name, input_list, output_list
1869 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001870
Les Bell729b0352021-11-24 10:28:21 +00001871 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001872 self.ser,
1873 validator_fcns,
1874 error_name,
1875 op=op,
1876 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001877 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001878 input_dtype=input_dtype,
1879 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001880 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001881 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001882 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001883 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001884 input_list=input_list,
1885 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001886 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001887 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001888 ):
1889 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001890
Eric Kunzee5e26762020-10-13 16:11:07 -07001891 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001892
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001893 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001894
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001895 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001896 return result_tens
1897
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001898 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1899 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1900 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001901 self.ser.addOperator(
1902 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1903 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001904 return result_tens
1905
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001906 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001907 self.ser.addOutputTensor(val)
1908 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001909
1910 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001911 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001912 result_tens = OutputShaper.typeConversionOp(
1913 self.ser, self.rng, val, out_dtype, error_name
1914 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001915
1916 # Invalidate Input/Output list for error if checks.
1917 input_list = [val.name]
1918 output_list = [result_tens.name]
1919 pCount, cCount = op["operands"]
1920 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001921 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1922 self, error_name, input_list, output_list
1923 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001924
Les Bell729b0352021-11-24 10:28:21 +00001925 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001926 self.ser,
1927 validator_fcns,
1928 error_name,
1929 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001930 input_shape=val.shape,
1931 output_shape=result_tens.shape,
1932 input_dtype=val.dtype,
1933 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001934 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001935 input_list=input_list,
1936 output_list=output_list,
1937 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001938 ):
1939 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001940
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001941 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001942 return result_tens
1943
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001944 def build_rescale(
1945 self,
1946 op,
1947 val,
1948 out_dtype,
1949 scale32,
1950 double_round,
1951 per_channel,
1952 validator_fcns,
1953 error_name,
1954 ):
1955 result_tens = OutputShaper.typeConversionOp(
1956 self.ser, self.rng, val, out_dtype, error_name
1957 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001958
1959 if per_channel:
1960 nc = val.shape[-1]
1961 else:
1962 nc = 1
1963
1964 in_type_width = self.typeWidth(val.dtype)
1965 out_type_width = self.typeWidth(out_dtype)
1966
Kevin Cheng3a478572021-01-22 17:21:02 -08001967 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001968 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001969 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001970 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001971 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001972 in_type_width += 1
1973 elif error_name in [
1974 ErrorIf.InputZeroPointNotZero,
1975 ErrorIf.U16InputZeroPointNotValid,
1976 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001977 input_zp = self.randInt(-128, 128)
1978 if input_zp == 0:
1979 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001980 in_type_width += 1
1981 elif val.dtype == DType.UINT16:
1982 # Must come after ErrorIf.U16InputZeroPointNotValid check
1983 input_zp = self.rng.choice([0, 32768])
1984 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001985 else:
1986 input_zp = 0
1987
Kevin Cheng3a478572021-01-22 17:21:02 -08001988 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001989 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001990 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001991 elif out_dtype == DType.UINT8:
1992 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001993 out_type_width += 1
1994 elif error_name in [
1995 ErrorIf.OutputZeroPointNotZero,
1996 ErrorIf.U16OutputZeroPointNotValid,
1997 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001998 output_zp = self.randInt(-128, 128)
1999 if output_zp == 0:
2000 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002001 out_type_width += 1
2002 elif out_dtype == DType.UINT16:
2003 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2004 output_zp = self.rng.choice([0, 32768])
2005 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002006 else:
2007 output_zp = 0
2008
2009 # Calculate scale based on:
2010 # scale = a *(2^output_width)/(2^input_width))
2011
2012 a = np.float32(self.rng.random(size=[nc]))
2013 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2014
2015 if scale32:
2016 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002017 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002018 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2019 else:
2020 # Cap the scaling at 2^15 - 1 for scale16
2021 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2022
Kevin Cheng550ccc52021-03-03 11:21:43 -08002023 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002024
2025 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2026 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002027 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2028 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002029
2030 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002031 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2032 scale_arr[i], scale32
2033 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002034 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2035 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002036
Kevin Cheng550ccc52021-03-03 11:21:43 -08002037 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002038 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002039 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002040 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002041 assert val.placeholderFilename
2042 values = np.load(
2043 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2044 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002045 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2046 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2047 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2048 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002049 if not np.all(np.array_equal(values, val_adj)):
2050 # Values changed so overwrite file with new values
2051 np.save(
2052 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2053 val_adj,
2054 False,
2055 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002056
Matthew Haddonc2025212021-10-08 21:21:05 +01002057 # Invalidate Input/Output list for error if checks.
2058 input_list = [val.name]
2059 output_list = [result_tens.name]
2060 pCount, cCount = op["operands"]
2061 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002062 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2063 self, error_name, input_list, output_list
2064 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002065
2066 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002067 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002068 self.ser,
2069 validator_fcns,
2070 error_name,
2071 op=op,
2072 input_dtype=val.dtype,
2073 output_dtype=out_dtype,
2074 input_shape=val.shape,
2075 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002076 scale32=scale32,
2077 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002078 input_list=input_list,
2079 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002080 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002081 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002082 ):
2083 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002084
Eric Kunzee5e26762020-10-13 16:11:07 -07002085 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002086 attr.RescaleAttribute(
2087 input_zp,
2088 output_zp,
2089 multiplier_arr,
2090 shift_arr,
2091 scale32,
2092 double_round,
2093 per_channel,
2094 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002095
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002096 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002097 return result_tens
2098
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002099 def _get_condition_tensor(self, op, cond, error_name):
2100 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002101 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002102 else:
2103 cond_type = DType.BOOL
2104 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2105 choice = self.rng.choice([1, 2])
2106 if choice == 1:
2107 cond_shape = [2]
2108 else:
2109 cond_shape = [1, 2]
2110 else:
2111 # Must be of size 1 (rank 0)
2112 cond_shape = []
2113 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2114 return cond_tens
2115
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002116 def build_cond_if_const(
2117 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2118 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002119 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002120 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002121 # and fill them with const nodes for the body.
2122
2123 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002124 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002125
2126 # Make then/else tensors
2127 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002128
2129 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002130 if error_name in [
2131 ErrorIf.CondIfOutputListThenGraphMismatch,
2132 ErrorIf.CondIfOutputListElseGraphMismatch,
2133 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002134 incorrect_shape = deepcopy(then_tens.shape)
2135 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002136 incorrect_shape[i] += (
2137 self.rng.choice([-3, -2, 2, 3])
2138 if incorrect_shape[i] > 3
2139 else self.rng.choice([1, 2, 4])
2140 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002141 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2142
Jeremy Johnson18e26662021-07-22 16:15:29 +01002143 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2144 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002145
2146 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002147 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002148
2149 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002150 then_block = "THEN_BLOCK"
2151 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002152 attr = ts.TosaSerializerAttribute()
2153 attr.CondIfAttribute(then_block, else_block)
2154
2155 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002156 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002157
Jerry Ge9e94af82022-10-27 09:57:00 -07002158 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002159 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002160 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2161 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2162 else:
2163 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002164 self.ser.addOutputTensor(then_tens)
2165
Jerry Ge9e94af82022-10-27 09:57:00 -07002166 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002167 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2168 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2169 else:
2170 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002171 self.ser.addOutputTensor(else_tens)
2172
Les Bell729b0352021-11-24 10:28:21 +00002173 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002174 self.ser,
2175 validator_fcns,
2176 error_name,
2177 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002178 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002179 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002180 ):
2181 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002182
Eric Kunzee5e26762020-10-13 16:11:07 -07002183 return result_tens
2184
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002185 def build_cond_if_binary(
2186 self, op, a, b, cond, validator_fcns=None, error_name=None
2187 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002188 # For cond_if with a binary op in the then/else blocks, take a and b and
2189 # alternately add or subtract them based on the condition
2190
2191 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002192 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002193
Kevin Cheng550ccc52021-03-03 11:21:43 -08002194 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002195
2196 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002197 then_block = "THEN_BLOCK"
2198 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002199 attr = ts.TosaSerializerAttribute()
2200 attr.CondIfAttribute(then_block, else_block)
2201
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002202 if error_name in [
2203 ErrorIf.CondIfInputListThenGraphMismatch,
2204 ErrorIf.CondIfInputListElseGraphMismatch,
2205 ErrorIf.CondIfOutputListElseGraphMismatch,
2206 ErrorIf.CondIfOutputListThenGraphMismatch,
2207 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002208 incorrect_shape = a.shape.copy()
2209 for i in range(len(incorrect_shape)):
2210 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2211 incorrect_block_input = deepcopy(a)
2212 incorrect_block_input.shape = incorrect_shape
2213
Eric Kunzee5e26762020-10-13 16:11:07 -07002214 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002215 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002216 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002217 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002218
James Ward24dbc422022-10-19 12:20:31 +01002219 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002220 then_op, else_op = Op.ADD, Op.SUB
2221 elif a.dtype in (DType.INT8, DType.INT16):
2222 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2223 else:
2224 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002225
Les Bell6040b4d2021-10-11 12:50:31 +01002226 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002227 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002228 if (
2229 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2230 and block == then_block
2231 ) or (
2232 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2233 and block == else_block
2234 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002235 self.ser.addInputTensor(incorrect_block_input)
2236 self.ser.addInputTensor(b)
2237 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002238 elif (
2239 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2240 and block == then_block
2241 ) or (
2242 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2243 and block == else_block
2244 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002245 self.ser.addInputTensor(a)
2246 self.ser.addInputTensor(b)
2247 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2248 else:
2249 self.ser.addInputTensor(a)
2250 self.ser.addInputTensor(b)
2251 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002252 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002253
Les Bell729b0352021-11-24 10:28:21 +00002254 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002255 self.ser,
2256 validator_fcns,
2257 error_name,
2258 op=op,
2259 a=a,
2260 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002261 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002262 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002263 ):
2264 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002265
Eric Kunzee5e26762020-10-13 16:11:07 -07002266 return result_tens
2267
Matthew Haddon630c17c2021-10-14 15:05:41 +01002268 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002269 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002270
Kevin Cheng550ccc52021-03-03 11:21:43 -08002271 cond_block = "COND_BLOCK"
2272 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002273
2274 attr = ts.TosaSerializerAttribute()
2275 attr.WhileLoopAttribute(cond_block, body_block)
2276
2277 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002278 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002279 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002280 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002281
2282 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002283 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2284 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002285 if error_name == ErrorIf.InputListOutputListMismatch:
2286 incorrect_acc = deepcopy(acc)
2287 for i in range(len(incorrect_acc.shape)):
2288 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2289 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2290 else:
2291 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002292
2293 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002294 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002295 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002296 [iter.name, a.name, acc.name],
2297 [iter_out.name, a_out.name, acc_out.name],
2298 attr,
2299 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002300 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002301
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002302 if error_name in [
2303 ErrorIf.InputListCondGraphMismatch,
2304 ErrorIf.InputListBodyGraphInputMismatch,
2305 ErrorIf.InputListBodyGraphOutputMismatch,
2306 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002307 incorrect_iter = deepcopy(iter)
2308 for i in range(len(incorrect_iter.shape)):
2309 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2310 if len(incorrect_iter.shape) == 0:
2311 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2312
2313 incorrect_acc = deepcopy(acc)
2314 for i in range(len(incorrect_acc.shape)):
2315 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2316
Eric Kunzee5e26762020-10-13 16:11:07 -07002317 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002318 self.ser.addBasicBlock(cond_block)
2319
Matthew Haddon630c17c2021-10-14 15:05:41 +01002320 if error_name == ErrorIf.InputListCondGraphMismatch:
2321 self.ser.addInputTensor(incorrect_iter)
2322 self.ser.addInputTensor(a)
2323 self.ser.addInputTensor(incorrect_acc)
2324 else:
2325 self.ser.addInputTensor(iter)
2326 self.ser.addInputTensor(a)
2327 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002328 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002329
2330 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002331 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002332 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002333 cond_type = DType.BOOL
2334 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2335 choice = self.rng.choice([1, 2])
2336 if choice == 1:
2337 cond_shape = [3]
2338 else:
2339 cond_shape = [1, 2]
2340 else:
2341 cond_shape = []
2342 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002343
Kevin Cheng550ccc52021-03-03 11:21:43 -08002344 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002345
2346 # BODY block (input: a, acc, iter, output: a, acc, iter)
2347 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002348 self.ser.addBasicBlock(body_block)
2349
Matthew Haddon630c17c2021-10-14 15:05:41 +01002350 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2351 self.ser.addInputTensor(incorrect_iter)
2352 self.ser.addInputTensor(a)
2353 self.ser.addInputTensor(incorrect_acc)
2354 else:
2355 self.ser.addInputTensor(iter)
2356 self.ser.addInputTensor(a)
2357 self.ser.addInputTensor(acc)
2358
Kevin Cheng550ccc52021-03-03 11:21:43 -08002359 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002360
2361 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002362 iter_body_out = self.ser.addIntermediate(
2363 incorrect_iter.shape, incorrect_iter.dtype
2364 )
2365 acc_body_out = self.ser.addIntermediate(
2366 incorrect_acc.shape, incorrect_acc.dtype
2367 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002368 else:
2369 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2370 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2371
Eric Kunzee5e26762020-10-13 16:11:07 -07002372 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2373 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2374 self.ser.addOutputTensor(iter_body_out)
2375 self.ser.addOutputTensor(a)
2376 self.ser.addOutputTensor(acc_body_out)
2377
Les Bell729b0352021-11-24 10:28:21 +00002378 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002379 self.ser,
2380 validator_fcns,
2381 error_name,
2382 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002383 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002384 ):
2385 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002386
Eric Kunzee5e26762020-10-13 16:11:07 -07002387 return acc_out
2388
Luke Hutton57287132023-02-06 14:54:18 +00002389 def build_fft2d(
2390 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2391 ):
2392 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2393
2394 input_names = [val1.name, val2.name]
2395 pCount, cCount = op["operands"]
2396 num_operands = pCount + cCount
2397
2398 output_names = [res.name for res in results]
2399 output_shapes = [res.shape for res in results]
2400 output_dtypes = [res.dtype for res in results]
2401
2402 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2403 self, error_name, input_names, output_names
2404 )
2405
2406 if not TosaErrorValidator.evValidateErrorIfs(
2407 self.ser,
2408 validator_fcns,
2409 error_name,
2410 op=op,
2411 inverse=inverse,
2412 input1=val1,
2413 input2=val2,
2414 input_shape=val1.shape,
2415 input_dtype=val1.dtype,
2416 output_shape=output_shapes,
2417 output_dtype=output_dtypes,
2418 result_tensors=results,
2419 input_list=input_names,
2420 output_list=output_names,
2421 num_operands=num_operands,
2422 ):
2423 return None
2424
2425 attr = ts.TosaSerializerAttribute()
2426 attr.FFTAttribute(inverse)
2427
2428 self.ser.addOperator(op["op"], input_names, output_names, attr)
2429 return results
2430
Luke Hutton261b7b62023-01-10 14:50:31 +00002431 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2432 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2433
2434 input_names = [val.name]
2435 pCount, cCount = op["operands"]
2436 num_operands = pCount + cCount
2437
2438 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002439 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002440 output_dtypes = [res.dtype for res in results]
2441
2442 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2443 self, error_name, input_names, output_names
2444 )
2445
2446 if not TosaErrorValidator.evValidateErrorIfs(
2447 self.ser,
2448 validator_fcns,
2449 error_name,
2450 op=op,
2451 input_shape=val.shape,
2452 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002453 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002454 output_dtype=output_dtypes,
2455 result_tensors=results,
2456 input_list=input_names,
2457 output_list=output_names,
2458 num_operands=num_operands,
2459 ):
2460 return None
2461
2462 self.ser.addOperator(op["op"], input_names, output_names)
2463 return results
2464
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002465 def create_filter_lists(
2466 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2467 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002468 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2469 default_test_rank_range = range(1, 5)
2470 if not shapeFilter:
2471 shapeFilter = [None]
2472
2473 # Calculate the filters based on what is requested and what the operator allows
2474 rmin, rmax = op["rank"]
2475 if rankFilter is not None:
2476 cleanRankFilter = []
2477 # Ensure rankFilter values are allowed by operator
2478 for rank in rankFilter:
2479 if rank >= rmin and rank <= rmax:
2480 cleanRankFilter.append(rank)
2481 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002482 # Ensure default behaviour is bounded by default range or by operator,
2483 # whichever is the smaller range of ranks.
2484 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002485 cleanRankFilter = (
2486 opRankRange
2487 if len(opRankRange) <= len(default_test_rank_range)
2488 else default_test_rank_range
2489 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002490 else:
2491 cleanRankFilter = range(rmin, rmax + 1)
2492
2493 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002494
Matthew Haddon1c00b712021-10-01 15:51:03 +01002495 if dtypeFilter is not None:
2496 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002497 # Create list of operator dtypes filtered by requested dtypes
2498 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002499 if dtype in dtypeFilter or (
2500 isinstance(dtype, list) and dtype[0] in dtypeFilter
2501 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002502 cleanDtypeFilter.append(dtype)
2503 else:
2504 cleanDtypeFilter = dtypes
2505
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002506 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002507 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002508 "shapeFilter": shapeFilter,
2509 "rankFilter": cleanRankFilter,
2510 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002511 }
2512 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002513 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002514 if validator is not None:
2515 validator_info = validator(check=False, op=op)
2516 else:
2517 return None
2518
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002519 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002520
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002521 # Set parameters as required
2522 if error_arguments["rank"] is not None:
2523 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002524 else:
2525 rankFilter = cleanRankFilter
2526
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002527 if error_arguments["dtype"] is not None:
2528 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002529 else:
2530 dtypeFilter = cleanDtypeFilter
2531
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002532 if error_arguments["shape"] is not None:
2533 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002534 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002535 shapeFilter = shapeFilter[
2536 :2
2537 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002538
2539 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002540 "shapeFilter": shapeFilter,
2541 "rankFilter": rankFilter,
2542 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002543 }
2544 return filterDict
2545
Kevin Cheng550ccc52021-03-03 11:21:43 -08002546 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002547 self,
2548 opName,
2549 shapeFilter=[None],
2550 rankFilter=None,
2551 dtypeFilter=None,
2552 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002553 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002554
2555 try:
2556 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002557 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002558 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002559
2560 # Initialize a new random number generator
2561 self.rng = np.random.default_rng(self.random_seed)
2562
Jeremy Johnson1271c442023-09-05 11:39:26 +01002563 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002564
Eric Kunzee5e26762020-10-13 16:11:07 -07002565 # Test list consists of a tuple of:
2566 # (opName, testNameStr, dtype, shapeList, argumentsList)
2567 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002568 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002569 error_if_validators = op["error_if_validators"]
2570 else:
2571 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002572
Matthew Haddon1c00b712021-10-01 15:51:03 +01002573 for validator in error_if_validators:
2574 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002575 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002576 else:
2577 error_name = None
2578
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002579 filterDict = self.create_filter_lists(
2580 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2581 )
2582 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002583 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002584 cleanRankFilter = filterDict["rankFilter"]
2585 cleanDtypeFilter = filterDict["dtypeFilter"]
2586 cleanShapeFilter = filterDict["shapeFilter"]
2587 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002588
2589 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002590 for t in cleanDtypeFilter:
2591 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002592 # Filter out by rank
2593 if shape is not None and len(shape) != r:
2594 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002595 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002596 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002597
Matthew Haddon74567092021-07-16 15:38:20 +01002598 shapeStr = self.shapeStr(shapeList[0])
2599 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002600
Matthew Haddon74567092021-07-16 15:38:20 +01002601 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2602 argList = []
2603 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002604 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002605 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002606 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002607
Matthew Haddon74567092021-07-16 15:38:20 +01002608 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002609 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002610 if argStr:
2611 testStr = "{}_{}_{}_{}".format(
2612 opName, shapeStr, typeStr, argStr
2613 )
2614 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002615 testStr = "{}_{}_{}".format(
2616 opName, shapeStr, typeStr
2617 )
2618 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002619 if argStr:
2620 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2621 opName, error_name, shapeStr, typeStr, argStr
2622 )
2623 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002624 testStr = "{}_ERRORIF_{}_{}_{}".format(
2625 opName, error_name, shapeStr, typeStr
2626 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002627
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002628 testList.append(
2629 (opName, testStr, t, error_name, shapeList, args)
2630 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002631
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002632 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002633 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2634 if "invalid_test_validators" in op:
2635 invalid_test_validators = op["invalid_test_validators"]
2636 clean_testList = []
2637 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002638 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002639 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002640 if validator_fcn(
2641 opName=test[0],
2642 input_dtype=test[2],
2643 shapeList=test[4],
2644 args=test[5],
2645 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002646 remove_test = True
2647 if not remove_test:
2648 clean_testList.append(test)
2649 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002650
2651 return testList
2652
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002653 def serializeTest(
2654 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2655 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002656 try:
2657 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002658 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002659 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002660
Jeremy Johnson0c716862023-04-13 17:18:19 +01002661 if self.args.verbose:
2662 print(f"Creating {testStr}")
2663
Eric Kunzee5e26762020-10-13 16:11:07 -07002664 # Create a serializer
2665 self.createSerializer(opName, testStr)
2666
Jeremy Johnson1271c442023-09-05 11:39:26 +01002667 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002668 if "error_if_validators" in op:
2669 error_if_validators = op["error_if_validators"]
2670 else:
2671 error_if_validators = None
2672
Kevin Cheng550ccc52021-03-03 11:21:43 -08002673 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002674 num_operands = pCount + cCount
2675
2676 if isinstance(dtype_or_dtypeList, list):
2677 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002678 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002679 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002680 else:
2681 dtypeList = [dtype_or_dtypeList] * (num_operands)
2682
Kevin Cheng93a16282021-08-31 16:14:03 -07002683 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002684 assert (
2685 len(shapeList) == num_operands
2686 ), "shapeList length {} must match number of operands {}".format(
2687 len(shapeList), num_operands
2688 )
2689 assert (
2690 len(dtypeList) == num_operands
2691 ), "dtypeList length {} must match number of operands {}".format(
2692 len(dtypeList), num_operands
2693 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002694
2695 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002696 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002697 except KeyError:
2698 qgen = None
2699
2700 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002701
Matthew Haddon1c00b712021-10-01 15:51:03 +01002702 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002703 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002704 else:
2705 qinfo = None
2706
Jeremy Johnson1271c442023-09-05 11:39:26 +01002707 # Extra meta data for the desc.json
2708 tensMeta = {}
2709
2710 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002711 if isinstance(testArgs, dict):
2712 # New interface with args info in dictionary
2713 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002714 assert "dg_type" in argsDict
2715 tvgInfo = tvgen_fcn(
2716 self, opName, dtypeList, shapeList, argsDict, error_name
2717 )
2718 if tvgInfo.dataGenDict:
2719 tensMeta["data_gen"] = tvgInfo.dataGenDict
2720 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002721
2722 result = build_fcn(
2723 self,
2724 op,
2725 tens,
2726 argsDict,
2727 validator_fcns=error_if_validators,
2728 error_name=error_name,
2729 qinfo=qinfo,
2730 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002731 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002732 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002733 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002734
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002735 try:
2736 if error_if_validators is None:
2737 if qinfo is not None:
2738 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2739 else:
2740 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002741 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002742 if qinfo is not None:
2743 result = build_fcn(
2744 self,
2745 op,
2746 *tens,
2747 *testArgs,
2748 validator_fcns=error_if_validators,
2749 error_name=error_name,
2750 qinfo=qinfo,
2751 )
2752 else:
2753 result = build_fcn(
2754 self,
2755 op,
2756 *tens,
2757 *testArgs,
2758 validator_fcns=error_if_validators,
2759 error_name=error_name,
2760 )
2761 except TypeError as e:
2762 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2763 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002764
Jeremy Johnson1271c442023-09-05 11:39:26 +01002765 if result:
Les Bell729b0352021-11-24 10:28:21 +00002766 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002767 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2768 # Add the compliance meta data
2769 # NOTE: This currently expects only one result output
2770 tensMeta["compliance"] = {
2771 "version": "0.1",
2772 "tensors": {result.resultTensor.name: result.complianceDict},
2773 }
2774 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002775 else:
2776 # The test is not valid
2777 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002778
Eric Kunzee5e26762020-10-13 16:11:07 -07002779 def createDynamicOpLists(self):
2780
Jeremy Johnson00423432022-09-12 17:27:37 +01002781 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2782 # Already created these lists (can occur when class is initialized more than once)
2783 return
2784
Eric Kunzee5e26762020-10-13 16:11:07 -07002785 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002786 if not self.args.level8k:
2787 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2788 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2789 else:
2790 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2791 KERNELS_2D = [[1, bigK], [bigK, 2]]
2792 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002793
Kevin Cheng1533b852021-09-01 12:51:58 -07002794 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002795 testName = "conv2d_{}x{}".format(k[0], k[1])
2796 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2797 self.TOSA_OP_LIST[testName]["filter"] = k
2798 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002799
Kevin Cheng550ccc52021-03-03 11:21:43 -08002800 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2801 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2802 "depthwise_conv2d_TEMPLATE"
2803 ].copy()
2804 self.TOSA_OP_LIST[testName]["filter"] = k
2805 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002806
Kevin Cheng550ccc52021-03-03 11:21:43 -08002807 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2808 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2809 "transpose_conv2d_TEMPLATE"
2810 ].copy()
2811 self.TOSA_OP_LIST[testName]["filter"] = k
2812 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002813
Kevin Cheng1533b852021-09-01 12:51:58 -07002814 for k in KERNELS_3D:
2815 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2816 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2817 self.TOSA_OP_LIST[testName]["filter"] = k
2818 self.TOSA_OP_LIST[testName]["template"] = False
2819
Eric Kunzee5e26762020-10-13 16:11:07 -07002820 # Delete any templates after having created any dynamic ops
2821 # This is a two-pass operation because it's bad practice to delete
2822 # keys from dictionaries while iterating
2823 keyList = []
2824 for k in self.TOSA_OP_LIST:
2825 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002826 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002827 keyList.append(k)
2828 continue
2829 except KeyError:
2830 pass
2831
2832 for k in keyList:
2833 del self.TOSA_OP_LIST[k]
2834
2835 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002836 """Fill in default fields for ops if they aren't already specified.
2837 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002838 for op in self.TOSA_OP_LIST:
2839
2840 # Required fields
2841 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002842 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002843 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002844 raise Exception(
2845 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2846 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002847
2848 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002849 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002850 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002851 raise Exception(
2852 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2853 op
2854 )
2855 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002856
2857 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002858 _ = self.TOSA_OP_LIST[op]["types"]
2859 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002860 raise Exception(
2861 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2862 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002863
2864 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002865 _ = self.TOSA_OP_LIST[op]["op"]
2866 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002867 raise Exception(
2868 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2869 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002870
2871 # Put in default rank range, if missing
2872 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002873 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002874 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002875 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002876
2877 # Tensor operator list
2878 # 'op': op name
2879 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002880 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2881 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002882 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2883 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002884 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002885
Kevin Cheng550ccc52021-03-03 11:21:43 -08002886 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002887 TYPE_INT_FP = [
2888 DType.INT8,
2889 DType.INT16,
2890 DType.INT32,
2891 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002892 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002893 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002894 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002895
Kevin Cheng550ccc52021-03-03 11:21:43 -08002896 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002897 TYPE_FI32 = [
2898 DType.FP32,
2899 DType.FP16,
2900 DType.BF16,
2901 DType.INT32,
2902 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002903 TYPE_FIB = [
2904 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002905 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002906 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002907 DType.INT8,
2908 DType.INT16,
2909 DType.INT32,
2910 DType.BOOL,
2911 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002912 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002913
James Ward24dbc422022-10-19 12:20:31 +01002914 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002915
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002916 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002917 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002918 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002919 [DType.INT8, DType.INT8, DType.INT32],
2920 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002921 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002922 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002923 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002924 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002925 ]
2926
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002927 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002928
2929 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002930 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002931 "argmax": {
2932 "op": Op.ARGMAX,
2933 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002934 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002935 "build_fcn": (
2936 build_argmax,
2937 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002938 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002939 TosaArgGen.agAxis,
2940 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002941 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002942 "error_if_validators": (
2943 TosaErrorValidator.evAxisSmallerZero,
2944 TosaErrorValidator.evAxisLargerRank,
2945 TosaErrorValidator.evArgmaxOutputRankMismatch,
2946 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2947 TosaErrorValidator.evWrongRank,
2948 TosaErrorValidator.evWrongInputType,
2949 TosaErrorValidator.evWrongOutputType,
2950 TosaErrorValidator.evWrongInputList,
2951 TosaErrorValidator.evWrongOutputList,
2952 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002953 "data_gen": {
2954 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
2955 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002956 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002957 "avg_pool2d": {
2958 "op": Op.AVG_POOL2D,
2959 "operands": (1, 0),
2960 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002961 "build_fcn": (
2962 build_pool2d,
2963 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002964 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002965 TosaArgGen.agPooling,
2966 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002967 "qgen": TosaQuantGen.qgUnary,
2968 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002969 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002970 "error_if_validators": (
2971 TosaErrorValidator.evKernelSmallerOne,
2972 TosaErrorValidator.evStrideSmallerOne,
2973 TosaErrorValidator.evPadSmallerZero,
2974 TosaErrorValidator.evWrongRank,
2975 TosaErrorValidator.evWrongInputType,
2976 TosaErrorValidator.evWrongOutputType,
2977 TosaErrorValidator.evWrongInputList,
2978 TosaErrorValidator.evWrongOutputList,
2979 TosaErrorValidator.evInputZeroPointNotZero,
2980 TosaErrorValidator.evOutputZeroPointNotZero,
2981 TosaErrorValidator.evPadLargerEqualKernel,
2982 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002983 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002984 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002985 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002986 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002987 "conv2d_TEMPLATE": {
2988 "op": Op.CONV2D,
2989 "operands": (1, 2),
2990 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002991 "build_fcn": (
2992 build_conv2d,
2993 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002994 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002995 TosaArgGen.agConv,
2996 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002997 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002998 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002999 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3000 "error_if_validators": (
3001 TosaErrorValidator.evWrongInputType,
3002 TosaErrorValidator.evWrongOutputType,
3003 TosaErrorValidator.evWrongInputList,
3004 TosaErrorValidator.evWrongOutputList,
3005 TosaErrorValidator.evInputZeroPointNotZero,
3006 TosaErrorValidator.evWeightZeroPointNotZero,
3007 TosaErrorValidator.evPadSmallerZero,
3008 TosaErrorValidator.evStrideSmallerOne,
3009 TosaErrorValidator.evDilationSmallerOne,
3010 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003011 TosaErrorValidator.evConvOutputShapeMismatch,
3012 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003013 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003014 "data_gen": {
3015 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3016 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003017 "template": True,
3018 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003019 # Templated operator. Filled in by createDynamicOpLists
3020 "conv3d_TEMPLATE": {
3021 "op": Op.CONV3D,
3022 "operands": (1, 2),
3023 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003024 "build_fcn": (
3025 build_conv3d,
3026 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003027 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003028 TosaArgGen.agConv,
3029 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003030 "qgen": TosaQuantGen.qgConv,
3031 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003032 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3033 "error_if_validators": (
3034 TosaErrorValidator.evWrongInputType,
3035 TosaErrorValidator.evWrongOutputType,
3036 TosaErrorValidator.evWrongInputList,
3037 TosaErrorValidator.evWrongOutputList,
3038 TosaErrorValidator.evInputZeroPointNotZero,
3039 TosaErrorValidator.evWeightZeroPointNotZero,
3040 TosaErrorValidator.evPadSmallerZero,
3041 TosaErrorValidator.evStrideSmallerOne,
3042 TosaErrorValidator.evDilationSmallerOne,
3043 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003044 TosaErrorValidator.evConvOutputShapeMismatch,
3045 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003046 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003047 "template": True,
3048 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003049 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003050 "depthwise_conv2d_TEMPLATE": {
3051 "op": Op.DEPTHWISE_CONV2D,
3052 "operands": (1, 2),
3053 "filter": [1, 1],
3054 "rank": (4, 4),
3055 "build_fcn": (
3056 build_depthwise_conv2d,
3057 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003058 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003059 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003060 ),
3061 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003062 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003063 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3064 "error_if_validators": (
3065 TosaErrorValidator.evWrongInputType,
3066 TosaErrorValidator.evWrongOutputType,
3067 TosaErrorValidator.evWrongInputList,
3068 TosaErrorValidator.evWrongOutputList,
3069 TosaErrorValidator.evInputZeroPointNotZero,
3070 TosaErrorValidator.evWeightZeroPointNotZero,
3071 TosaErrorValidator.evPadSmallerZero,
3072 TosaErrorValidator.evStrideSmallerOne,
3073 TosaErrorValidator.evDilationSmallerOne,
3074 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003075 TosaErrorValidator.evConvOutputShapeMismatch,
3076 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003077 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003078 "template": True,
3079 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003080 "fully_connected": {
3081 "op": Op.FULLY_CONNECTED,
3082 "operands": (1, 2),
3083 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003084 "build_fcn": (
3085 build_fully_connected,
3086 TosaTensorGen.tgFullyConnected,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003087 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003088 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003089 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003090 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003091 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003092 "error_if_validators": (
3093 TosaErrorValidator.evInputZeroPointNotZero,
3094 TosaErrorValidator.evWeightZeroPointNotZero,
3095 TosaErrorValidator.evWrongRank,
3096 TosaErrorValidator.evWrongInputType,
3097 TosaErrorValidator.evWrongOutputType,
3098 TosaErrorValidator.evWrongInputList,
3099 TosaErrorValidator.evWrongOutputList,
3100 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003101 "data_gen": {
3102 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3103 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003104 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003105 "matmul": {
3106 "op": Op.MATMUL,
3107 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003108 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003109 "build_fcn": (
3110 build_matmul,
3111 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003112 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003113 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003114 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003115 "qgen": TosaQuantGen.qgMatmul,
3116 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003117 "error_if_validators": (
3118 TosaErrorValidator.evInputZeroPointNotZero,
3119 TosaErrorValidator.evWrongRank,
3120 TosaErrorValidator.evWrongInputType,
3121 TosaErrorValidator.evWrongOutputType,
3122 TosaErrorValidator.evWrongInputList,
3123 TosaErrorValidator.evWrongOutputList,
3124 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003125 "data_gen": {
3126 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003127 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003128 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003129 "max_pool2d": {
3130 "op": Op.MAX_POOL2D,
3131 "operands": (1, 0),
3132 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003133 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01003134 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003135 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003136 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003137 TosaArgGen.agPooling,
3138 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003139 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003140 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003141 "error_if_validators": (
3142 TosaErrorValidator.evKernelSmallerOne,
3143 TosaErrorValidator.evStrideSmallerOne,
3144 TosaErrorValidator.evPadSmallerZero,
3145 TosaErrorValidator.evWrongRank,
3146 TosaErrorValidator.evWrongInputType,
3147 TosaErrorValidator.evWrongOutputType,
3148 TosaErrorValidator.evWrongInputList,
3149 TosaErrorValidator.evWrongOutputList,
3150 TosaErrorValidator.evPadLargerEqualKernel,
3151 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003152 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003153 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003154 "data_gen": {
3155 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3156 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003157 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003158 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003159 "transpose_conv2d_TEMPLATE": {
3160 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003161 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003162 "rank": (4, 4),
3163 "build_fcn": (
3164 build_transpose_conv2d,
3165 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003166 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003167 TosaArgGen.agTransposeConv2D,
3168 ),
3169 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003170 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003171 "invalid_test_validators": (
3172 TosaInvalidValidator.ivHeightWidthInvalid,
3173 TosaInvalidValidator.ivNonPositiveOutputShape,
3174 ),
3175 "error_if_validators": (
3176 TosaErrorValidator.evWrongInputType,
3177 TosaErrorValidator.evWrongOutputType,
3178 TosaErrorValidator.evWrongInputList,
3179 TosaErrorValidator.evWrongOutputList,
3180 TosaErrorValidator.evInputZeroPointNotZero,
3181 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003182 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003183 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003184 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003185 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003186 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003187 "template": True,
3188 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003189 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003190 "clamp": {
3191 "op": Op.CLAMP,
3192 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003193 "build_fcn": (
3194 build_clamp,
3195 TosaTensorGen.tgBasic,
3196 TosaTensorValuesGen.tvgDefault,
3197 None,
3198 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003199 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003200 "error_if_validators": (
3201 TosaErrorValidator.evMaxSmallerMin,
3202 TosaErrorValidator.evWrongInputType,
3203 TosaErrorValidator.evWrongOutputType,
3204 TosaErrorValidator.evWrongInputList,
3205 TosaErrorValidator.evWrongOutputList,
3206 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003207 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003208 "sigmoid": {
3209 "op": Op.SIGMOID,
3210 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003211 "build_fcn": (
3212 build_sigmoid,
3213 TosaTensorGen.tgBasic,
3214 TosaTensorValuesGen.tvgDefault,
3215 None,
3216 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003217 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003218 "error_if_validators": (
3219 TosaErrorValidator.evWrongInputType,
3220 TosaErrorValidator.evWrongOutputType,
3221 TosaErrorValidator.evWrongInputList,
3222 TosaErrorValidator.evWrongOutputList,
3223 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003224 },
3225 "tanh": {
3226 "op": Op.TANH,
3227 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003228 "build_fcn": (
3229 build_tanh,
3230 TosaTensorGen.tgBasic,
3231 TosaTensorValuesGen.tvgDefault,
3232 None,
3233 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003234 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003235 "error_if_validators": (
3236 TosaErrorValidator.evWrongInputType,
3237 TosaErrorValidator.evWrongOutputType,
3238 TosaErrorValidator.evWrongInputList,
3239 TosaErrorValidator.evWrongOutputList,
3240 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003241 },
Won Jeon78155c62023-06-10 00:20:04 +00003242 "erf": {
3243 "op": Op.ERF,
3244 "operands": (1, 0),
3245 "build_fcn": (
3246 build_erf,
3247 TosaTensorGen.tgBasic,
3248 TosaTensorValuesGen.tvgDefault,
3249 None,
3250 ),
3251 "types": TYPE_FP,
3252 "error_if_validators": (
3253 TosaErrorValidator.evWrongInputType,
3254 TosaErrorValidator.evWrongOutputType,
3255 TosaErrorValidator.evWrongInputList,
3256 TosaErrorValidator.evWrongOutputList,
3257 ),
3258 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003259 # Elementwise Binary Operators
3260 "add": {
3261 "op": Op.ADD,
3262 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003263 "build_fcn": (
3264 build_binary_broadcast,
3265 TosaTensorGen.tgBroadcastFuzz,
3266 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003267 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003268 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003269 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003270 "error_if_validators": (
3271 TosaErrorValidator.evRankMismatch,
3272 TosaErrorValidator.evWrongInputType,
3273 TosaErrorValidator.evWrongOutputType,
3274 TosaErrorValidator.evWrongInputList,
3275 TosaErrorValidator.evWrongOutputList,
3276 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003277 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003278 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003279 "data_gen": {
3280 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3281 },
3282 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003283 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003284 "arithmetic_right_shift": {
3285 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3286 "operands": (2, 0),
3287 "build_fcn": (
3288 build_arithmetic_right_shift,
3289 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003290 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003291 TosaArgGen.agArithmeticRightShift,
3292 ),
3293 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003294 "error_if_validators": (
3295 TosaErrorValidator.evRankMismatch,
3296 TosaErrorValidator.evWrongInputType,
3297 TosaErrorValidator.evWrongOutputType,
3298 TosaErrorValidator.evWrongInputList,
3299 TosaErrorValidator.evWrongOutputList,
3300 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003301 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003302 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003303 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003304 "bitwise_and": {
3305 "op": Op.BITWISE_AND,
3306 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003307 "build_fcn": (
3308 build_binary_broadcast,
3309 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003310 TosaTensorValuesGen.tvgLazyGenDefault,
3311 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003312 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003313 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003314 "error_if_validators": (
3315 TosaErrorValidator.evRankMismatch,
3316 TosaErrorValidator.evWrongInputType,
3317 TosaErrorValidator.evWrongOutputType,
3318 TosaErrorValidator.evWrongInputList,
3319 TosaErrorValidator.evWrongOutputList,
3320 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003321 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003322 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003323 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003324 "bitwise_or": {
3325 "op": Op.BITWISE_OR,
3326 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003327 "build_fcn": (
3328 build_binary_broadcast,
3329 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003330 TosaTensorValuesGen.tvgLazyGenDefault,
3331 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003332 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003333 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003334 "error_if_validators": (
3335 TosaErrorValidator.evRankMismatch,
3336 TosaErrorValidator.evWrongInputType,
3337 TosaErrorValidator.evWrongOutputType,
3338 TosaErrorValidator.evWrongInputList,
3339 TosaErrorValidator.evWrongOutputList,
3340 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003341 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003342 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003343 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003344 "bitwise_xor": {
3345 "op": Op.BITWISE_XOR,
3346 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003347 "build_fcn": (
3348 build_binary_broadcast,
3349 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003350 TosaTensorValuesGen.tvgLazyGenDefault,
3351 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003352 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003353 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003354 "error_if_validators": (
3355 TosaErrorValidator.evRankMismatch,
3356 TosaErrorValidator.evWrongInputType,
3357 TosaErrorValidator.evWrongOutputType,
3358 TosaErrorValidator.evWrongInputList,
3359 TosaErrorValidator.evWrongOutputList,
3360 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003361 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003362 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003363 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003364 "intdiv": {
3365 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003366 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003367 "build_fcn": (
3368 build_binary_broadcast,
3369 TosaTensorGen.tgBroadcastFuzz,
3370 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003371 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003372 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003373 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003374 "error_if_validators": (
3375 TosaErrorValidator.evRankMismatch,
3376 TosaErrorValidator.evWrongInputType,
3377 TosaErrorValidator.evWrongOutputType,
3378 TosaErrorValidator.evWrongInputList,
3379 TosaErrorValidator.evWrongOutputList,
3380 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003381 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003382 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003383 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003384 "logical_and": {
3385 "op": Op.LOGICAL_AND,
3386 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003387 "build_fcn": (
3388 build_binary_broadcast,
3389 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003390 TosaTensorValuesGen.tvgLazyGenDefault,
3391 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003392 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003393 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003394 "error_if_validators": (
3395 TosaErrorValidator.evRankMismatch,
3396 TosaErrorValidator.evWrongInputType,
3397 TosaErrorValidator.evWrongOutputType,
3398 TosaErrorValidator.evWrongInputList,
3399 TosaErrorValidator.evWrongOutputList,
3400 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003401 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003402 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003403 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003404 "logical_left_shift": {
3405 "op": Op.LOGICAL_LEFT_SHIFT,
3406 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003407 "build_fcn": (
3408 build_binary_broadcast,
3409 TosaTensorGen.tgBroadcastFuzz,
3410 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003411 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003412 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003413 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003414 "error_if_validators": (
3415 TosaErrorValidator.evRankMismatch,
3416 TosaErrorValidator.evWrongInputType,
3417 TosaErrorValidator.evWrongOutputType,
3418 TosaErrorValidator.evWrongInputList,
3419 TosaErrorValidator.evWrongOutputList,
3420 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003421 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003422 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003423 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003424 "logical_right_shift": {
3425 "op": Op.LOGICAL_RIGHT_SHIFT,
3426 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003427 "build_fcn": (
3428 build_binary_broadcast,
3429 TosaTensorGen.tgBroadcastFuzz,
3430 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003431 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003432 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003433 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003434 "error_if_validators": (
3435 TosaErrorValidator.evRankMismatch,
3436 TosaErrorValidator.evWrongInputType,
3437 TosaErrorValidator.evWrongOutputType,
3438 TosaErrorValidator.evWrongInputList,
3439 TosaErrorValidator.evWrongOutputList,
3440 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003441 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003442 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003443 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003444 "logical_or": {
3445 "op": Op.LOGICAL_OR,
3446 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003447 "build_fcn": (
3448 build_binary_broadcast,
3449 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003450 TosaTensorValuesGen.tvgLazyGenDefault,
3451 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003452 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003453 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003454 "error_if_validators": (
3455 TosaErrorValidator.evRankMismatch,
3456 TosaErrorValidator.evWrongInputType,
3457 TosaErrorValidator.evWrongOutputType,
3458 TosaErrorValidator.evWrongInputList,
3459 TosaErrorValidator.evWrongOutputList,
3460 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003461 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003462 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003463 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003464 "logical_xor": {
3465 "op": Op.LOGICAL_XOR,
3466 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003467 "build_fcn": (
3468 build_binary_broadcast,
3469 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003470 TosaTensorValuesGen.tvgLazyGenDefault,
3471 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003472 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003474 "error_if_validators": (
3475 TosaErrorValidator.evRankMismatch,
3476 TosaErrorValidator.evWrongInputType,
3477 TosaErrorValidator.evWrongOutputType,
3478 TosaErrorValidator.evWrongInputList,
3479 TosaErrorValidator.evWrongOutputList,
3480 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003481 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003482 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003483 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003484 "maximum": {
3485 "op": Op.MAXIMUM,
3486 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003487 "build_fcn": (
3488 build_binary_broadcast,
3489 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003490 TosaTensorValuesGen.tvgLazyGenDefault,
3491 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003492 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003493 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003494 "error_if_validators": (
3495 TosaErrorValidator.evRankMismatch,
3496 TosaErrorValidator.evWrongInputType,
3497 TosaErrorValidator.evWrongOutputType,
3498 TosaErrorValidator.evWrongInputList,
3499 TosaErrorValidator.evWrongOutputList,
3500 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003501 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003502 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003503 "data_gen": {
3504 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3505 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003506 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003507 "minimum": {
3508 "op": Op.MINIMUM,
3509 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003510 "build_fcn": (
3511 build_binary_broadcast,
3512 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003513 TosaTensorValuesGen.tvgLazyGenDefault,
3514 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003515 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003516 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003517 "error_if_validators": (
3518 TosaErrorValidator.evRankMismatch,
3519 TosaErrorValidator.evWrongInputType,
3520 TosaErrorValidator.evWrongOutputType,
3521 TosaErrorValidator.evWrongInputList,
3522 TosaErrorValidator.evWrongOutputList,
3523 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003524 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003525 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003526 "data_gen": {
3527 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3528 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003529 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003530 "mul": {
3531 "op": Op.MUL,
3532 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003533 "build_fcn": (
3534 build_mul,
3535 TosaTensorGen.tgBroadcastFuzz,
3536 TosaTensorValuesGen.tvgMul,
3537 TosaArgGen.agMul,
3538 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003539 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003540 "error_if_validators": (
3541 TosaErrorValidator.evWrongInputType,
3542 TosaErrorValidator.evWrongOutputType,
3543 TosaErrorValidator.evWrongInputList,
3544 TosaErrorValidator.evWrongOutputList,
3545 TosaErrorValidator.evRankMismatch,
3546 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003547 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003548 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003549 "data_gen": {
3550 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3551 },
3552 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003553 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003554 "pow": {
3555 "op": Op.POW,
3556 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003557 "build_fcn": (
3558 build_binary_broadcast,
3559 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003560 TosaTensorValuesGen.tvgLazyGenDefault,
3561 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003562 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003563 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003564 "error_if_validators": (
3565 TosaErrorValidator.evRankMismatch,
3566 TosaErrorValidator.evWrongInputType,
3567 TosaErrorValidator.evWrongOutputType,
3568 TosaErrorValidator.evWrongInputList,
3569 TosaErrorValidator.evWrongOutputList,
3570 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003571 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003572 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003573 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003574 "sub": {
3575 "op": Op.SUB,
3576 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003577 "build_fcn": (
3578 build_binary_broadcast,
3579 TosaTensorGen.tgBroadcastFuzz,
3580 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003581 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003582 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003583 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003584 "error_if_validators": (
3585 TosaErrorValidator.evRankMismatch,
3586 TosaErrorValidator.evWrongInputType,
3587 TosaErrorValidator.evWrongOutputType,
3588 TosaErrorValidator.evWrongInputList,
3589 TosaErrorValidator.evWrongOutputList,
3590 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003591 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003592 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003593 "data_gen": {
3594 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3595 },
3596 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003597 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003598 "table": {
3599 "op": Op.TABLE,
3600 # Use the automatic generation functions to create the input array
3601 # but create the table tensor in the build function, as it may be
3602 # a different type from the input
3603 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003604 "build_fcn": (
3605 build_table,
3606 TosaTensorGen.tgBasic,
3607 TosaTensorValuesGen.tvgDefault,
3608 TosaArgGen.agTable,
3609 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003610 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003611 "error_if_validators": (
3612 TosaErrorValidator.evWrongInputType,
3613 TosaErrorValidator.evWrongOutputType,
3614 TosaErrorValidator.evWrongInputList,
3615 TosaErrorValidator.evWrongOutputList,
3616 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003617 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003618 # Elementwise Unary operators
3619 "abs": {
3620 "op": Op.ABS,
3621 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003622 "build_fcn": (
3623 build_unary,
3624 TosaTensorGen.tgBasic,
3625 TosaTensorValuesGen.tvgDefault,
3626 None,
3627 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003628 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003629 "error_if_validators": (
3630 TosaErrorValidator.evWrongInputType,
3631 TosaErrorValidator.evWrongOutputType,
3632 TosaErrorValidator.evWrongInputList,
3633 TosaErrorValidator.evWrongOutputList,
3634 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003635 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003636 "bitwise_not": {
3637 "op": Op.BITWISE_NOT,
3638 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003639 "build_fcn": (
3640 build_unary,
3641 TosaTensorGen.tgBasic,
3642 TosaTensorValuesGen.tvgDefault,
3643 None,
3644 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003645 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003646 "error_if_validators": (
3647 TosaErrorValidator.evWrongInputType,
3648 TosaErrorValidator.evWrongOutputType,
3649 TosaErrorValidator.evWrongInputList,
3650 TosaErrorValidator.evWrongOutputList,
3651 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003652 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003653 "ceil": {
3654 "op": Op.CEIL,
3655 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003656 "build_fcn": (
3657 build_unary,
3658 TosaTensorGen.tgBasic,
3659 TosaTensorValuesGen.tvgDefault,
3660 None,
3661 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003662 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003663 "error_if_validators": (
3664 TosaErrorValidator.evWrongInputType,
3665 TosaErrorValidator.evWrongOutputType,
3666 TosaErrorValidator.evWrongInputList,
3667 TosaErrorValidator.evWrongOutputList,
3668 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003669 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003670 "clz": {
3671 "op": Op.CLZ,
3672 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003673 "build_fcn": (
3674 build_unary,
3675 TosaTensorGen.tgBasic,
3676 TosaTensorValuesGen.tvgDefault,
3677 None,
3678 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003679 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003680 "error_if_validators": (
3681 TosaErrorValidator.evWrongInputType,
3682 TosaErrorValidator.evWrongOutputType,
3683 TosaErrorValidator.evWrongInputList,
3684 TosaErrorValidator.evWrongOutputList,
3685 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003686 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003687 "exp": {
3688 "op": Op.EXP,
3689 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003690 "build_fcn": (
3691 build_unary,
3692 TosaTensorGen.tgBasic,
3693 TosaTensorValuesGen.tvgDefault,
3694 None,
3695 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003696 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003697 "error_if_validators": (
3698 TosaErrorValidator.evWrongInputType,
3699 TosaErrorValidator.evWrongOutputType,
3700 TosaErrorValidator.evWrongInputList,
3701 TosaErrorValidator.evWrongOutputList,
3702 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003703 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003704 "floor": {
3705 "op": Op.FLOOR,
3706 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003707 "build_fcn": (
3708 build_unary,
3709 TosaTensorGen.tgBasic,
3710 TosaTensorValuesGen.tvgDefault,
3711 None,
3712 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003713 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003714 "error_if_validators": (
3715 TosaErrorValidator.evWrongInputType,
3716 TosaErrorValidator.evWrongOutputType,
3717 TosaErrorValidator.evWrongInputList,
3718 TosaErrorValidator.evWrongOutputList,
3719 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003720 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003721 "log": {
3722 "op": Op.LOG,
3723 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003724 "build_fcn": (
3725 build_unary,
3726 TosaTensorGen.tgBasic,
3727 TosaTensorValuesGen.tvgDefault,
3728 None,
3729 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003730 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003731 "error_if_validators": (
3732 TosaErrorValidator.evWrongInputType,
3733 TosaErrorValidator.evWrongOutputType,
3734 TosaErrorValidator.evWrongInputList,
3735 TosaErrorValidator.evWrongOutputList,
3736 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003737 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003738 "logical_not": {
3739 "op": Op.LOGICAL_NOT,
3740 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003741 "build_fcn": (
3742 build_unary,
3743 TosaTensorGen.tgBasic,
3744 TosaTensorValuesGen.tvgDefault,
3745 None,
3746 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003747 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003748 "error_if_validators": (
3749 TosaErrorValidator.evWrongInputType,
3750 TosaErrorValidator.evWrongOutputType,
3751 TosaErrorValidator.evWrongInputList,
3752 TosaErrorValidator.evWrongOutputList,
3753 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003754 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003755 "negate": {
3756 "op": Op.NEGATE,
3757 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003758 "build_fcn": (
3759 build_unary,
3760 TosaTensorGen.tgBasic,
3761 TosaTensorValuesGen.tvgNegate,
3762 None,
3763 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003764 "qgen": TosaQuantGen.qgUnary,
3765 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003766 "error_if_validators": (
3767 TosaErrorValidator.evInputZeroPointNotZero,
3768 TosaErrorValidator.evOutputZeroPointNotZero,
3769 TosaErrorValidator.evWrongInputType,
3770 TosaErrorValidator.evWrongOutputType,
3771 TosaErrorValidator.evWrongInputList,
3772 TosaErrorValidator.evWrongOutputList,
3773 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003774 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003775 "reciprocal": {
3776 "op": Op.RECIPROCAL,
3777 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003778 "build_fcn": (
3779 build_unary,
3780 TosaTensorGen.tgBasic,
3781 TosaTensorValuesGen.tvgDefault,
3782 None,
3783 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003784 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003785 "error_if_validators": (
3786 TosaErrorValidator.evWrongInputType,
3787 TosaErrorValidator.evWrongOutputType,
3788 TosaErrorValidator.evWrongInputList,
3789 TosaErrorValidator.evWrongOutputList,
3790 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003791 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003792 "rsqrt": {
3793 "op": Op.RSQRT,
3794 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003795 "build_fcn": (
3796 build_unary,
3797 TosaTensorGen.tgBasic,
3798 TosaTensorValuesGen.tvgDefault,
3799 None,
3800 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003801 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003802 "error_if_validators": (
3803 TosaErrorValidator.evWrongInputType,
3804 TosaErrorValidator.evWrongOutputType,
3805 TosaErrorValidator.evWrongInputList,
3806 TosaErrorValidator.evWrongOutputList,
3807 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003808 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003809 # Elementwise Ternary operators
3810 "select": {
3811 "op": Op.SELECT,
3812 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003813 "build_fcn": (
3814 build_select,
3815 TosaTensorGen.tgBroadcastFuzz,
3816 TosaTensorValuesGen.tvgSelect,
3817 None,
3818 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003819 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003820 "error_if_validators": (
3821 TosaErrorValidator.evRankMismatch,
3822 TosaErrorValidator.evWrongInputType,
3823 TosaErrorValidator.evWrongOutputType,
3824 TosaErrorValidator.evWrongInputList,
3825 TosaErrorValidator.evWrongOutputList,
3826 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003827 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003828 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003829 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003830 # Comparison operators
3831 "equal": {
3832 "op": Op.EQUAL,
3833 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003834 "build_fcn": (
3835 build_comparison,
3836 TosaTensorGen.tgBroadcastFuzz,
3837 TosaTensorValuesGen.tvgEqual,
3838 None,
3839 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003840 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003841 "error_if_validators": (
3842 TosaErrorValidator.evRankMismatch,
3843 TosaErrorValidator.evWrongInputType,
3844 TosaErrorValidator.evWrongOutputType,
3845 TosaErrorValidator.evWrongInputList,
3846 TosaErrorValidator.evWrongOutputList,
3847 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003848 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003849 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003850 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003851 "greater_equal": {
3852 "op": Op.GREATER_EQUAL,
3853 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003854 "build_fcn": (
3855 build_comparison,
3856 TosaTensorGen.tgBroadcastFuzz,
3857 TosaTensorValuesGen.tvgDefault,
3858 None,
3859 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003860 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003861 "error_if_validators": (
3862 TosaErrorValidator.evRankMismatch,
3863 TosaErrorValidator.evWrongInputType,
3864 TosaErrorValidator.evWrongOutputType,
3865 TosaErrorValidator.evWrongInputList,
3866 TosaErrorValidator.evWrongOutputList,
3867 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003868 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003869 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003870 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003871 "greater": {
3872 "op": Op.GREATER,
3873 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003874 "build_fcn": (
3875 build_comparison,
3876 TosaTensorGen.tgBroadcastFuzz,
3877 TosaTensorValuesGen.tvgDefault,
3878 None,
3879 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003880 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003881 "error_if_validators": (
3882 TosaErrorValidator.evRankMismatch,
3883 TosaErrorValidator.evWrongInputType,
3884 TosaErrorValidator.evWrongOutputType,
3885 TosaErrorValidator.evWrongInputList,
3886 TosaErrorValidator.evWrongOutputList,
3887 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003888 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003889 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003890 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003891 # Reduction operators
3892 "reduce_all": {
3893 "op": Op.REDUCE_ALL,
3894 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003895 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003896 "build_fcn": (
3897 build_reduce,
3898 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003899 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003900 TosaArgGen.agAxis,
3901 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003902 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003903 "error_if_validators": (
3904 TosaErrorValidator.evAxisLargerRank,
3905 TosaErrorValidator.evAxisSmallerZero,
3906 TosaErrorValidator.evShapeOfAxisNotOne,
3907 TosaErrorValidator.evWrongInputType,
3908 TosaErrorValidator.evWrongOutputType,
3909 TosaErrorValidator.evWrongRank,
3910 TosaErrorValidator.evWrongInputList,
3911 TosaErrorValidator.evWrongOutputList,
3912 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003913 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003914 "reduce_any": {
3915 "op": Op.REDUCE_ANY,
3916 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003917 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003918 "build_fcn": (
3919 build_reduce,
3920 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003921 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003922 TosaArgGen.agAxis,
3923 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003924 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003925 "error_if_validators": (
3926 TosaErrorValidator.evAxisLargerRank,
3927 TosaErrorValidator.evAxisSmallerZero,
3928 TosaErrorValidator.evShapeOfAxisNotOne,
3929 TosaErrorValidator.evWrongInputType,
3930 TosaErrorValidator.evWrongOutputType,
3931 TosaErrorValidator.evWrongRank,
3932 TosaErrorValidator.evWrongInputList,
3933 TosaErrorValidator.evWrongOutputList,
3934 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003935 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003936 "reduce_max": {
3937 "op": Op.REDUCE_MAX,
3938 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003939 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003940 "build_fcn": (
3941 build_reduce,
3942 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003943 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003944 TosaArgGen.agAxis,
3945 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003946 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003947 "error_if_validators": (
3948 TosaErrorValidator.evAxisLargerRank,
3949 TosaErrorValidator.evAxisSmallerZero,
3950 TosaErrorValidator.evShapeOfAxisNotOne,
3951 TosaErrorValidator.evWrongInputType,
3952 TosaErrorValidator.evWrongOutputType,
3953 TosaErrorValidator.evWrongRank,
3954 TosaErrorValidator.evWrongInputList,
3955 TosaErrorValidator.evWrongOutputList,
3956 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003957 "data_gen": {
3958 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3959 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003960 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003961 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003962 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003963 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003964 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003965 "build_fcn": (
3966 build_reduce,
3967 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003968 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003969 TosaArgGen.agAxis,
3970 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003971 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003972 "error_if_validators": (
3973 TosaErrorValidator.evAxisLargerRank,
3974 TosaErrorValidator.evAxisSmallerZero,
3975 TosaErrorValidator.evShapeOfAxisNotOne,
3976 TosaErrorValidator.evWrongInputType,
3977 TosaErrorValidator.evWrongOutputType,
3978 TosaErrorValidator.evWrongRank,
3979 TosaErrorValidator.evWrongInputList,
3980 TosaErrorValidator.evWrongOutputList,
3981 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003982 "data_gen": {
3983 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3984 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003985 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003986 "reduce_product": {
3987 "op": Op.REDUCE_PRODUCT,
3988 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003989 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003990 "build_fcn": (
3991 build_reduce,
3992 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003993 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003994 TosaArgGen.agAxis,
3995 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003996 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003997 "error_if_validators": (
3998 TosaErrorValidator.evAxisLargerRank,
3999 TosaErrorValidator.evAxisSmallerZero,
4000 TosaErrorValidator.evShapeOfAxisNotOne,
4001 TosaErrorValidator.evWrongInputType,
4002 TosaErrorValidator.evWrongOutputType,
4003 TosaErrorValidator.evWrongRank,
4004 TosaErrorValidator.evWrongInputList,
4005 TosaErrorValidator.evWrongOutputList,
4006 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004007 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004008 "reduce_sum": {
4009 "op": Op.REDUCE_SUM,
4010 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004011 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004012 "build_fcn": (
4013 build_reduce,
4014 TosaTensorGen.tgBasic,
4015 TosaTensorValuesGen.tvgReduceSum,
4016 TosaArgGen.agAxis,
4017 ),
James Ward24dbc422022-10-19 12:20:31 +01004018 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004019 "error_if_validators": (
4020 TosaErrorValidator.evAxisLargerRank,
4021 TosaErrorValidator.evAxisSmallerZero,
4022 TosaErrorValidator.evShapeOfAxisNotOne,
4023 TosaErrorValidator.evWrongInputType,
4024 TosaErrorValidator.evWrongOutputType,
4025 TosaErrorValidator.evWrongRank,
4026 TosaErrorValidator.evWrongInputList,
4027 TosaErrorValidator.evWrongOutputList,
4028 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004029 "data_gen": {
4030 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4031 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004032 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004033 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004034 "concat": {
4035 "op": Op.CONCAT,
4036 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004037 "build_fcn": (
4038 build_concat,
4039 TosaTensorGen.tgConcat,
4040 TosaTensorValuesGen.tvgConcat,
4041 TosaArgGen.agAxis,
4042 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004043 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004044 "error_if_validators": (
4045 TosaErrorValidator.evAxisLargerRank,
4046 TosaErrorValidator.evAxisSmallerZero,
4047 TosaErrorValidator.evConcatInputRankMismatch,
4048 TosaErrorValidator.evConcatShapeSumMismatch,
4049 TosaErrorValidator.evConcatInputDimMismatch,
4050 TosaErrorValidator.evWrongInputType,
4051 TosaErrorValidator.evWrongOutputType,
4052 TosaErrorValidator.evWrongOutputList,
4053 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004054 },
4055 "pad": {
4056 "op": Op.PAD,
4057 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004058 "build_fcn": (
4059 build_pad,
4060 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004061 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004062 TosaArgGen.agPad,
4063 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004064 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004065 "error_if_validators": (
4066 TosaErrorValidator.evWrongInputType,
4067 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004068 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004069 TosaErrorValidator.evWrongOutputType,
4070 TosaErrorValidator.evWrongInputList,
4071 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004072 TosaErrorValidator.evRankMismatch,
4073 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004074 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004075 "data_gen": {
4076 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4077 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004078 },
Won Jeona21b2e82023-08-10 10:33:01 +00004079 "dim": {
4080 "op": Op.DIM,
4081 "operands": (1, 0),
4082 "build_fcn": (
4083 build_dim,
4084 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004085 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004086 TosaArgGen.agAxis,
4087 ),
4088 "types": TYPE_FIB,
4089 "error_if_validators": (
4090 TosaErrorValidator.evAxisLargerRank,
4091 TosaErrorValidator.evAxisSmallerZero,
4092 TosaErrorValidator.evWrongInputType,
4093 TosaErrorValidator.evWrongInputList,
4094 TosaErrorValidator.evWrongOutputList,
4095 TosaErrorValidator.evWrongRank,
4096 ),
4097 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004098 "reshape": {
4099 "op": Op.RESHAPE,
4100 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004101 "build_fcn": (
4102 build_reshape,
4103 TosaTensorGen.tgBasic,
4104 TosaTensorValuesGen.tvgDefault,
4105 TosaArgGen.agReshape,
4106 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004107 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004108 "error_if_validators": (
4109 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4110 TosaErrorValidator.evWrongInputType,
4111 TosaErrorValidator.evWrongOutputType,
4112 TosaErrorValidator.evWrongInputList,
4113 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00004114 TosaErrorValidator.evReshapeOutputSizeMultiInference,
4115 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004116 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004117 },
4118 "reverse": {
4119 "op": Op.REVERSE,
4120 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004121 "build_fcn": (
4122 build_reverse,
4123 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004124 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004125 TosaArgGen.agAxis,
4126 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004127 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004128 "error_if_validators": (
4129 TosaErrorValidator.evAxisSmallerZero,
4130 TosaErrorValidator.evAxisLargerRank,
4131 TosaErrorValidator.evWrongInputType,
4132 TosaErrorValidator.evWrongOutputType,
4133 TosaErrorValidator.evWrongInputList,
4134 TosaErrorValidator.evWrongOutputList,
4135 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004136 },
4137 "slice": {
4138 "op": Op.SLICE,
4139 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004140 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004141 "build_fcn": (
4142 build_slice,
4143 TosaTensorGen.tgBasic,
4144 TosaTensorValuesGen.tvgDefault,
4145 TosaArgGen.agSlice,
4146 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004147 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004148 "error_if_validators": (
4149 TosaErrorValidator.evStartSmallerZero,
4150 TosaErrorValidator.evSizeSmallerEqualZero,
4151 TosaErrorValidator.evStartSizeOutsideBounds,
4152 TosaErrorValidator.evSizeOutputShapeMismatch,
4153 TosaErrorValidator.evInputSizeStartLengthMismatch,
4154 TosaErrorValidator.evWrongRank,
4155 TosaErrorValidator.evWrongInputType,
4156 TosaErrorValidator.evWrongOutputType,
4157 TosaErrorValidator.evWrongInputList,
4158 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004159 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004160 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004161 },
4162 "tile": {
4163 "op": Op.TILE,
4164 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004165 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004166 "build_fcn": (
4167 build_tile,
4168 TosaTensorGen.tgBasic,
4169 TosaTensorValuesGen.tvgDefault,
4170 TosaArgGen.agTile,
4171 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004172 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004173 "error_if_validators": (
4174 TosaErrorValidator.evWrongInputType,
4175 TosaErrorValidator.evWrongOutputType,
4176 TosaErrorValidator.evWrongInputList,
4177 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004178 TosaErrorValidator.evRankMismatch,
4179 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004180 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004181 },
4182 "transpose": {
4183 "op": Op.TRANSPOSE,
4184 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004185 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004186 "build_fcn": (
4187 build_transpose,
4188 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004189 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004190 TosaArgGen.agTranspose,
4191 ),
4192 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004193 "error_if_validators": (
4194 TosaErrorValidator.evIndexOutsideBounds,
4195 TosaErrorValidator.evIndexUsedTwice,
4196 TosaErrorValidator.evWrongInputType,
4197 TosaErrorValidator.evWrongOutputType,
4198 TosaErrorValidator.evWrongInputList,
4199 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004200 TosaErrorValidator.evWrongRank,
4201 TosaErrorValidator.evRankMismatch,
4202 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004203 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004204 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004205 # Data nodes
4206 "const": {
4207 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004208 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004209 "build_fcn": (
4210 build_const,
4211 TosaTensorGen.tgBasic,
4212 TosaTensorValuesGen.tvgDefault,
4213 None,
4214 ),
Luke Hutton65872422023-02-20 10:33:04 +00004215 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004216 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004217 "identity": {
4218 "op": Op.IDENTITY,
4219 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004220 "build_fcn": (
4221 build_unary,
4222 TosaTensorGen.tgBasic,
4223 TosaTensorValuesGen.tvgDefault,
4224 None,
4225 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004226 "types": TYPE_FIB,
4227 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004228 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004229 "gather": {
4230 "op": Op.GATHER,
4231 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4232 "operands": (1, 0),
4233 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004234 "build_fcn": (
4235 build_gather,
4236 TosaTensorGen.tgBasic,
4237 TosaTensorValuesGen.tvgDefault,
4238 None,
4239 ),
James Ward24dbc422022-10-19 12:20:31 +01004240 "types": (
4241 DType.INT8,
4242 DType.INT16,
4243 DType.INT32,
4244 DType.FP16,
4245 DType.BF16,
4246 DType.FP32,
4247 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004248 "error_if_validators": (
4249 TosaErrorValidator.evWrongInputType,
4250 TosaErrorValidator.evWrongOutputType,
4251 TosaErrorValidator.evWrongInputList,
4252 TosaErrorValidator.evWrongOutputList,
4253 TosaErrorValidator.evWrongRank,
4254 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004255 },
4256 "scatter": {
4257 "op": Op.SCATTER,
4258 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004259 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004260 "operands": (2, 0),
4261 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004262 "build_fcn": (
4263 build_scatter,
4264 TosaTensorGen.tgScatter,
4265 TosaTensorValuesGen.tvgDefault,
4266 None,
4267 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004268 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004269 "error_if_validators": (
4270 TosaErrorValidator.evWrongInputType,
4271 TosaErrorValidator.evWrongOutputType,
4272 TosaErrorValidator.evWrongInputList,
4273 TosaErrorValidator.evWrongOutputList,
4274 TosaErrorValidator.evWrongRank,
4275 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004276 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004277 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004278 "resize": {
4279 "op": Op.RESIZE,
4280 "operands": (1, 0),
4281 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004282 "build_fcn": (
4283 build_resize,
4284 TosaTensorGen.tgNHWC,
4285 TosaTensorValuesGen.tvgDefault,
4286 TosaArgGen.agResize,
4287 ),
James Ward24dbc422022-10-19 12:20:31 +01004288 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004289 "invalid_test_validators": (
4290 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004291 ),
4292 "error_if_validators": (
4293 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004294 TosaErrorValidator.evScaleSmallerEqualZero,
4295 TosaErrorValidator.evScaleNLargerMax,
4296 TosaErrorValidator.evScaleDLargerMax,
4297 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004298 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004299 TosaErrorValidator.evBorderSmallerMin,
4300 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004301 TosaErrorValidator.evWrongInputType,
4302 TosaErrorValidator.evWrongOutputType,
4303 TosaErrorValidator.evWrongRank,
4304 TosaErrorValidator.evWrongInputList,
4305 TosaErrorValidator.evWrongOutputList,
4306 TosaErrorValidator.evBatchMismatch,
4307 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004308 TosaErrorValidator.evResizeOutputShapeMismatch,
4309 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004310 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004311 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004312 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004313 "cast": {
4314 "op": Op.CAST,
4315 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004316 "build_fcn": (
4317 build_cast,
4318 TosaTensorGen.tgBasic,
4319 TosaTensorValuesGen.tvgDefault,
4320 TosaArgGen.agCast,
4321 ),
James Ward8b390432022-08-12 20:48:56 +01004322 "types": (
4323 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004324 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004325 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004326 DType.INT8,
4327 DType.INT16,
4328 DType.INT32,
4329 DType.BOOL,
4330 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004331 "error_if_validators": (
4332 TosaErrorValidator.evWrongInputType,
4333 TosaErrorValidator.evWrongOutputType,
4334 TosaErrorValidator.evWrongInputList,
4335 TosaErrorValidator.evWrongOutputList,
4336 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004337 },
4338 "rescale": {
4339 "op": Op.RESCALE,
4340 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004341 "build_fcn": (
4342 build_rescale,
4343 TosaTensorGen.tgBasic,
4344 TosaTensorValuesGen.tvgDefault,
4345 TosaArgGen.agRescale,
4346 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004347 "types": [
4348 DType.UINT8,
4349 DType.INT8,
4350 DType.INT16,
4351 DType.INT32,
4352 DType.INT48,
4353 DType.UINT16,
4354 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004355 "error_if_validators": (
4356 TosaErrorValidator.evInputZeroPointNotZero,
4357 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004358 TosaErrorValidator.evU16InputZeroPointNotValid,
4359 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004360 TosaErrorValidator.evScaleTrue,
4361 TosaErrorValidator.evScaleNotTrue,
4362 TosaErrorValidator.evWrongInputType,
4363 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004364 TosaErrorValidator.evWrongInputList,
4365 TosaErrorValidator.evWrongOutputList,
4366 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004367 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004368 # Custom
4369 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004370 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004371 # Two varients of cond_if, one that generates one of two constant tensors (no
4372 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4373 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004374 "cond_if_const": {
4375 "op": Op.COND_IF,
4376 "operands": (0, 2),
4377 "build_fcn": (
4378 build_cond_if_const,
4379 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004380 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004381 TosaArgGen.agCondIf,
4382 ),
4383 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004384 "error_if_validators": (
4385 TosaErrorValidator.evOutputListThenGraphMismatch,
4386 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004387 TosaErrorValidator.evCondIfCondNotMatchingBool,
4388 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004389 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004390 },
4391 "cond_if_binary": {
4392 "op": Op.COND_IF,
4393 "operands": (2, 0),
4394 "build_fcn": (
4395 build_cond_if_binary,
4396 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004397 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004398 TosaArgGen.agCondIf,
4399 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004400 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004401 "error_if_validators": (
4402 TosaErrorValidator.evInputListThenGraphMismatch,
4403 TosaErrorValidator.evInputListElseGraphMismatch,
4404 TosaErrorValidator.evOutputListThenGraphMismatch,
4405 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004406 TosaErrorValidator.evCondIfCondNotMatchingBool,
4407 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004408 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004409 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004410 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004411 "while_loop": {
4412 "op": Op.WHILE_LOOP,
4413 "operands": (0, 1),
4414 "build_fcn": (
4415 build_while_loop,
4416 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004417 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004418 TosaArgGen.agWhileLoop,
4419 ),
4420 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004421 "error_if_validators": (
4422 TosaErrorValidator.evInputListOutputListMismatch,
4423 TosaErrorValidator.evInputListCondGraphMismatch,
4424 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4425 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4426 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004427 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004428 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004429 },
Luke Hutton57287132023-02-06 14:54:18 +00004430 "fft2d": {
4431 "op": Op.FFT2D,
4432 "operands": (2, 0),
4433 "rank": (3, 3),
4434 "build_fcn": (
4435 build_fft2d,
4436 TosaTensorGen.tgFFT2d,
4437 TosaTensorValuesGen.tvgDefault,
4438 TosaArgGen.agFFT2d,
4439 ),
4440 "types": [DType.FP32],
4441 "error_if_validators": (
4442 TosaErrorValidator.evWrongInputType,
4443 TosaErrorValidator.evWrongOutputType,
4444 TosaErrorValidator.evWrongInputList,
4445 TosaErrorValidator.evWrongOutputList,
4446 TosaErrorValidator.evWrongRank,
4447 TosaErrorValidator.evBatchMismatch,
4448 TosaErrorValidator.evKernelNotPowerOfTwo,
4449 TosaErrorValidator.evFFTInputShapeMismatch,
4450 TosaErrorValidator.evFFTOutputShapeMismatch,
4451 ),
4452 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004453 "rfft2d": {
4454 "op": Op.RFFT2D,
4455 "operands": (1, 0),
4456 "rank": (3, 3),
4457 "build_fcn": (
4458 build_rfft2d,
4459 TosaTensorGen.tgRFFT2d,
4460 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004461 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004462 ),
4463 "types": [DType.FP32],
4464 "error_if_validators": (
4465 TosaErrorValidator.evWrongInputType,
4466 TosaErrorValidator.evWrongOutputType,
4467 TosaErrorValidator.evWrongInputList,
4468 TosaErrorValidator.evWrongOutputList,
4469 TosaErrorValidator.evWrongRank,
4470 TosaErrorValidator.evBatchMismatch,
4471 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004472 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004473 ),
4474 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004475 }
4476
Kevin Cheng550ccc52021-03-03 11:21:43 -08004477
Eric Kunzee5e26762020-10-13 16:11:07 -07004478class OutputShaper:
4479 # Methods in this class compute the expected output shape and datatype
4480 # for common classes of operations
4481 def __init__(self):
4482 pass
4483
4484 # These methods return arguments that can be used for
4485 # creating a new output tensor
4486 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004487 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4488 if error_name != ErrorIf.RankMismatch:
4489 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004490 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004491
4492 shape = []
4493 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004494 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004495 shape.append(b.shape[i])
4496 else:
4497 shape.append(a.shape[i])
4498
Jerry Ge135c9552023-05-23 20:59:32 +00004499 fuzz_idx = rng.integers(0, len(a.shape))
4500 if error_name == ErrorIf.DimensionMismatch:
4501 shape[fuzz_idx] += 1
4502
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004503 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004504 all_dtypes = [
4505 DType.INT8,
4506 DType.INT16,
4507 DType.INT32,
4508 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004509 DType.FP16,
4510 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004511 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004512 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004513 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4514 outputDType = rng.choice(wrong_dtypes)
4515 else:
4516 outputDType = a.dtype
4517
4518 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004519
4520 @staticmethod
4521 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004522 assert len(a.shape) == len(b.shape)
4523 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004524
4525 shape = []
4526 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004527 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004528 shape.append(a.shape[i])
4529
Kevin Cheng550ccc52021-03-03 11:21:43 -08004530 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004531
4532 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004533 def unaryOp(ser, rng, a, error_name=None):
4534 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004535 all_dtypes = [
4536 DType.INT8,
4537 DType.INT16,
4538 DType.INT32,
4539 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004540 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004541 DType.FP16,
4542 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004543 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004544 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4545 outputDType = rng.choice(wrong_dtypes)
4546 else:
4547 outputDType = a.dtype
4548
4549 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004550
4551 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004552 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004553 if error_name != ErrorIf.RankMismatch:
4554 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004555 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004556
4557 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004558 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004559 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004560 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4561 else:
4562 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004563
Jerry Ge135c9552023-05-23 20:59:32 +00004564 fuzz_idx = rng.integers(0, len(a.shape))
4565 if error_name == ErrorIf.DimensionMismatch:
4566 shape[fuzz_idx] += 1
4567
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004568 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004569 all_dtypes = [
4570 DType.INT8,
4571 DType.INT16,
4572 DType.INT32,
4573 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004574 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004575 DType.FP16,
4576 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004577 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004578 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4579 outputDType = rng.choice(wrong_dtypes)
4580 else:
4581 outputDType = a.dtype
4582
4583 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004584
4585 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004586 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004587 if error_name != ErrorIf.RankMismatch:
4588 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004589 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004590
4591 # Do broadcast
4592 shape = []
4593 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004594 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004595 shape.append(b.shape[i])
4596 else:
4597 shape.append(a.shape[i])
4598
Jerry Ge135c9552023-05-23 20:59:32 +00004599 fuzz_idx = rng.integers(0, len(a.shape))
4600 if error_name == ErrorIf.DimensionMismatch:
4601 shape[fuzz_idx] += 1
4602
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004603 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004604 wrong_dtypes = [
4605 DType.INT8,
4606 DType.INT16,
4607 DType.INT32,
4608 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004609 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004610 DType.FP16,
4611 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004612 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004613 outputDType = rng.choice(wrong_dtypes)
4614 else:
4615 outputDType = DType.BOOL
4616
4617 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004618
4619 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004620 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004621 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004622 if error_name not in [
4623 ErrorIf.AxisSmallerZero,
4624 ErrorIf.AxisLargerRank,
4625 ErrorIf.ShapeOfAxisNotOne,
4626 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004627 shape[axis] = 1
4628 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4629 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004630
Matthew Haddond6ce7252021-09-29 15:35:44 +01004631 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004632 all_dtypes = [
4633 DType.INT8,
4634 DType.INT16,
4635 DType.INT32,
4636 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004637 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004638 DType.FP16,
4639 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004640 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004641 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4642 outputDType = rng.choice(wrong_dtypes)
4643 else:
4644 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004645
Matthew Haddond6ce7252021-09-29 15:35:44 +01004646 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004647
4648 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004649 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004650 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004651
4652 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4653 del shape[axis]
4654
4655 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4656 remove = rng.choice([True, False])
4657 if remove and len(shape) > 1:
4658 del shape[0]
4659 else:
4660 shape.append(1)
4661 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4662 for i in range(len(shape)):
4663 shape[i] = shape[i] + rng.integers(1, 10)
4664
4665 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004666 all_dtypes = [
4667 DType.INT8,
4668 DType.INT16,
4669 DType.INT32,
4670 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004671 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004672 DType.FP16,
4673 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004674 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004675 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4676 outputDType = rng.choice(wrong_dtypes)
4677 else:
4678 outputDType = DType.INT32
4679
4680 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004681
4682 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004683 def conv2dOp(
4684 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4685 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004686
4687 # IFM: NHWC
4688 # Filter: OHWI
4689 # OFM: NHWC
4690
Kevin Cheng550ccc52021-03-03 11:21:43 -08004691 h = (
4692 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004693 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004694 + padding[0]
4695 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004696 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004697 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004698
Kevin Cheng550ccc52021-03-03 11:21:43 -08004699 w = (
4700 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004701 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004702 + padding[2]
4703 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004704 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004705 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004706
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004707 if error_name == ErrorIf.ConvOutputShapeMismatch:
4708 choices = [1, 2, 3]
4709 change = rng.choice(choices)
4710 # increment in multiples of stride to not hit non-integer error case
4711 if change in [1, 3]:
4712 h = h + (rng.choice(choices) * strides[0])
4713 if change in [2, 3]:
4714 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004715
Eric Kunzee5e26762020-10-13 16:11:07 -07004716 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4717
James Ward8b390432022-08-12 20:48:56 +01004718 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004719 # Pick some potentially correct output dtype if input type is incorrect
4720 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004721 else:
James Ward8b390432022-08-12 20:48:56 +01004722 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004723
4724 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004725 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004726 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004727 else:
4728 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004729 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004730 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004731
Kevin Cheng550ccc52021-03-03 11:21:43 -08004732 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004733
4734 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004735 def conv3dOp(
4736 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4737 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004738
4739 # IFM: NDHWC
4740 # Filter: ODHWI
4741 # OFM: NDHWC
4742
4743 d = (
4744 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004745 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004746 + padding[0]
4747 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004748 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004749 ) // strides[0] + 1
4750
4751 h = (
4752 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004753 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004754 + padding[2]
4755 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004756 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004757 ) // strides[1] + 1
4758
4759 w = (
4760 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004761 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004762 + padding[4]
4763 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004764 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004765 ) // strides[2] + 1
4766
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004767 if error_name == ErrorIf.ConvOutputShapeMismatch:
4768 choices = [1, 2, 3, 4]
4769 change = rng.choice(choices)
4770 # increment in multiples of stride to not hit non-integer error case
4771 if change in [1, 4]:
4772 d = d + (rng.choice(choices) * strides[0])
4773 if change in [2, 4]:
4774 h = h + (rng.choice(choices) * strides[1])
4775 if change in [3, 4]:
4776 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004777
Kevin Cheng1533b852021-09-01 12:51:58 -07004778 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4779
James Ward8b390432022-08-12 20:48:56 +01004780 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004781 # Pick some potentially correct output dtype if input type is incorrect
4782 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004783 else:
James Ward8b390432022-08-12 20:48:56 +01004784 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004785
4786 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004787 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004788 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004789 else:
4790 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004791 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004792 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004793
4794 return ser.addOutput(ofm_shape, out_dtype)
4795
4796 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004797 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004798 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004799 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004800 # IFM: NHWC
4801 # Filter: HWCM
4802 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004803
Kevin Cheng550ccc52021-03-03 11:21:43 -08004804 h = (
4805 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004806 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004807 + padding[0]
4808 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004809 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004810 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004811
Kevin Cheng550ccc52021-03-03 11:21:43 -08004812 w = (
4813 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004814 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004815 + padding[2]
4816 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004817 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004818 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004819
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004820 if error_name == ErrorIf.ConvOutputShapeMismatch:
4821 choices = [1, 2, 3]
4822 change = rng.choice(choices)
4823 # increment in multiples of stride to not hit non-integer error case
4824 if change in [1, 3]:
4825 h = h + (rng.choice(choices) * strides[0])
4826 if change in [2, 3]:
4827 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004828
Eric Kunzee5e26762020-10-13 16:11:07 -07004829 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4830
James Ward8b390432022-08-12 20:48:56 +01004831 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004832 # Pick some potentially correct output dtype if input type is incorrect
4833 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004834 else:
James Ward8b390432022-08-12 20:48:56 +01004835 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004836
4837 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004838 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004839 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004840 else:
4841 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004842 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004843 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004844
Kevin Cheng550ccc52021-03-03 11:21:43 -08004845 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004846
4847 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004848 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004849 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004850 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004851 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004852 h = 1
4853 w = 1
4854 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004855 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4856 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004857
4858 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004859 choices = [1, 2, 3]
4860 change = rng.choice(choices)
4861 # increment in multiples of stride to not hit non-integer error case
4862 if change in [1, 3]:
4863 h = h + (rng.choice(choices) * stride[0])
4864 if change in [2, 3]:
4865 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004866 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004867
4868 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004869 all_dtypes = [
4870 DType.INT8,
4871 DType.INT16,
4872 DType.INT32,
4873 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004874 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004875 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004876 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004877 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004878 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4879 outputDType = rng.choice(wrong_dtypes)
4880 else:
4881 outputDType = ifm.dtype
4882
4883 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004884
4885 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004886 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004887 # input: N, IC
4888 # filter: OC, IC
4889 # output: N, OC
4890
4891 output_shape = [input.shape[0], filter.shape[0]]
4892
James Ward8b390432022-08-12 20:48:56 +01004893 # Validated in arg_gen (also invalidated for ErrorIf)
4894 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004895
Kevin Cheng550ccc52021-03-03 11:21:43 -08004896 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004897
4898 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004899 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004900 # a: N, H, C
4901 # b: N, C, W
4902 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004903
Kevin Cheng2d60f002021-06-09 14:18:32 -07004904 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004905
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004906 if error_name == ErrorIf.WrongOutputType:
4907 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004908 incorrect_types = (
4909 DType.INT4,
4910 DType.INT8,
4911 DType.INT16,
4912 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004913 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004914 DType.FP16,
4915 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004916 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004917 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004918 incorrect_types = (
4919 DType.INT4,
4920 DType.INT8,
4921 DType.INT16,
4922 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004923 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004924 DType.FP16,
4925 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004926 )
James Ward24dbc422022-10-19 12:20:31 +01004927 elif (
4928 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4929 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004930 incorrect_types = (
4931 DType.INT4,
4932 DType.INT8,
4933 DType.INT16,
4934 DType.INT32,
4935 DType.INT48,
4936 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004937 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004938 elif error_name == ErrorIf.WrongInputType:
4939 # Pick some potentially correct output dtype if input type is incorrect
4940 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004941 else:
James Ward8b390432022-08-12 20:48:56 +01004942 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004943
Kevin Cheng550ccc52021-03-03 11:21:43 -08004944 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004945
4946 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004947 def concatOp(ser, rng, axis, inputs, error_name=None):
4948 input1 = inputs[0]
4949 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004950
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004951 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004952 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004953 if not (
4954 # unable to concat tensors of different ranks
4955 error_name == ErrorIf.ConcatInputRankMismatch
4956 # unable to concat tensors along an invalid axis
4957 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004958 ):
4959 for tensor in remaining_inputs:
4960 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004961
Matthew Haddon01c359d2021-10-15 16:30:48 +01004962 if error_name == ErrorIf.ConcatShapeSumMismatch:
4963 output_shape[axis] += rng.integers(5, 10)
4964
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004965 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004966 all_dtypes = {
4967 DType.INT8,
4968 DType.INT16,
4969 DType.INT32,
4970 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004971 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004972 DType.FP16,
4973 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004974 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004975 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4976 outputDType = rng.choice(wrong_dtypes)
4977 else:
4978 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004979
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004980 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004981
4982 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004983 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004984
4985 output_shape = a.shape.copy()
4986
4987 for i in range(len(output_shape)):
4988 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4989
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004990 if error_name == ErrorIf.PadOutputShapeMismatch:
4991 bad_dim = rng.choice(range(len(output_shape)))
4992 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004993 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004994 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004995
Matthew Haddone807aae2021-10-11 18:12:58 +01004996 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004997 all_dtypes = [
4998 DType.INT8,
4999 DType.INT16,
5000 DType.INT32,
5001 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005002 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005003 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005004 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005005 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005006 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5007 outputDType = rng.choice(wrong_dtypes)
5008 else:
5009 outputDType = a.dtype
5010
5011 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005012
5013 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005014 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00005015 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00005016
5017 if error_name == ErrorIf.WrongOutputType:
5018 all_dtypes = [
5019 DType.INT8,
5020 DType.INT16,
5021 DType.INT32,
5022 DType.INT48,
5023 DType.FP32,
5024 DType.FP16,
5025 DType.BF16,
5026 ]
5027 wrong_dtypes = list(set(all_dtypes))
5028 outputDType = rng.choice(wrong_dtypes)
5029 else:
5030 outputDType = DType.SHAPE
5031
5032 return ser.addOutput(output_shape, outputDType)
5033
5034 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005035 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005036 output_shape = shape.copy()
5037
Matthew Haddone807aae2021-10-11 18:12:58 +01005038 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5039 for i in range(len(output_shape)):
5040 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5041
5042 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005043 all_dtypes = [
5044 DType.INT8,
5045 DType.INT16,
5046 DType.INT32,
5047 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005048 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005049 DType.FP16,
5050 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005051 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005052 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5053 outputDType = rng.choice(wrong_dtypes)
5054 else:
5055 outputDType = a.dtype
5056
5057 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005058
5059 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005060 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005061
Matthew Haddone807aae2021-10-11 18:12:58 +01005062 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005063 all_dtypes = [
5064 DType.INT8,
5065 DType.INT16,
5066 DType.INT32,
5067 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005068 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005069 DType.FP16,
5070 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005071 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005072 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005073 outputDType = rng.choice(wrong_dtypes)
5074 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005075 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005076
Luke Huttona4e48ca2023-02-22 11:53:48 +00005077 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005078 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005079 for index in range(len(output_shape)):
5080 if output_shape[index] <= 2:
5081 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5082 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005083 output_shape[index] = output_shape[index] + rng.choice(
5084 [-2, -1, 1, 2]
5085 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005086 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5087 output_shape = input.shape.copy()
5088 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005089 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005090
5091 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005092
5093 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005094 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005095
5096 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005097 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005098
5099 for i in range(len(output_shape)):
5100 output_shape[i] = a.shape[i] * multiples[i]
5101
Luke Huttona4e48ca2023-02-22 11:53:48 +00005102 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005103 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005104
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005105 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005106 all_dtypes = [
5107 DType.INT8,
5108 DType.INT16,
5109 DType.INT32,
5110 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005111 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005112 DType.FP16,
5113 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005114 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005115 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5116 outputDType = rng.choice(wrong_dtypes)
5117 else:
5118 outputDType = a.dtype
5119
5120 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005121
5122 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005123 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005124 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005125
Kevin Cheng550ccc52021-03-03 11:21:43 -08005126 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005127
Luke Huttona4e48ca2023-02-22 11:53:48 +00005128 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005129 for i in range(len(output_shape)):
5130 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005131
Luke Huttona4e48ca2023-02-22 11:53:48 +00005132 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5133 for i in range(len(output_shape)):
5134 output_shape[i] += rng.integers(1, 10)
5135 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005136 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005137
Matthew Haddone807aae2021-10-11 18:12:58 +01005138 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005139 all_dtypes = [
5140 DType.INT8,
5141 DType.INT16,
5142 DType.INT32,
5143 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005144 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005145 DType.FP16,
5146 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005147 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005148 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5149 outputDType = rng.choice(wrong_dtypes)
5150 else:
5151 outputDType = a.dtype
5152
5153 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005154
5155 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005156 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005157 if error_name != ErrorIf.WrongRank:
5158 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005159 assert len(indices.shape) == 2
5160 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005161
Kevin Cheng77d0f762020-11-24 10:26:32 -08005162 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5163
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005164 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005165 all_dtypes = [
5166 DType.INT8,
5167 DType.INT16,
5168 DType.INT32,
5169 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005170 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005171 DType.FP16,
5172 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005173 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005174 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5175 outputDType = rng.choice(wrong_dtypes)
5176 else:
5177 outputDType = values.dtype
5178
5179 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005180
5181 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005182 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005183 if error_name != ErrorIf.WrongRank:
5184 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005185 assert len(indices.shape) == 2
5186 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005187 assert values_in.shape[0] == indices.shape[0] # N
5188 assert input.shape[1] == indices.shape[1] # W
5189 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005190
5191 output_shape = values_in.shape
5192
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005193 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005194 all_dtypes = [
5195 DType.INT8,
5196 DType.INT16,
5197 DType.INT32,
5198 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005199 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005200 DType.FP16,
5201 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005202 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005203 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5204 outputDType = rng.choice(wrong_dtypes)
5205 else:
5206 outputDType = values_in.dtype
5207
5208 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005209
5210 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005211 def tableOp(ser, rng, input, error_name=None):
5212 # Same shape as the input, dtype dependent on input dtype
5213 if error_name != ErrorIf.WrongInputType:
5214 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005215 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005216 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005217 wrong_dtypes = [
5218 DType.INT8,
5219 DType.INT16,
5220 DType.INT32,
5221 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005222 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005223 DType.FP16,
5224 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005225 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005226 wrong_dtypes.remove(output_dtype)
5227 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005228 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005229
5230 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005231 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005232 serializer,
5233 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005234 input,
5235 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005236 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005237 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005238 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005239 input_dtype,
5240 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005241 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005242 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005243 # Calculate OH, OW
5244 scale_y_n = scale[0]
5245 scale_y_d = scale[1]
5246 scale_x_n = scale[2]
5247 scale_x_d = scale[3]
5248 if error_name == ErrorIf.ScaleSmallerEqualZero:
5249 scale_y_n = max(scale_y_n, 1)
5250 scale_y_d = max(scale_y_d, 1)
5251 scale_x_n = max(scale_x_n, 1)
5252 scale_x_d = max(scale_x_d, 1)
5253
5254 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5255 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5256
5257 if error_name is not None:
5258 # Make sure the output tensor is valid, which can occur when
5259 # scale, offset or border have been changed for ERROR_IFs
5260 oh = max(oh, 1)
5261 ow = max(ow, 1)
5262 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005263 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5264 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005265
5266 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5267 choices = [1, 2, 3]
5268 change = rng.choice(choices)
5269 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5270 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005271 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005272 oh -= scale_y_d
5273 assert oh > 0 # Should have been caught in agResize
5274 else:
5275 oh += scale_y_d
5276 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005277 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005278 ow -= scale_x_d
5279 assert ow > 0 # Should have been caught in agResize
5280 else:
5281 ow += scale_x_d
5282
Matthew Haddon848efb42021-09-09 12:30:53 +01005283 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005284 output_dims = [
5285 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005286 oh,
5287 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005288 input.shape[0],
5289 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005290 elif error_name == ErrorIf.BatchMismatch:
5291 output_dims = [
5292 input.shape[0] + rng.integers(1, 10),
5293 oh,
5294 ow,
5295 input.shape[3],
5296 ]
5297 elif error_name == ErrorIf.ChannelMismatch:
5298 output_dims = [
5299 input.shape[0],
5300 oh,
5301 ow,
5302 input.shape[3] + rng.integers(1, 10),
5303 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005304 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005305 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005306
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005307 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005308
5309 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005310 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005311 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005312
5313 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005314 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005315 if error_name == ErrorIf.ConvOutputShapeMismatch:
5316 choices = [1, 2, 3]
5317 change = rng.choice(choices)
5318 if change in [1, 3]:
5319 output_shape[1] = output_shape[1] + rng.choice(choices)
5320 if change in [2, 3]:
5321 output_shape[2] = output_shape[2] + rng.choice(choices)
5322
James Ward8b390432022-08-12 20:48:56 +01005323 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005324 # Pick some potentially correct output dtype if input type is incorrect
5325 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005326 else:
James Ward8b390432022-08-12 20:48:56 +01005327 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005328
5329 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005330 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005331 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005332 else:
5333 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005334 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005335 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005336
Kevin Cheng550ccc52021-03-03 11:21:43 -08005337 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005338
5339 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005340 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5341 outputs = []
5342
5343 assert ifm1.dtype == ifm2.dtype
5344 input_dtype = ifm1.dtype
5345
5346 if error_name != ErrorIf.FFTInputShapeMismatch:
5347 assert ifm1.shape == ifm2.shape
5348
5349 input_shape = ifm1.shape
5350 if error_name != ErrorIf.WrongRank:
5351 assert len(input_shape) == 3
5352
5353 output_shape = input_shape.copy()
5354 output_dtype = input_dtype
5355
5356 if error_name == ErrorIf.WrongOutputType:
5357 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005358 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005359 output_dtype = rng.choice(wrong_dtypes)
5360 elif error_name == ErrorIf.BatchMismatch:
5361 output_shape[0] += rng.integers(1, 10)
5362 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5363 modify_dim = rng.choice([1, 2])
5364 output_shape[modify_dim] += rng.integers(1, 10)
5365
5366 outputs.append(serializer.addOutput(output_shape, output_dtype))
5367 outputs.append(serializer.addOutput(output_shape, output_dtype))
5368 return outputs
5369
5370 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005371 def rfft2dOp(serializer, rng, value, error_name=None):
5372 outputs = []
5373
5374 input_shape = value.shape
5375 if error_name != ErrorIf.WrongRank:
5376 assert len(input_shape) == 3
5377
5378 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5379
5380 output_dtype = value.dtype
5381 if error_name == ErrorIf.WrongOutputType:
5382 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005383 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005384 output_dtype = rng.choice(wrong_dtypes)
5385 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005386 output_shape[0] += rng.integers(1, 10)
5387 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5388 modify_dim = rng.choice([1, 2])
5389 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005390
5391 outputs.append(serializer.addOutput(output_shape, output_dtype))
5392 outputs.append(serializer.addOutput(output_shape, output_dtype))
5393 return outputs