blob: 556a0d89835b8147775b9a6977029033f347bb1c [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
40 TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6)
41 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
69 vals.append(v)
70 return tuple(sorted(vals))
71
72 self.random_float_range = {}
73 for dtype in (DType.FP32, DType.FP16, DType.BF16):
74 self.random_float_range[dtype] = convertFPRange(
75 args.tensor_fp_value_range,
76 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
77 )
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 def createSerializer(self, opName, testPath):
80 self.testPath = os.path.join(opName, testPath)
81
82 fullPath = os.path.join(self.basePath, self.testPath)
83 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010084 # Embed const data in the flatbuffer
85 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010086 if self.args.lazy_data_gen:
87 # Lazy data generation - so make constants files
88 constMode = ts.ConstMode.INPUTS
89 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 constMode = ts.ConstMode.EMBED_DUMP
91 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070092
93 def getSerializer(self):
94 return self.ser
95
Jeremy Johnson1271c442023-09-05 11:39:26 +010096 def serialize(self, testName, metaData=None):
97 path = Path(self.basePath) / self.testPath
98
99 # Write out TOSA flatbuffer binary
100 path_fb = path / f"{testName}.tosa"
101 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 fd.write(self.ser.serialize())
103
Jeremy Johnson1271c442023-09-05 11:39:26 +0100104 # Get JSON descriptor from serializer
105 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
106
107 if metaData:
108 # Add extra meta data to desc.json
109 desc["meta"] = metaData
110
111 # Validate desc.json before we output it
112 self.descSchemaValidator.validate_config(desc)
113
114 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100115 if "data_gen" in metaData:
116 if self.args.lazy_data_gen:
117 # Output datagen meta data as CPP data
118 path_md = path / f"{testName}_meta_data_gen.cpp"
119 with path_md.open("w") as fd:
120 fd.write(TOSA_AUTOGENERATED_HEADER)
121 fd.write("// Test meta data for data generation setup\n\n")
122 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
123 json.dump(metaData["data_gen"], fd)
124 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100125 if "compliance" in metaData:
126 # Output datagen meta data as CPP data
127 path_md = path / f"{testName}_meta_compliance.cpp"
128 with path_md.open("w") as fd:
129 fd.write(TOSA_AUTOGENERATED_HEADER)
130 fd.write("// Test meta data for compliance validation\n\n")
131 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
132 json.dump(metaData["compliance"], fd)
133 fd.write(')";\n\n')
134
135 # Write desc.json
136 path_desc = path / "desc.json"
137 with path_desc.open("w") as fd:
138 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700139
Matthew Haddon74567092021-07-16 15:38:20 +0100140 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000141 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100142 seed = self.random_seed + 1
143 self.rng = np.random.default_rng(seed)
144
Jeremy Johnson1271c442023-09-05 11:39:26 +0100145 def getDTypeRange(self, dtype, high_inclusive=False):
146 # Returns dtype value range boundaries (low, high)
147 # The high boundary is excluded in the range
148 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100149 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100150 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 elif dtype == DType.BOOL:
152 rng = (0, 2)
153 elif dtype == DType.UINT8:
154 rng = (0, 256)
155 elif dtype == DType.UINT16:
156 rng = (0, 65536)
157 elif dtype == DType.INT4:
158 # TOSA specific INT4 weight range from -7 to 7
159 rng = (-7, 8)
160 elif dtype == DType.INT8:
161 rng = (-128, 128)
162 elif dtype == DType.INT16:
163 rng = (-32768, 32768)
164 elif dtype in (DType.INT32, DType.SHAPE):
165 # restricting too large value for SHAPE
166 rng = (-(1 << 31), (1 << 31))
167 elif dtype == DType.INT48:
168 rng = (-(1 << 47), (1 << 47))
169 else:
170 raise Exception("Unknown dtype: {}".format(dtype))
171
172 if not high_inclusive:
173 # Exclusive high: low <= range < high
174 return rng
175 else:
176 # Inclusive range: low <= range <= high
177 return (rng[0], rng[1] - 1)
178
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 def getRandTensor(self, shape, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100180 low, high = self.getDTypeRange(dtype)
181
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700183 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700184 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100185 return np.int64(self.rng.integers(low=low, high=high, size=shape))
186 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
187 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
188
189 if dtype == DType.FP16:
190 return np.float16(f_tensor)
191 else:
192 f32_tensor = np.float32(f_tensor)
193 if dtype == DType.BF16:
194 # Floor the last 16 bits of each f32 value
195 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
196 else:
197 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700198 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 # All other integer types
200 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700201
Kevin Cheng989cb052021-04-28 16:29:44 -0700202 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700203 placeholders = []
204
Kevin Cheng989cb052021-04-28 16:29:44 -0700205 assert len(shape_list) == len(dtype_list)
206
Jeremy Johnson1271c442023-09-05 11:39:26 +0100207 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700208 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100209 if not self.args.lazy_data_gen:
210 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700212
213 return placeholders
214
Kevin Cheng989cb052021-04-28 16:29:44 -0700215 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700216 consts = []
217
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 assert len(shape_list) == len(dtype_list)
219
Jeremy Johnson1271c442023-09-05 11:39:26 +0100220 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700221 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100222 if not self.args.lazy_data_gen:
223 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700225
226 return consts
227
228 def makeShape(self, rank):
229 if self.targetted_shape:
230 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800231 return np.int32(
232 self.rng.integers(
233 low=self.args.tensor_shape_range[0],
234 high=self.args.tensor_shape_range[1],
235 size=rank,
236 )
237 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700238
239 def setTargetShape(self, shape):
240 self.targetted_shape = shape
241
242 def randInt(self, low=0, high=256):
243 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
244
245 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100246 low, high = self.getDTypeRange(dtype)
247
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100248 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100249 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100250 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100251 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100252 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100253 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
254 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700255 elif dtype == DType.BOOL:
256 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700257 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 # Special size
259 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700260
261 return np.int32(self.rng.integers(low, high, size=1))[0]
262
263 def shapeStr(self, shape):
264
265 sStr = []
266 # Convert to strings
267 for i in shape:
268 sStr.append(str(i))
269
Kevin Cheng550ccc52021-03-03 11:21:43 -0800270 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700271
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100272 def typeStr(self, dtype):
273 if isinstance(dtype, list) or isinstance(dtype, tuple):
274 assert len(dtype) >= 2
275 strs = [self.typeStr(t) for t in dtype]
276 # Limit types to the first 2 as the 3rd is the accumulator
277 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700278 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100279 if dtype in gtu.DTYPE_ATTRIBUTES:
280 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700281 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100282 raise Exception(
283 "Unknown dtype, cannot convert to string: {}".format(dtype)
284 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100286 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100287 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100288 if dtype in gtu.DTYPE_ATTRIBUTES:
289 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700290 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100291 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700292
Luke Hutton57287132023-02-06 14:54:18 +0000293 def constrictBatchSize(self, shape):
294 # Limit the batch size unless an explicit target shape set
295 if self.args.max_batch_size and not self.args.target_shapes:
296 shape[0] = min(shape[0], self.args.max_batch_size)
297 return shape
298
James Ward30124a82023-02-02 14:56:33 +0000299 def makeDimension(self):
300 return self.randInt(
301 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
302 )
303
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100304 def tensorComplianceMetaData(
305 self, op, inputType, argsDict, outputTensor, errorName
306 ):
307 if (
308 errorName
309 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
310 or not gtu.dtypeIsSupportedByCompliance(inputType)
311 ):
312 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100313 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100314
Jeremy Johnson1271c442023-09-05 11:39:26 +0100315 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100316 compliance_tens = {
317 "mode": None,
318 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
319 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
320 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100321 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
322 mode = gtu.ComplianceMode.DOT_PRODUCT
323 compliance_tens["dot_product_info"] = {
324 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100325 "ks": int(argsDict["ksb"])
326 if "ksb" in argsDict
327 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100328 }
329 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
330 mode = gtu.ComplianceMode.FP_SPECIAL
331 elif "compliance" in op and "ulp" in op["compliance"]:
332 mode = gtu.ComplianceMode.ULP
333 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
334 elif op["op"] == Op.REDUCE_PRODUCT:
335 mode = gtu.ComplianceMode.REDUCE_PRODUCT
336 else:
337 mode = gtu.ComplianceMode.EXACT
338 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
339
340 return compliance_tens
341
342 # Build Op functions
343 # Create the output tensor (calling OutputShaper as needed)
344 # Do final tweaks to attributes (if necessary for errorIf)
345 # Add Op into graph
346 # Return resulting tensor information or BuildInfo
347
348 class BuildInfo:
349 """Enhanced build information containing result tensor and associated compliance dict."""
350
351 def __init__(self, resultTensor, complianceDict):
352 self.resultTensor = resultTensor
353 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700354
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100355 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
356 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
357
Matthew Haddon848efb42021-09-09 12:30:53 +0100358 # build_placeholder returns an int, ABS/other ops does not
359 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000360 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100361 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000362 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000363 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100364 return result_tens
365
366 # Ensure new output type has correct qinfo
367 if error_name == ErrorIf.WrongOutputType:
368 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000369 qinfo = [
370 TosaQuantGen.getZeroPoint(self, a.dtype),
371 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
372 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100373
374 # Invalidate Input/Output list for error if checks.
375 input_list = [a.name]
376 output_list = [result_tens.name]
377 pCount, cCount = op["operands"]
378 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000379 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
380 self, error_name, input_list, output_list
381 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100382
Les Bell729b0352021-11-24 10:28:21 +0000383 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100384 self.ser,
385 validator_fcns,
386 error_name,
387 op=op,
388 input_dtype=a.dtype,
389 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000390 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000391 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100392 input_list=input_list,
393 output_list=output_list,
394 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000395 ):
396 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100397
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000398 attr = None
399 if op["op"] == Op.NEGATE:
400 attr = ts.TosaSerializerAttribute()
401 attr.NegateAttribute(qinfo[0], qinfo[1])
402
403 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700404 return result_tens
405
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000406 def build_binary_broadcast(
407 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
408 ):
409 assert len(inputs) == 2
410 a, b = inputs
411 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000412 self.ser, self.rng, a, b, error_name
413 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100414
415 # Invalidate Input/Output list for error if checks.
416 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000417 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100418 pCount, cCount = op["operands"]
419 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000420 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
421 self, error_name, input_list, output_list
422 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100423
Les Bell729b0352021-11-24 10:28:21 +0000424 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100425 self.ser,
426 validator_fcns,
427 error_name,
428 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000429 input1=a,
430 input2=b,
431 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000432 output_dtype=result_tensor.dtype,
433 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100434 input_list=input_list,
435 output_list=output_list,
436 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000437 ):
438 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100439
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000440 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000441
442 if op["op"] == Op.POW:
443 # TODO - add compliance support
444 compliance = None
445 else:
446 compliance = self.tensorComplianceMetaData(
447 op, a.dtype, args_dict, result_tensor, error_name
448 )
449
450 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700451
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100452 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700453 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000454 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700455 return result_tens
456
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000457 def build_arithmetic_right_shift(
458 self, op, a, b, round, validator_fcns=None, error_name=None
459 ):
460 result_tens = OutputShaper.binaryBroadcastOp(
461 self.ser, self.rng, a, b, error_name
462 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100463
464 # Invalidate Input/Output list for error if checks.
465 input_list = [a.name, b.name]
466 output_list = [result_tens.name]
467 pCount, cCount = op["operands"]
468 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000469 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
470 self, error_name, input_list, output_list
471 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100472
Les Bell729b0352021-11-24 10:28:21 +0000473 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100474 self.ser,
475 validator_fcns,
476 error_name,
477 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000478 input1=a,
479 input2=b,
480 input_dtype=a.dtype,
481 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000482 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100483 input_list=input_list,
484 output_list=output_list,
485 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000486 ):
487 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800488
489 attr = ts.TosaSerializerAttribute()
490 attr.ArithmeticRightShiftAttribute(round)
491
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000492 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800493 return result_tens
494
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100495 def build_mul(
496 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
497 ):
498 assert len(inputs) == 2
499 a, b = inputs
500 shift = args_dict["shift"]
501
502 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000503 self.ser, self.rng, a, b, error_name
504 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700505
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100506 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100507 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100508 result_tensor.setDtype(DType.INT32)
509
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100510 if error_name == ErrorIf.WrongOutputType:
511 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
512 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100513 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100514
515 # Invalidate Input/Output list for error if checks.
516 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100517 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100518 pCount, cCount = op["operands"]
519 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000520 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
521 self, error_name, input_list, output_list
522 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100523
Les Bell729b0352021-11-24 10:28:21 +0000524 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100525 self.ser,
526 validator_fcns,
527 error_name,
528 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000529 input1=a,
530 input2=b,
531 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100532 output_dtype=result_tensor.dtype,
533 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100534 input_list=input_list,
535 output_list=output_list,
536 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000537 ):
538 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700539
Kevin Chengaee1fac2020-11-11 13:54:06 -0800540 attr = ts.TosaSerializerAttribute()
541 attr.MulAttribute(shift)
542
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000543 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100544
545 compliance = self.tensorComplianceMetaData(
546 op, a.dtype, args_dict, result_tensor, error_name
547 )
548
549 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700550
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100551 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
552 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700553
Kevin Chengfe392ce2021-10-18 21:51:55 +0000554 attr = ts.TosaSerializerAttribute()
555 attr.TableAttribute(table)
556
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100557 # Invalidate Input/Output list for error if checks.
558 input_list = [a.name]
559 output_list = [result_tens.name]
560 pCount, cCount = op["operands"]
561 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000562 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
563 self, error_name, input_list, output_list
564 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100565
Les Bell729b0352021-11-24 10:28:21 +0000566 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100567 self.ser,
568 validator_fcns,
569 error_name,
570 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000571 input_shape=a.shape,
572 input_dtype=a.dtype,
573 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000574 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100575 input_list=input_list,
576 output_list=output_list,
577 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000578 ):
579 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100580
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000581 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700582
583 return result_tens
584
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100585 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
586 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
587
588 # Invalidate Input/Output list for error if checks.
589 input_list = [cond.name, a.name, b.name]
590 output_list = [result_tens.name]
591 pCount, cCount = op["operands"]
592 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000593 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
594 self, error_name, input_list, output_list
595 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100596
Les Bell729b0352021-11-24 10:28:21 +0000597 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100598 self.ser,
599 validator_fcns,
600 error_name,
601 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000602 input1=cond,
603 input2=a,
604 input3=b,
605 input_shape=a.shape,
606 input_dtype=a.dtype,
607 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000608 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100609 input_list=input_list,
610 output_list=output_list,
611 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000612 ):
613 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100614
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000615 self.ser.addOperator(
616 op["op"],
617 input_list,
618 output_list,
619 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700620 return result_tens
621
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000623 result_tens = OutputShaper.binaryComparisonOp(
624 self.ser, self.rng, a, b, error_name
625 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100626
627 # Invalidate Input/Output list for error if checks.
628 input_list = [a.name, b.name]
629 output_list = [result_tens.name]
630 pCount, cCount = op["operands"]
631 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000632 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
633 self, error_name, input_list, output_list
634 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100635
Les Bell729b0352021-11-24 10:28:21 +0000636 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100637 self.ser,
638 validator_fcns,
639 error_name,
640 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000641 input1=a,
642 input2=b,
643 input_shape=a.shape,
644 input_dtype=a.dtype,
645 output_shape=result_tens.shape,
646 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000647 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100648 input_list=input_list,
649 output_list=output_list,
650 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000651 ):
652 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100653
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000654 self.ser.addOperator(
655 op["op"],
656 input_list,
657 output_list,
658 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700659 return result_tens
660
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100661 def build_argmax(self, op, a, axis, validator_fcns, error_name):
662 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
663
664 # Invalidate Input/Output list for error if checks.
665 input_list = [a.name]
666 output_list = [result_tens.name]
667 pCount, cCount = op["operands"]
668 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000669 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
670 self, error_name, input_list, output_list
671 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100672
Les Bell729b0352021-11-24 10:28:21 +0000673 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100674 self.ser,
675 validator_fcns,
676 error_name,
677 op=op,
678 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000679 input_shape=a.shape,
680 input_dtype=a.dtype,
681 output_shape=result_tens.shape,
682 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000683 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100684 input_list=input_list,
685 output_list=output_list,
686 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000687 ):
688 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700689
690 attr = ts.TosaSerializerAttribute()
691 attr.AxisAttribute(axis)
692
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000693 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700694 return result_tens
695
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000696 def build_pool2d(
697 self,
698 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100699 inputs,
700 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000701 validator_fcns=None,
702 error_name=None,
703 qinfo=None,
704 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100705 assert len(inputs) == 1
706 input = inputs[0]
707 # max_pool has no accum_dtype
708 accum_dtype = (
709 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
710 )
711 stride = args_dict["stride"]
712 pad = args_dict["pad"]
713 kernel = args_dict["kernel"]
714
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000715 result_tens = OutputShaper.pool2dOp(
716 self.ser, self.rng, input, kernel, stride, pad, error_name
717 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100718
719 # Ensure new output type has correct qinfo
720 if error_name == ErrorIf.WrongInputType:
721 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000722 qinfo = [
723 TosaQuantGen.getZeroPoint(self, input.dtype),
724 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
725 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100726
727 # Invalidate Input/Output list for error if checks.
728 input_list = [input.name]
729 output_list = [result_tens.name]
730 pCount, cCount = op["operands"]
731 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000732 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
733 self, error_name, input_list, output_list
734 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100735
Les Bell729b0352021-11-24 10:28:21 +0000736 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100737 self.ser,
738 validator_fcns,
739 error_name,
740 op=op,
741 input_shape=input.shape,
742 input_dtype=input.dtype,
743 output_shape=result_tens.shape,
744 output_dtype=result_tens.dtype,
745 kernel=kernel,
746 stride=stride,
747 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000748 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000749 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100750 input_list=input_list,
751 output_list=output_list,
752 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000753 ):
754 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700755
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000756 if qinfo is None:
757 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700758
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000759 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100760 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000761
762 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700763 return result_tens
764
James Ward8b390432022-08-12 20:48:56 +0100765 def build_maxpool2d(
766 self,
767 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100768 inputs,
769 args_dict,
James Ward8b390432022-08-12 20:48:56 +0100770 validator_fcns=None,
771 error_name=None,
772 qinfo=None,
773 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100774 result_tensor = self.build_pool2d(
James Ward8b390432022-08-12 20:48:56 +0100775 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100776 inputs,
777 args_dict,
James Ward8b390432022-08-12 20:48:56 +0100778 validator_fcns,
779 error_name,
780 qinfo,
781 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100782 compliance = self.tensorComplianceMetaData(
783 op, inputs[0].dtype, args_dict, result_tensor, error_name
784 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100785
786 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100787
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000788 def build_conv2d(
789 self,
790 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100791 inputs,
792 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000793 validator_fcns=None,
794 error_name=None,
795 qinfo=None,
796 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100797 assert len(inputs) == 3
798 ifm, filter, bias = inputs
799 accum_dtype = args_dict["acc_type"]
800 strides = args_dict["stride"]
801 padding = args_dict["pad"]
802 dilations = args_dict["dilation"]
803
Kevin Cheng550ccc52021-03-03 11:21:43 -0800804 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100805 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100806 self.ser,
807 self.rng,
808 ifm,
809 filter,
810 accum_dtype,
811 strides,
812 padding,
813 dilations,
814 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000815 )
816
817 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000818 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
819 DType.INT8,
820 DType.UINT8,
821 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000822 qinfo = [
823 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100824 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000825 ]
Les Bell0e027d42021-11-09 14:42:14 +0000826
827 # Invalidate Input/Output list for error_if checks.
828 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100829 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000830 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000831 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
832 self, error_name, input_list, output_list
833 )
Les Bell0e027d42021-11-09 14:42:14 +0000834
Les Bell729b0352021-11-24 10:28:21 +0000835 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000836 self.ser,
837 validator_fcns,
838 error_name,
839 op=op,
840 input_dtype=ifm.dtype,
841 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100842 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000843 qinfo=qinfo,
844 input_list=input_list,
845 num_operands=num_operands,
846 output_list=output_list,
847 pad=padding,
848 stride=strides,
849 dilation=dilations,
850 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100851 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100852 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000853 ):
854 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700855
856 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000857 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700858
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000859 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100860
861 compliance = self.tensorComplianceMetaData(
862 op, ifm.dtype, args_dict, result_tensor, error_name
863 )
864
865 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700866
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000867 def build_conv3d(
868 self,
869 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100870 inputs,
871 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000872 validator_fcns=None,
873 error_name=None,
874 qinfo=None,
875 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100876 assert len(inputs) == 3
877 ifm, filter, bias = inputs
878 accum_dtype = args_dict["acc_type"]
879 strides = args_dict["stride"]
880 padding = args_dict["pad"]
881 dilations = args_dict["dilation"]
882
Kevin Cheng1533b852021-09-01 12:51:58 -0700883 assert len(padding) == 6
884 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100885 self.ser,
886 self.rng,
887 ifm,
888 filter,
889 accum_dtype,
890 strides,
891 padding,
892 dilations,
893 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000894 )
895
896 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000897 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
898 DType.INT8,
899 DType.UINT8,
900 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000901 qinfo = [
902 TosaQuantGen.getZeroPoint(self, ifm.dtype),
903 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
904 ]
Les Bell0e027d42021-11-09 14:42:14 +0000905
906 # Invalidate Input/Output list for error_if checks.
907 input_list = [ifm.name, filter.name, bias.name]
908 output_list = [result_tens.name]
909 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000910 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
911 self, error_name, input_list, output_list
912 )
Les Bell0e027d42021-11-09 14:42:14 +0000913
Les Bell729b0352021-11-24 10:28:21 +0000914 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000915 self.ser,
916 validator_fcns,
917 error_name,
918 op=op,
919 input_dtype=ifm.dtype,
920 weight_dtype=filter.dtype,
921 output_dtype=result_tens.dtype,
922 qinfo=qinfo,
923 input_list=input_list,
924 num_operands=num_operands,
925 output_list=output_list,
926 pad=padding,
927 stride=strides,
928 dilation=dilations,
929 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100930 weight_shape=filter.shape,
931 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000932 ):
933 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700934
935 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000936 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700937
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000938 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700939 return result_tens
940
Kevin Cheng550ccc52021-03-03 11:21:43 -0800941 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000942 self,
943 op,
944 ifm,
945 filter,
946 bias,
James Ward8b390432022-08-12 20:48:56 +0100947 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000948 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700949 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000950 output_shape,
951 validator_fcns=None,
952 error_name=None,
953 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800954 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700955 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000956 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100957 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000958 )
Les Bell0e027d42021-11-09 14:42:14 +0000959
960 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000961 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
962 DType.INT8,
963 DType.UINT8,
964 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000965 qinfo = [
966 TosaQuantGen.getZeroPoint(self, ifm.dtype),
967 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
968 ]
Les Bell0e027d42021-11-09 14:42:14 +0000969
970 # Invalidate Input/Output list for error_if checks.
971 input_list = [ifm.name, filter.name, bias.name]
972 output_list = [result_tens.name]
973 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000974 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
975 self, error_name, input_list, output_list
976 )
Les Bell0e027d42021-11-09 14:42:14 +0000977
Les Bell729b0352021-11-24 10:28:21 +0000978 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000979 self.ser,
980 validator_fcns,
981 error_name,
982 op=op,
983 input_dtype=ifm.dtype,
984 weight_dtype=filter.dtype,
985 output_dtype=result_tens.dtype,
986 qinfo=qinfo,
987 input_list=input_list,
988 num_operands=num_operands,
989 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700990 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000991 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000992 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100993 weight_shape=filter.shape,
994 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000995 ):
996 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700997
998 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000999 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07001000
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001001 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001002 return result_tens
1003
Kevin Cheng550ccc52021-03-03 11:21:43 -08001004 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001005 self,
1006 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001007 inputs,
1008 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001009 validator_fcns=None,
1010 error_name=None,
1011 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001012 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001013 assert len(inputs) == 3
1014 ifm, filter, bias = inputs
1015 accum_dtype = args_dict["acc_type"]
1016 strides = args_dict["stride"]
1017 padding = args_dict["pad"]
1018 dilations = args_dict["dilation"]
1019
Kevin Cheng550ccc52021-03-03 11:21:43 -08001020 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001021 self.ser,
1022 self.rng,
1023 ifm,
1024 filter,
1025 accum_dtype,
1026 strides,
1027 padding,
1028 dilations,
1029 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001030 )
1031
1032 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001033 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1034 DType.INT8,
1035 DType.UINT8,
1036 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001037 qinfo = [
1038 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1039 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1040 ]
Les Bell0e027d42021-11-09 14:42:14 +00001041
1042 # Invalidate Input/Output list for error_if checks.
1043 input_list = [ifm.name, filter.name, bias.name]
1044 output_list = [result_tens.name]
1045 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001046 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1047 self, error_name, input_list, output_list
1048 )
Les Bell0e027d42021-11-09 14:42:14 +00001049
Les Bell729b0352021-11-24 10:28:21 +00001050 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001051 self.ser,
1052 validator_fcns,
1053 error_name,
1054 op=op,
1055 input_dtype=ifm.dtype,
1056 weight_dtype=filter.dtype,
1057 output_dtype=result_tens.dtype,
1058 qinfo=qinfo,
1059 input_list=input_list,
1060 num_operands=num_operands,
1061 output_list=output_list,
1062 pad=padding,
1063 stride=strides,
1064 dilation=dilations,
1065 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001066 weight_shape=filter.shape,
1067 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001068 ):
1069 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001070
1071 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001072 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07001073
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001074 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001075 return result_tens
1076
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001077 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001078 self,
1079 op,
1080 ifm,
1081 filter,
1082 bias,
1083 accum_dtype,
1084 validator_fcns=None,
1085 error_name=None,
1086 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001087 ):
1088 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001089 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001090 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001091
1092 # Invalidate Input/Output list for error if checks.
1093 input_list = [ifm.name, filter.name, bias.name]
1094 output_list = [result_tens.name]
1095 pCount, cCount = op["operands"]
1096 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001097 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1098 self, error_name, input_list, output_list
1099 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001100
Les Bell729b0352021-11-24 10:28:21 +00001101 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001102 self.ser,
1103 validator_fcns,
1104 error_name,
1105 op=op,
1106 input_shape=ifm.shape,
1107 input_dtype=ifm.dtype,
1108 weight_dtype=filter.dtype,
1109 output_shape=result_tens.shape,
1110 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001111 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001112 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001113 input_list=input_list,
1114 output_list=output_list,
1115 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001116 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001117 ):
1118 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001119
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001120 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001121 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001122
1123 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001124 return result_tens
1125
James Ward8b390432022-08-12 20:48:56 +01001126 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001127 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001128 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001129 assert len(inputs) == 2
1130 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001131 accum_dtype = args_dict["acc_type"]
1132 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001133 self.ser, self.rng, a, b, accum_dtype, error_name
1134 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001135
1136 # Invalidate Input/Output list for error if checks.
1137 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001138 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001139 pCount, cCount = op["operands"]
1140 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001141 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1142 self, error_name, input_list, output_list
1143 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001144
Les Bell729b0352021-11-24 10:28:21 +00001145 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001146 self.ser,
1147 validator_fcns,
1148 error_name,
1149 op=op,
1150 input_shape=a.shape,
1151 input_dtype=a.dtype,
1152 input2_shape=b.shape,
1153 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001154 output_shape=result_tensor.shape,
1155 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001156 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001157 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001158 input_list=input_list,
1159 output_list=output_list,
1160 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001161 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001162 ):
1163 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001164
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001165 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001166 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001167
1168 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001169
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001170 compliance = self.tensorComplianceMetaData(
1171 op, a.dtype, args_dict, result_tensor, error_name
1172 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001173
1174 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001175
Matthew Haddond6ce7252021-09-29 15:35:44 +01001176 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1177 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1178
1179 # Invalidate Input/Output list for error if checks.
1180 input_list = [a.name]
1181 output_list = [result_tens.name]
1182 pCount, cCount = op["operands"]
1183 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001184 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1185 self, error_name, input_list, output_list
1186 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001187
Les Bell729b0352021-11-24 10:28:21 +00001188 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001189 self.ser,
1190 validator_fcns,
1191 error_name,
1192 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001193 axis=axis,
1194 input_shape=a.shape,
1195 output_shape=result_tens.shape,
1196 input_dtype=a.dtype,
1197 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001198 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001199 input_list=input_list,
1200 output_list=output_list,
1201 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001202 ):
1203 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001204
1205 attr = ts.TosaSerializerAttribute()
1206 attr.AxisAttribute(axis)
1207
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001208 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001209 return result_tens
1210
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001211 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1212 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001213
Jeremy Johnson18e26662021-07-22 16:15:29 +01001214 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001215
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001216 if error_name == ErrorIf.MaxSmallerMin:
1217 # Make sure the numbers are different to invoke this error
1218 while v[0] == v[1]:
1219 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1220 max_val = min(v)
1221 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001222 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001223 max_val = max(v)
1224 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001225
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001226 # Invalidate Input/Output list for error if checks.
1227 input_list = [a.name]
1228 output_list = [result_tens.name]
1229 pCount, cCount = op["operands"]
1230 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001231 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1232 self, error_name, input_list, output_list
1233 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001234
Les Bell729b0352021-11-24 10:28:21 +00001235 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001236 self.ser,
1237 validator_fcns,
1238 error_name,
1239 op=op,
1240 max_val=max_val,
1241 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001242 input_shape=a.shape,
1243 output_shape=result_tens.shape,
1244 input_dtype=a.dtype,
1245 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001246 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001247 input_list=input_list,
1248 output_list=output_list,
1249 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001250 ):
1251 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001252
1253 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001254 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1255 if a.dtype == DType.FP16:
1256 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1257 min_val = min_val.astype(np.float32)
1258 max_val = max_val.astype(np.float32)
1259
1260 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001261 else:
James Ward34071252022-12-07 15:48:47 +00001262 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001263
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001264 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001265 return result_tens
1266
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001267 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1268 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001269 attr = ts.TosaSerializerAttribute()
1270
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001271 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001272
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001273 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001274 return result_tens
1275
1276 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001277 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1278 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001279
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001280 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001281 return result_tens
1282
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001283 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1284 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1285
1286 # Invalidate Input/Output list for error if checks.
1287 input_list = [a.name]
1288 output_list = [result_tens.name]
1289 pCount, cCount = op["operands"]
1290 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001291 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1292 self, error_name, input_list, output_list
1293 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001294
Les Bell729b0352021-11-24 10:28:21 +00001295 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001296 self.ser,
1297 validator_fcns,
1298 error_name,
1299 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001300 input_shape=a.shape,
1301 output_shape=result_tens.shape,
1302 input_dtype=a.dtype,
1303 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001304 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001305 input_list=input_list,
1306 output_list=output_list,
1307 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001308 ):
1309 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001310
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001311 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001312 return result_tens
1313
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001314 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1315 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1316
1317 # Invalidate Input/Output list for error if checks.
1318 input_list = [a.name]
1319 output_list = [result_tens.name]
1320 pCount, cCount = op["operands"]
1321 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001322 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1323 self, error_name, input_list, output_list
1324 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001325
Les Bell729b0352021-11-24 10:28:21 +00001326 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001327 self.ser,
1328 validator_fcns,
1329 error_name,
1330 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001331 input_shape=a.shape,
1332 output_shape=result_tens.shape,
1333 input_dtype=a.dtype,
1334 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001335 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001336 input_list=input_list,
1337 output_list=output_list,
1338 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001339 ):
1340 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001341
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001342 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001343 return result_tens
1344
Won Jeon78155c62023-06-10 00:20:04 +00001345 def build_erf(self, op, a, validator_fcns=None, error_name=None):
1346 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1347
1348 # Invalidate Input/Output list for error if checks.
1349 input_list = [a.name]
1350 output_list = [result_tens.name]
1351 pCount, cCount = op["operands"]
1352 num_operands = pCount + cCount
1353 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1354 self, error_name, input_list, output_list
1355 )
1356
1357 if not TosaErrorValidator.evValidateErrorIfs(
1358 self.ser,
1359 validator_fcns,
1360 error_name,
1361 op=op,
1362 input_shape=a.shape,
1363 output_shape=result_tens.shape,
1364 input_dtype=a.dtype,
1365 output_dtype=result_tens.dtype,
1366 result_tensors=[result_tens],
1367 input_list=input_list,
1368 output_list=output_list,
1369 num_operands=num_operands,
1370 ):
1371 return None
1372
1373 self.ser.addOperator(op["op"], input_list, output_list)
1374 return result_tens
1375
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001376 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1377 if error_name != ErrorIf.WrongInputType:
1378 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001379
1380 # To store variable length list of input tensors we need to store axis along with it
1381 axis = a[-1]
1382 a = a[:-1]
1383
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001384 result_tens = OutputShaper.concatOp(
1385 self.ser, self.rng, axis, *a, error_name=error_name
1386 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001387
Matthew Haddon818ab902021-07-27 09:12:49 +01001388 input_tensor_names = []
1389 for tensor in a:
1390 input_tensor_names.append(tensor.name)
1391
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001392 # Invalidate Input/Output list for error if checks.
1393 input_list = input_tensor_names
1394 output_list = [result_tens.name]
1395 pCount, cCount = op["operands"]
1396 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001397 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1398 self, error_name, input_list, output_list
1399 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001400
Les Bell729b0352021-11-24 10:28:21 +00001401 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001402 self.ser,
1403 validator_fcns,
1404 error_name,
1405 op=op,
1406 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001407 input_shape=a[0].shape,
1408 output_shape=result_tens.shape,
1409 input_dtype=a[0].dtype,
1410 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001411 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001412 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001413 input_list=input_list,
1414 output_list=output_list,
1415 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001416 ):
1417 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001418
1419 attr = ts.TosaSerializerAttribute()
1420 attr.AxisAttribute(axis)
1421
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001422 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001423 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001424
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001425 def build_pad(
1426 self,
1427 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001428 inputs,
1429 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001430 validator_fcns=None,
1431 error_name=None,
1432 qinfo=None,
1433 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001434 assert len(inputs) == 1
1435 a = inputs[0]
1436 padding = args_dict["pad"]
1437 pad_const_int = args_dict["pad_const_int"]
1438 pad_const_float = args_dict["pad_const_fp"]
1439
1440 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001441
Kevin Chengfe392ce2021-10-18 21:51:55 +00001442 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001443 attr.PadAttribute(
1444 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1445 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001446
Matthew Haddone807aae2021-10-11 18:12:58 +01001447 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001448 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001449 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001450 pCount, cCount = op["operands"]
1451 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001452 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1453 self, error_name, input_list, output_list
1454 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001455
Les Bell729b0352021-11-24 10:28:21 +00001456 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001457 self.ser,
1458 validator_fcns,
1459 error_name,
1460 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001461 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001462 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001463 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001464 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001465 pad=padding,
1466 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001467 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001468 input_list=input_list,
1469 output_list=output_list,
1470 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001471 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001472 ):
1473 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001474
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001475 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001476
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001477 compliance = self.tensorComplianceMetaData(
1478 op, a.dtype, args_dict, result_tensor, error_name
1479 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001480
1481 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001482
Won Jeona21b2e82023-08-10 10:33:01 +00001483 def build_dim(
1484 self,
1485 op,
1486 a,
1487 axis,
1488 validator_fcns=None,
1489 error_name=None,
1490 qinfo=None,
1491 ):
1492 result_tens = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
1493
1494 # Invalidate Input/Output list for error if checks.
1495 input_list = [a.name]
1496 output_list = [result_tens.name]
1497 pCount, cCount = op["operands"]
1498 num_operands = pCount + cCount
1499 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1500 self, error_name, input_list, output_list
1501 )
1502
1503 if not TosaErrorValidator.evValidateErrorIfs(
1504 self.ser,
1505 validator_fcns,
1506 error_name,
1507 op=op,
1508 axis=axis,
1509 input_shape=a.shape,
1510 input_dtype=a.dtype,
1511 output_shape=result_tens.shape,
1512 output_dtype=result_tens.dtype,
1513 result_tensors=[result_tens],
1514 input_list=input_list,
1515 output_list=output_list,
1516 num_operands=num_operands,
1517 ):
1518 return None
1519
1520 attr = ts.TosaSerializerAttribute()
1521 attr.AxisAttribute(axis)
1522
1523 self.ser.addOperator(op["op"], input_list, output_list, attr)
1524 return result_tens
1525
Matthew Haddone807aae2021-10-11 18:12:58 +01001526 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001527 result_tens = OutputShaper.reshapeOp(
1528 self.ser, self.rng, a, newShape, error_name
1529 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001530
1531 # Invalidate Input/Output list for error if checks.
1532 input_list = [a.name]
1533 output_list = [result_tens.name]
1534 pCount, cCount = op["operands"]
1535 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001536 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1537 self, error_name, input_list, output_list
1538 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001539
Les Bell729b0352021-11-24 10:28:21 +00001540 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001541 self.ser,
1542 validator_fcns,
1543 error_name,
1544 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001545 input_shape=a.shape,
1546 output_shape=result_tens.shape,
1547 input_dtype=a.dtype,
1548 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001549 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001550 input_list=input_list,
1551 output_list=output_list,
1552 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001553 ):
1554 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001555
1556 attr = ts.TosaSerializerAttribute()
1557 attr.ReshapeAttribute(newShape)
1558
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001559 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001560 return result_tens
1561
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001562 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1563 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1564
1565 # Invalidate Input/Output list for error if checks.
1566 input_list = [a.name]
1567 output_list = [result_tens.name]
1568 pCount, cCount = op["operands"]
1569 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001570 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1571 self, error_name, input_list, output_list
1572 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001573
Les Bell729b0352021-11-24 10:28:21 +00001574 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001575 self.ser,
1576 validator_fcns,
1577 error_name,
1578 op=op,
1579 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001580 input_shape=a.shape,
1581 output_shape=result_tens.shape,
1582 input_dtype=a.dtype,
1583 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001584 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001585 input_list=input_list,
1586 output_list=output_list,
1587 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001588 ):
1589 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001590
1591 attr = ts.TosaSerializerAttribute()
1592 attr.AxisAttribute(axis)
1593
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001594 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001595 return result_tens
1596
Matthew Haddone807aae2021-10-11 18:12:58 +01001597 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1598 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001599
Kevin Chengfe392ce2021-10-18 21:51:55 +00001600 attr = ts.TosaSerializerAttribute()
1601 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001602
Matthew Haddone807aae2021-10-11 18:12:58 +01001603 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001604 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001605 output_list = [result_tens.name]
1606 pCount, cCount = op["operands"]
1607 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001608 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1609 self, error_name, input_list, output_list
1610 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001611
Les Bell729b0352021-11-24 10:28:21 +00001612 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001613 self.ser,
1614 validator_fcns,
1615 error_name,
1616 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001617 input_shape=a.shape,
1618 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001619 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001620 input_dtype=a.dtype,
1621 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001622 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001623 input_list=input_list,
1624 output_list=output_list,
1625 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001626 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001627 ):
1628 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001629
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001630 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001631 return result_tens
1632
Matthew Haddone807aae2021-10-11 18:12:58 +01001633 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001634 result_tens = OutputShaper.sliceOp(
1635 self.ser, self.rng, a, start, size, error_name
1636 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001637
1638 # Invalidate Input/Output list for error if checks.
1639 input_list = [a.name]
1640 output_list = [result_tens.name]
1641 pCount, cCount = op["operands"]
1642 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001643 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1644 self, error_name, input_list, output_list
1645 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001646
Les Bell729b0352021-11-24 10:28:21 +00001647 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001648 self.ser,
1649 validator_fcns,
1650 error_name,
1651 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001652 input_shape=a.shape,
1653 output_shape=result_tens.shape,
1654 input_dtype=a.dtype,
1655 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001656 start=start,
1657 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001658 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001659 input_list=input_list,
1660 output_list=output_list,
1661 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001662 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001663 ):
1664 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001665
1666 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001667 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001668
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001669 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001670 return result_tens
1671
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001672 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1673 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1674
1675 # Invalidate Input/Output list for error if checks.
1676 input_list = [a.name]
1677 output_list = [result_tens.name]
1678 pCount, cCount = op["operands"]
1679 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001680 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1681 self, error_name, input_list, output_list
1682 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001683
Les Bell729b0352021-11-24 10:28:21 +00001684 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001685 self.ser,
1686 validator_fcns,
1687 error_name,
1688 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001689 input_shape=a.shape,
1690 output_shape=result_tens.shape,
1691 input_dtype=a.dtype,
1692 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001693 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001694 input_list=input_list,
1695 output_list=output_list,
1696 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001697 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001698 ):
1699 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001700
1701 attr = ts.TosaSerializerAttribute()
1702 attr.TileAttribute(multiples)
1703
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001704 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001705 return result_tens
1706
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001707 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001708
1709 # Create a new indicies tensor
1710 # here with data that doesn't exceed the dimensions of the values tensor
1711
Kevin Cheng550ccc52021-03-03 11:21:43 -08001712 K = values.shape[1] # K
1713 W = self.randInt(
1714 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1715 ) # W
1716 indicies_arr = np.int32(
1717 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1718 ) # (N, W)
1719 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001720
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001721 result_tens = OutputShaper.gatherOp(
1722 self.ser, self.rng, values, indicies, error_name
1723 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001724
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001725 # Invalidate Input/Output list for error if checks.
1726 input_list = [values.name, indicies.name]
1727 output_list = [result_tens.name]
1728 pCount, cCount = op["operands"]
1729 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001730 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1731 self, error_name, input_list, output_list
1732 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001733
Les Bell729b0352021-11-24 10:28:21 +00001734 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001735 self.ser,
1736 validator_fcns,
1737 error_name,
1738 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001739 input_shape=values.shape,
1740 output_shape=result_tens.shape,
1741 input_dtype=values.dtype,
1742 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001743 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001744 input_list=input_list,
1745 output_list=output_list,
1746 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001747 ):
1748 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001749
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001750 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001751
1752 return result_tens
1753
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001754 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001755
1756 # Create a new indicies tensor
1757 # here with data that doesn't exceed the dimensions of the values_in tensor
1758
Kevin Cheng550ccc52021-03-03 11:21:43 -08001759 K = values_in.shape[1] # K
1760 W = input.shape[1] # W
1761 indicies_arr = np.int32(
1762 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1763 ) # (N, W)
1764 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001765
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001766 result_tens = OutputShaper.scatterOp(
1767 self.ser, self.rng, values_in, indicies, input, error_name
1768 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001769
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001770 # Invalidate Input/Output list for error if checks.
1771 input_list = [values_in.name, indicies.name, input.name]
1772 output_list = [result_tens.name]
1773 pCount, cCount = op["operands"]
1774 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001775 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1776 self, error_name, input_list, output_list
1777 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001778
Les Bell729b0352021-11-24 10:28:21 +00001779 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001780 self.ser,
1781 validator_fcns,
1782 error_name,
1783 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001784 input_shape=values_in.shape,
1785 output_shape=result_tens.shape,
1786 input_dtype=values_in.dtype,
1787 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001788 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001789 input_list=input_list,
1790 output_list=output_list,
1791 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001792 ):
1793 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001794
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001795 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001796
Kevin Cheng77d0f762020-11-24 10:26:32 -08001797 return result_tens
1798
Kevin Cheng550ccc52021-03-03 11:21:43 -08001799 def build_resize(
1800 self,
1801 op,
1802 input,
1803 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001804 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001805 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001806 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001807 input_dtype,
1808 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001809 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001810 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001811 ):
1812 result_tens = OutputShaper.resizeOp(
1813 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001814 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001815 input,
1816 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001817 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001818 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001819 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001820 input_dtype,
1821 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001822 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001823 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001824
Matthew Haddon848efb42021-09-09 12:30:53 +01001825 # Invalidate Input/Output list for error if checks.
1826 input_list = [input.name]
1827 output_list = [result_tens.name]
1828 pCount, cCount = op["operands"]
1829 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001830 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1831 self, error_name, input_list, output_list
1832 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001833
Les Bell729b0352021-11-24 10:28:21 +00001834 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001835 self.ser,
1836 validator_fcns,
1837 error_name,
1838 op=op,
1839 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001840 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001841 input_dtype=input_dtype,
1842 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001843 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001844 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001845 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001846 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001847 input_list=input_list,
1848 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001849 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001850 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001851 ):
1852 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001853
Eric Kunzee5e26762020-10-13 16:11:07 -07001854 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001855
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001856 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001857
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001858 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001859 return result_tens
1860
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001861 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1862 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1863 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001864 self.ser.addOperator(
1865 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1866 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001867 return result_tens
1868
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001869 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001870 self.ser.addOutputTensor(val)
1871 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001872
1873 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001874 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001875 result_tens = OutputShaper.typeConversionOp(
1876 self.ser, self.rng, val, out_dtype, error_name
1877 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001878
1879 # Invalidate Input/Output list for error if checks.
1880 input_list = [val.name]
1881 output_list = [result_tens.name]
1882 pCount, cCount = op["operands"]
1883 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001884 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1885 self, error_name, input_list, output_list
1886 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001887
Les Bell729b0352021-11-24 10:28:21 +00001888 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001889 self.ser,
1890 validator_fcns,
1891 error_name,
1892 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001893 input_shape=val.shape,
1894 output_shape=result_tens.shape,
1895 input_dtype=val.dtype,
1896 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001897 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001898 input_list=input_list,
1899 output_list=output_list,
1900 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001901 ):
1902 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001903
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001904 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001905 return result_tens
1906
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001907 def build_rescale(
1908 self,
1909 op,
1910 val,
1911 out_dtype,
1912 scale32,
1913 double_round,
1914 per_channel,
1915 validator_fcns,
1916 error_name,
1917 ):
1918 result_tens = OutputShaper.typeConversionOp(
1919 self.ser, self.rng, val, out_dtype, error_name
1920 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001921
1922 if per_channel:
1923 nc = val.shape[-1]
1924 else:
1925 nc = 1
1926
1927 in_type_width = self.typeWidth(val.dtype)
1928 out_type_width = self.typeWidth(out_dtype)
1929
Kevin Cheng3a478572021-01-22 17:21:02 -08001930 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001931 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001932 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001933 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001934 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001935 in_type_width += 1
1936 elif error_name in [
1937 ErrorIf.InputZeroPointNotZero,
1938 ErrorIf.U16InputZeroPointNotValid,
1939 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001940 input_zp = self.randInt(-128, 128)
1941 if input_zp == 0:
1942 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001943 in_type_width += 1
1944 elif val.dtype == DType.UINT16:
1945 # Must come after ErrorIf.U16InputZeroPointNotValid check
1946 input_zp = self.rng.choice([0, 32768])
1947 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001948 else:
1949 input_zp = 0
1950
Kevin Cheng3a478572021-01-22 17:21:02 -08001951 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001952 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001953 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001954 elif out_dtype == DType.UINT8:
1955 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001956 out_type_width += 1
1957 elif error_name in [
1958 ErrorIf.OutputZeroPointNotZero,
1959 ErrorIf.U16OutputZeroPointNotValid,
1960 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001961 output_zp = self.randInt(-128, 128)
1962 if output_zp == 0:
1963 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001964 out_type_width += 1
1965 elif out_dtype == DType.UINT16:
1966 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1967 output_zp = self.rng.choice([0, 32768])
1968 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001969 else:
1970 output_zp = 0
1971
1972 # Calculate scale based on:
1973 # scale = a *(2^output_width)/(2^input_width))
1974
1975 a = np.float32(self.rng.random(size=[nc]))
1976 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1977
1978 if scale32:
1979 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001980 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001981 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1982 else:
1983 # Cap the scaling at 2^15 - 1 for scale16
1984 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1985
Kevin Cheng550ccc52021-03-03 11:21:43 -08001986 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001987
1988 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1989 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001990 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1991 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001992
1993 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001994 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1995 scale_arr[i], scale32
1996 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001997 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1998 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001999
Kevin Cheng550ccc52021-03-03 11:21:43 -08002000 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002001 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002002 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002003 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002004 assert val.placeholderFilename
2005 values = np.load(
2006 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2007 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002008 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2009 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2010 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2011 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002012 if not np.all(np.array_equal(values, val_adj)):
2013 # Values changed so overwrite file with new values
2014 np.save(
2015 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2016 val_adj,
2017 False,
2018 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002019
Matthew Haddonc2025212021-10-08 21:21:05 +01002020 # Invalidate Input/Output list for error if checks.
2021 input_list = [val.name]
2022 output_list = [result_tens.name]
2023 pCount, cCount = op["operands"]
2024 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002025 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2026 self, error_name, input_list, output_list
2027 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002028
2029 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002030 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002031 self.ser,
2032 validator_fcns,
2033 error_name,
2034 op=op,
2035 input_dtype=val.dtype,
2036 output_dtype=out_dtype,
2037 input_shape=val.shape,
2038 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002039 scale32=scale32,
2040 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002041 input_list=input_list,
2042 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002043 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002044 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002045 ):
2046 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002047
Eric Kunzee5e26762020-10-13 16:11:07 -07002048 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002049 attr.RescaleAttribute(
2050 input_zp,
2051 output_zp,
2052 multiplier_arr,
2053 shift_arr,
2054 scale32,
2055 double_round,
2056 per_channel,
2057 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002058
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002059 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002060 return result_tens
2061
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002062 def _get_condition_tensor(self, op, cond, error_name):
2063 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002064 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002065 else:
2066 cond_type = DType.BOOL
2067 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2068 choice = self.rng.choice([1, 2])
2069 if choice == 1:
2070 cond_shape = [2]
2071 else:
2072 cond_shape = [1, 2]
2073 else:
2074 # Must be of size 1 (rank 0)
2075 cond_shape = []
2076 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2077 return cond_tens
2078
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002079 def build_cond_if_const(
2080 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2081 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002082 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002083 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002084 # and fill them with const nodes for the body.
2085
2086 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002087 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002088
2089 # Make then/else tensors
2090 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002091
2092 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002093 if error_name in [
2094 ErrorIf.CondIfOutputListThenGraphMismatch,
2095 ErrorIf.CondIfOutputListElseGraphMismatch,
2096 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002097 incorrect_shape = deepcopy(then_tens.shape)
2098 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002099 incorrect_shape[i] += (
2100 self.rng.choice([-3, -2, 2, 3])
2101 if incorrect_shape[i] > 3
2102 else self.rng.choice([1, 2, 4])
2103 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002104 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2105
Jeremy Johnson18e26662021-07-22 16:15:29 +01002106 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2107 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002108
2109 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002110 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002111
2112 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002113 then_block = "THEN_BLOCK"
2114 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002115 attr = ts.TosaSerializerAttribute()
2116 attr.CondIfAttribute(then_block, else_block)
2117
2118 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002119 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002120
Jerry Ge9e94af82022-10-27 09:57:00 -07002121 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002122 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002123 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2124 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2125 else:
2126 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002127 self.ser.addOutputTensor(then_tens)
2128
Jerry Ge9e94af82022-10-27 09:57:00 -07002129 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002130 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2131 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2132 else:
2133 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002134 self.ser.addOutputTensor(else_tens)
2135
Les Bell729b0352021-11-24 10:28:21 +00002136 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002137 self.ser,
2138 validator_fcns,
2139 error_name,
2140 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002141 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002142 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002143 ):
2144 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002145
Eric Kunzee5e26762020-10-13 16:11:07 -07002146 return result_tens
2147
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002148 def build_cond_if_binary(
2149 self, op, a, b, cond, validator_fcns=None, error_name=None
2150 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002151 # For cond_if with a binary op in the then/else blocks, take a and b and
2152 # alternately add or subtract them based on the condition
2153
2154 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002155 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002156
Kevin Cheng550ccc52021-03-03 11:21:43 -08002157 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002158
2159 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002160 then_block = "THEN_BLOCK"
2161 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002162 attr = ts.TosaSerializerAttribute()
2163 attr.CondIfAttribute(then_block, else_block)
2164
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002165 if error_name in [
2166 ErrorIf.CondIfInputListThenGraphMismatch,
2167 ErrorIf.CondIfInputListElseGraphMismatch,
2168 ErrorIf.CondIfOutputListElseGraphMismatch,
2169 ErrorIf.CondIfOutputListThenGraphMismatch,
2170 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002171 incorrect_shape = a.shape.copy()
2172 for i in range(len(incorrect_shape)):
2173 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2174 incorrect_block_input = deepcopy(a)
2175 incorrect_block_input.shape = incorrect_shape
2176
Eric Kunzee5e26762020-10-13 16:11:07 -07002177 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002178 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002179 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002180 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002181
James Ward24dbc422022-10-19 12:20:31 +01002182 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002183 then_op, else_op = Op.ADD, Op.SUB
2184 elif a.dtype in (DType.INT8, DType.INT16):
2185 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2186 else:
2187 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002188
Les Bell6040b4d2021-10-11 12:50:31 +01002189 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002190 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002191 if (
2192 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2193 and block == then_block
2194 ) or (
2195 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2196 and block == else_block
2197 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002198 self.ser.addInputTensor(incorrect_block_input)
2199 self.ser.addInputTensor(b)
2200 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002201 elif (
2202 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2203 and block == then_block
2204 ) or (
2205 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2206 and block == else_block
2207 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002208 self.ser.addInputTensor(a)
2209 self.ser.addInputTensor(b)
2210 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2211 else:
2212 self.ser.addInputTensor(a)
2213 self.ser.addInputTensor(b)
2214 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002215 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002216
Les Bell729b0352021-11-24 10:28:21 +00002217 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002218 self.ser,
2219 validator_fcns,
2220 error_name,
2221 op=op,
2222 a=a,
2223 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002224 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002225 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002226 ):
2227 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002228
Eric Kunzee5e26762020-10-13 16:11:07 -07002229 return result_tens
2230
Matthew Haddon630c17c2021-10-14 15:05:41 +01002231 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002232 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002233
Kevin Cheng550ccc52021-03-03 11:21:43 -08002234 cond_block = "COND_BLOCK"
2235 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002236
2237 attr = ts.TosaSerializerAttribute()
2238 attr.WhileLoopAttribute(cond_block, body_block)
2239
2240 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002241 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002242 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002243 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002244
2245 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002246 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2247 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002248 if error_name == ErrorIf.InputListOutputListMismatch:
2249 incorrect_acc = deepcopy(acc)
2250 for i in range(len(incorrect_acc.shape)):
2251 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2252 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2253 else:
2254 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002255
2256 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002257 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002258 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002259 [iter.name, a.name, acc.name],
2260 [iter_out.name, a_out.name, acc_out.name],
2261 attr,
2262 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002263 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002264
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002265 if error_name in [
2266 ErrorIf.InputListCondGraphMismatch,
2267 ErrorIf.InputListBodyGraphInputMismatch,
2268 ErrorIf.InputListBodyGraphOutputMismatch,
2269 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002270 incorrect_iter = deepcopy(iter)
2271 for i in range(len(incorrect_iter.shape)):
2272 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2273 if len(incorrect_iter.shape) == 0:
2274 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2275
2276 incorrect_acc = deepcopy(acc)
2277 for i in range(len(incorrect_acc.shape)):
2278 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2279
Eric Kunzee5e26762020-10-13 16:11:07 -07002280 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002281 self.ser.addBasicBlock(cond_block)
2282
Matthew Haddon630c17c2021-10-14 15:05:41 +01002283 if error_name == ErrorIf.InputListCondGraphMismatch:
2284 self.ser.addInputTensor(incorrect_iter)
2285 self.ser.addInputTensor(a)
2286 self.ser.addInputTensor(incorrect_acc)
2287 else:
2288 self.ser.addInputTensor(iter)
2289 self.ser.addInputTensor(a)
2290 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002291 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002292
2293 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002294 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002295 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002296 cond_type = DType.BOOL
2297 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2298 choice = self.rng.choice([1, 2])
2299 if choice == 1:
2300 cond_shape = [3]
2301 else:
2302 cond_shape = [1, 2]
2303 else:
2304 cond_shape = []
2305 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002306
Kevin Cheng550ccc52021-03-03 11:21:43 -08002307 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002308
2309 # BODY block (input: a, acc, iter, output: a, acc, iter)
2310 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002311 self.ser.addBasicBlock(body_block)
2312
Matthew Haddon630c17c2021-10-14 15:05:41 +01002313 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2314 self.ser.addInputTensor(incorrect_iter)
2315 self.ser.addInputTensor(a)
2316 self.ser.addInputTensor(incorrect_acc)
2317 else:
2318 self.ser.addInputTensor(iter)
2319 self.ser.addInputTensor(a)
2320 self.ser.addInputTensor(acc)
2321
Kevin Cheng550ccc52021-03-03 11:21:43 -08002322 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002323
2324 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002325 iter_body_out = self.ser.addIntermediate(
2326 incorrect_iter.shape, incorrect_iter.dtype
2327 )
2328 acc_body_out = self.ser.addIntermediate(
2329 incorrect_acc.shape, incorrect_acc.dtype
2330 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002331 else:
2332 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2333 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2334
Eric Kunzee5e26762020-10-13 16:11:07 -07002335 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2336 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2337 self.ser.addOutputTensor(iter_body_out)
2338 self.ser.addOutputTensor(a)
2339 self.ser.addOutputTensor(acc_body_out)
2340
Les Bell729b0352021-11-24 10:28:21 +00002341 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002342 self.ser,
2343 validator_fcns,
2344 error_name,
2345 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002346 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002347 ):
2348 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002349
Eric Kunzee5e26762020-10-13 16:11:07 -07002350 return acc_out
2351
Luke Hutton57287132023-02-06 14:54:18 +00002352 def build_fft2d(
2353 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2354 ):
2355 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2356
2357 input_names = [val1.name, val2.name]
2358 pCount, cCount = op["operands"]
2359 num_operands = pCount + cCount
2360
2361 output_names = [res.name for res in results]
2362 output_shapes = [res.shape for res in results]
2363 output_dtypes = [res.dtype for res in results]
2364
2365 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2366 self, error_name, input_names, output_names
2367 )
2368
2369 if not TosaErrorValidator.evValidateErrorIfs(
2370 self.ser,
2371 validator_fcns,
2372 error_name,
2373 op=op,
2374 inverse=inverse,
2375 input1=val1,
2376 input2=val2,
2377 input_shape=val1.shape,
2378 input_dtype=val1.dtype,
2379 output_shape=output_shapes,
2380 output_dtype=output_dtypes,
2381 result_tensors=results,
2382 input_list=input_names,
2383 output_list=output_names,
2384 num_operands=num_operands,
2385 ):
2386 return None
2387
2388 attr = ts.TosaSerializerAttribute()
2389 attr.FFTAttribute(inverse)
2390
2391 self.ser.addOperator(op["op"], input_names, output_names, attr)
2392 return results
2393
Luke Hutton261b7b62023-01-10 14:50:31 +00002394 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2395 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2396
2397 input_names = [val.name]
2398 pCount, cCount = op["operands"]
2399 num_operands = pCount + cCount
2400
2401 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002402 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002403 output_dtypes = [res.dtype for res in results]
2404
2405 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2406 self, error_name, input_names, output_names
2407 )
2408
2409 if not TosaErrorValidator.evValidateErrorIfs(
2410 self.ser,
2411 validator_fcns,
2412 error_name,
2413 op=op,
2414 input_shape=val.shape,
2415 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002416 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002417 output_dtype=output_dtypes,
2418 result_tensors=results,
2419 input_list=input_names,
2420 output_list=output_names,
2421 num_operands=num_operands,
2422 ):
2423 return None
2424
2425 self.ser.addOperator(op["op"], input_names, output_names)
2426 return results
2427
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002428 def create_filter_lists(
2429 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2430 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002431 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2432 default_test_rank_range = range(1, 5)
2433 if not shapeFilter:
2434 shapeFilter = [None]
2435
2436 # Calculate the filters based on what is requested and what the operator allows
2437 rmin, rmax = op["rank"]
2438 if rankFilter is not None:
2439 cleanRankFilter = []
2440 # Ensure rankFilter values are allowed by operator
2441 for rank in rankFilter:
2442 if rank >= rmin and rank <= rmax:
2443 cleanRankFilter.append(rank)
2444 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002445 # Ensure default behaviour is bounded by default range or by operator,
2446 # whichever is the smaller range of ranks.
2447 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002448 cleanRankFilter = (
2449 opRankRange
2450 if len(opRankRange) <= len(default_test_rank_range)
2451 else default_test_rank_range
2452 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002453 else:
2454 cleanRankFilter = range(rmin, rmax + 1)
2455
2456 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002457
Matthew Haddon1c00b712021-10-01 15:51:03 +01002458 if dtypeFilter is not None:
2459 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002460 # Create list of operator dtypes filtered by requested dtypes
2461 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002462 if dtype in dtypeFilter or (
2463 isinstance(dtype, list) and dtype[0] in dtypeFilter
2464 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002465 cleanDtypeFilter.append(dtype)
2466 else:
2467 cleanDtypeFilter = dtypes
2468
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002469 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002470 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002471 "shapeFilter": shapeFilter,
2472 "rankFilter": cleanRankFilter,
2473 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002474 }
2475 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002476 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002477 if validator is not None:
2478 validator_info = validator(check=False, op=op)
2479 else:
2480 return None
2481
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002482 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002483
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002484 # Set parameters as required
2485 if error_arguments["rank"] is not None:
2486 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002487 else:
2488 rankFilter = cleanRankFilter
2489
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002490 if error_arguments["dtype"] is not None:
2491 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002492 else:
2493 dtypeFilter = cleanDtypeFilter
2494
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002495 if error_arguments["shape"] is not None:
2496 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002497 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002498 shapeFilter = shapeFilter[
2499 :2
2500 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002501
2502 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002503 "shapeFilter": shapeFilter,
2504 "rankFilter": rankFilter,
2505 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002506 }
2507 return filterDict
2508
Kevin Cheng550ccc52021-03-03 11:21:43 -08002509 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002510 self,
2511 opName,
2512 shapeFilter=[None],
2513 rankFilter=None,
2514 dtypeFilter=None,
2515 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002516 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002517
2518 try:
2519 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002520 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002521 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002522
2523 # Initialize a new random number generator
2524 self.rng = np.random.default_rng(self.random_seed)
2525
Jeremy Johnson1271c442023-09-05 11:39:26 +01002526 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002527
Eric Kunzee5e26762020-10-13 16:11:07 -07002528 # Test list consists of a tuple of:
2529 # (opName, testNameStr, dtype, shapeList, argumentsList)
2530 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002531 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002532 error_if_validators = op["error_if_validators"]
2533 else:
2534 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002535
Matthew Haddon1c00b712021-10-01 15:51:03 +01002536 for validator in error_if_validators:
2537 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002538 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002539 else:
2540 error_name = None
2541
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002542 filterDict = self.create_filter_lists(
2543 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2544 )
2545 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002546 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002547 cleanRankFilter = filterDict["rankFilter"]
2548 cleanDtypeFilter = filterDict["dtypeFilter"]
2549 cleanShapeFilter = filterDict["shapeFilter"]
2550 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002551
2552 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002553 for t in cleanDtypeFilter:
2554 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002555 # Filter out by rank
2556 if shape is not None and len(shape) != r:
2557 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002558 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002559 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002560
Matthew Haddon74567092021-07-16 15:38:20 +01002561 shapeStr = self.shapeStr(shapeList[0])
2562 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002563
Matthew Haddon74567092021-07-16 15:38:20 +01002564 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2565 argList = []
2566 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002567 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002568 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002569 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002570
Matthew Haddon74567092021-07-16 15:38:20 +01002571 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002572 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002573 if argStr:
2574 testStr = "{}_{}_{}_{}".format(
2575 opName, shapeStr, typeStr, argStr
2576 )
2577 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002578 testStr = "{}_{}_{}".format(
2579 opName, shapeStr, typeStr
2580 )
2581 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002582 if argStr:
2583 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2584 opName, error_name, shapeStr, typeStr, argStr
2585 )
2586 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002587 testStr = "{}_ERRORIF_{}_{}_{}".format(
2588 opName, error_name, shapeStr, typeStr
2589 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002590
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002591 testList.append(
2592 (opName, testStr, t, error_name, shapeList, args)
2593 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002594
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002595 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002596 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2597 if "invalid_test_validators" in op:
2598 invalid_test_validators = op["invalid_test_validators"]
2599 clean_testList = []
2600 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002601 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002602 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002603 if validator_fcn(
2604 opName=test[0],
2605 input_dtype=test[2],
2606 shapeList=test[4],
2607 args=test[5],
2608 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002609 remove_test = True
2610 if not remove_test:
2611 clean_testList.append(test)
2612 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002613
2614 return testList
2615
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002616 def serializeTest(
2617 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2618 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002619 try:
2620 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002621 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002622 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002623
Jeremy Johnson0c716862023-04-13 17:18:19 +01002624 if self.args.verbose:
2625 print(f"Creating {testStr}")
2626
Eric Kunzee5e26762020-10-13 16:11:07 -07002627 # Create a serializer
2628 self.createSerializer(opName, testStr)
2629
Jeremy Johnson1271c442023-09-05 11:39:26 +01002630 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002631 if "error_if_validators" in op:
2632 error_if_validators = op["error_if_validators"]
2633 else:
2634 error_if_validators = None
2635
Kevin Cheng550ccc52021-03-03 11:21:43 -08002636 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002637 num_operands = pCount + cCount
2638
2639 if isinstance(dtype_or_dtypeList, list):
2640 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002641 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002642 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002643 else:
2644 dtypeList = [dtype_or_dtypeList] * (num_operands)
2645
Kevin Cheng93a16282021-08-31 16:14:03 -07002646 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002647 assert (
2648 len(shapeList) == num_operands
2649 ), "shapeList length {} must match number of operands {}".format(
2650 len(shapeList), num_operands
2651 )
2652 assert (
2653 len(dtypeList) == num_operands
2654 ), "dtypeList length {} must match number of operands {}".format(
2655 len(dtypeList), num_operands
2656 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002657
2658 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002659 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002660 except KeyError:
2661 qgen = None
2662
2663 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002664
Matthew Haddon1c00b712021-10-01 15:51:03 +01002665 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002666 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002667 else:
2668 qinfo = None
2669
Jeremy Johnson1271c442023-09-05 11:39:26 +01002670 # Extra meta data for the desc.json
2671 tensMeta = {}
2672
2673 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002674 if isinstance(testArgs, dict):
2675 # New interface with args info in dictionary
2676 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002677 assert "dg_type" in argsDict
2678 tvgInfo = tvgen_fcn(
2679 self, opName, dtypeList, shapeList, argsDict, error_name
2680 )
2681 if tvgInfo.dataGenDict:
2682 tensMeta["data_gen"] = tvgInfo.dataGenDict
2683 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002684
2685 result = build_fcn(
2686 self,
2687 op,
2688 tens,
2689 argsDict,
2690 validator_fcns=error_if_validators,
2691 error_name=error_name,
2692 qinfo=qinfo,
2693 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002694 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002695 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002696 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002697
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002698 try:
2699 if error_if_validators is None:
2700 if qinfo is not None:
2701 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2702 else:
2703 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002704 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002705 if qinfo is not None:
2706 result = build_fcn(
2707 self,
2708 op,
2709 *tens,
2710 *testArgs,
2711 validator_fcns=error_if_validators,
2712 error_name=error_name,
2713 qinfo=qinfo,
2714 )
2715 else:
2716 result = build_fcn(
2717 self,
2718 op,
2719 *tens,
2720 *testArgs,
2721 validator_fcns=error_if_validators,
2722 error_name=error_name,
2723 )
2724 except TypeError as e:
2725 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2726 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002727
Jeremy Johnson1271c442023-09-05 11:39:26 +01002728 if result:
Les Bell729b0352021-11-24 10:28:21 +00002729 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002730 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2731 # Add the compliance meta data
2732 # NOTE: This currently expects only one result output
2733 tensMeta["compliance"] = {
2734 "version": "0.1",
2735 "tensors": {result.resultTensor.name: result.complianceDict},
2736 }
2737 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002738 else:
2739 # The test is not valid
2740 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002741
Eric Kunzee5e26762020-10-13 16:11:07 -07002742 def createDynamicOpLists(self):
2743
Jeremy Johnson00423432022-09-12 17:27:37 +01002744 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2745 # Already created these lists (can occur when class is initialized more than once)
2746 return
2747
Eric Kunzee5e26762020-10-13 16:11:07 -07002748 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002749 if not self.args.level8k:
2750 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2751 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2752 else:
2753 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2754 KERNELS_2D = [[1, bigK], [bigK, 2]]
2755 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002756
Kevin Cheng1533b852021-09-01 12:51:58 -07002757 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002758 testName = "conv2d_{}x{}".format(k[0], k[1])
2759 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2760 self.TOSA_OP_LIST[testName]["filter"] = k
2761 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002762
Kevin Cheng550ccc52021-03-03 11:21:43 -08002763 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2764 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2765 "depthwise_conv2d_TEMPLATE"
2766 ].copy()
2767 self.TOSA_OP_LIST[testName]["filter"] = k
2768 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002769
Kevin Cheng550ccc52021-03-03 11:21:43 -08002770 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2771 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2772 "transpose_conv2d_TEMPLATE"
2773 ].copy()
2774 self.TOSA_OP_LIST[testName]["filter"] = k
2775 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002776
Kevin Cheng1533b852021-09-01 12:51:58 -07002777 for k in KERNELS_3D:
2778 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2779 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2780 self.TOSA_OP_LIST[testName]["filter"] = k
2781 self.TOSA_OP_LIST[testName]["template"] = False
2782
Eric Kunzee5e26762020-10-13 16:11:07 -07002783 # Delete any templates after having created any dynamic ops
2784 # This is a two-pass operation because it's bad practice to delete
2785 # keys from dictionaries while iterating
2786 keyList = []
2787 for k in self.TOSA_OP_LIST:
2788 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002789 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002790 keyList.append(k)
2791 continue
2792 except KeyError:
2793 pass
2794
2795 for k in keyList:
2796 del self.TOSA_OP_LIST[k]
2797
2798 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002799 """Fill in default fields for ops if they aren't already specified.
2800 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002801 for op in self.TOSA_OP_LIST:
2802
2803 # Required fields
2804 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002805 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002806 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002807 raise Exception(
2808 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2809 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002810
2811 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002812 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002813 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002814 raise Exception(
2815 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2816 op
2817 )
2818 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002819
2820 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002821 _ = self.TOSA_OP_LIST[op]["types"]
2822 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002823 raise Exception(
2824 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2825 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002826
2827 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002828 _ = self.TOSA_OP_LIST[op]["op"]
2829 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002830 raise Exception(
2831 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2832 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002833
2834 # Put in default rank range, if missing
2835 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002836 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002837 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002838 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002839
2840 # Tensor operator list
2841 # 'op': op name
2842 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002843 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2844 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002845 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2846 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002847 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002848
Kevin Cheng550ccc52021-03-03 11:21:43 -08002849 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002850 TYPE_INT_FP = [
2851 DType.INT8,
2852 DType.INT16,
2853 DType.INT32,
2854 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002855 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002856 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002857 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002858
Kevin Cheng550ccc52021-03-03 11:21:43 -08002859 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002860 TYPE_FI32 = [
2861 DType.FP32,
2862 DType.FP16,
2863 DType.BF16,
2864 DType.INT32,
2865 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002866 TYPE_FIB = [
2867 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002868 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002869 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002870 DType.INT8,
2871 DType.INT16,
2872 DType.INT32,
2873 DType.BOOL,
2874 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002875 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002876
James Ward24dbc422022-10-19 12:20:31 +01002877 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002878
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002879 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002880 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002881 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002882 [DType.INT8, DType.INT8, DType.INT32],
2883 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002884 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002885 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002886 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002887 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002888 ]
2889
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002890 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002891
2892 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002893 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002894 "argmax": {
2895 "op": Op.ARGMAX,
2896 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002897 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002898 "build_fcn": (
2899 build_argmax,
2900 TosaTensorGen.tgBasic,
2901 TosaTensorValuesGen.tvgDefault,
2902 TosaArgGen.agAxis,
2903 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002904 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002905 "error_if_validators": (
2906 TosaErrorValidator.evAxisSmallerZero,
2907 TosaErrorValidator.evAxisLargerRank,
2908 TosaErrorValidator.evArgmaxOutputRankMismatch,
2909 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2910 TosaErrorValidator.evWrongRank,
2911 TosaErrorValidator.evWrongInputType,
2912 TosaErrorValidator.evWrongOutputType,
2913 TosaErrorValidator.evWrongInputList,
2914 TosaErrorValidator.evWrongOutputList,
2915 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002916 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002917 "avg_pool2d": {
2918 "op": Op.AVG_POOL2D,
2919 "operands": (1, 0),
2920 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002921 "build_fcn": (
2922 build_pool2d,
2923 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002924 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002925 TosaArgGen.agPooling,
2926 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002927 "qgen": TosaQuantGen.qgUnary,
2928 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002929 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002930 "error_if_validators": (
2931 TosaErrorValidator.evKernelSmallerOne,
2932 TosaErrorValidator.evStrideSmallerOne,
2933 TosaErrorValidator.evPadSmallerZero,
2934 TosaErrorValidator.evWrongRank,
2935 TosaErrorValidator.evWrongInputType,
2936 TosaErrorValidator.evWrongOutputType,
2937 TosaErrorValidator.evWrongInputList,
2938 TosaErrorValidator.evWrongOutputList,
2939 TosaErrorValidator.evInputZeroPointNotZero,
2940 TosaErrorValidator.evOutputZeroPointNotZero,
2941 TosaErrorValidator.evPadLargerEqualKernel,
2942 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002943 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002944 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002945 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002946 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002947 "conv2d_TEMPLATE": {
2948 "op": Op.CONV2D,
2949 "operands": (1, 2),
2950 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002951 "build_fcn": (
2952 build_conv2d,
2953 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002954 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002955 TosaArgGen.agConv,
2956 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002957 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002958 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002959 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2960 "error_if_validators": (
2961 TosaErrorValidator.evWrongInputType,
2962 TosaErrorValidator.evWrongOutputType,
2963 TosaErrorValidator.evWrongInputList,
2964 TosaErrorValidator.evWrongOutputList,
2965 TosaErrorValidator.evInputZeroPointNotZero,
2966 TosaErrorValidator.evWeightZeroPointNotZero,
2967 TosaErrorValidator.evPadSmallerZero,
2968 TosaErrorValidator.evStrideSmallerOne,
2969 TosaErrorValidator.evDilationSmallerOne,
2970 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002971 TosaErrorValidator.evConvOutputShapeMismatch,
2972 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002973 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002974 "data_gen": {
2975 "fp": (gtu.DataGenType.DOT_PRODUCT,),
2976 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002977 "template": True,
2978 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002979 # Templated operator. Filled in by createDynamicOpLists
2980 "conv3d_TEMPLATE": {
2981 "op": Op.CONV3D,
2982 "operands": (1, 2),
2983 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002984 "build_fcn": (
2985 build_conv3d,
2986 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002987 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002988 TosaArgGen.agConv,
2989 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002990 "qgen": TosaQuantGen.qgConv,
2991 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002992 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2993 "error_if_validators": (
2994 TosaErrorValidator.evWrongInputType,
2995 TosaErrorValidator.evWrongOutputType,
2996 TosaErrorValidator.evWrongInputList,
2997 TosaErrorValidator.evWrongOutputList,
2998 TosaErrorValidator.evInputZeroPointNotZero,
2999 TosaErrorValidator.evWeightZeroPointNotZero,
3000 TosaErrorValidator.evPadSmallerZero,
3001 TosaErrorValidator.evStrideSmallerOne,
3002 TosaErrorValidator.evDilationSmallerOne,
3003 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003004 TosaErrorValidator.evConvOutputShapeMismatch,
3005 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003006 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003007 "template": True,
3008 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003009 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003010 "depthwise_conv2d_TEMPLATE": {
3011 "op": Op.DEPTHWISE_CONV2D,
3012 "operands": (1, 2),
3013 "filter": [1, 1],
3014 "rank": (4, 4),
3015 "build_fcn": (
3016 build_depthwise_conv2d,
3017 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003018 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003019 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003020 ),
3021 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003022 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003023 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3024 "error_if_validators": (
3025 TosaErrorValidator.evWrongInputType,
3026 TosaErrorValidator.evWrongOutputType,
3027 TosaErrorValidator.evWrongInputList,
3028 TosaErrorValidator.evWrongOutputList,
3029 TosaErrorValidator.evInputZeroPointNotZero,
3030 TosaErrorValidator.evWeightZeroPointNotZero,
3031 TosaErrorValidator.evPadSmallerZero,
3032 TosaErrorValidator.evStrideSmallerOne,
3033 TosaErrorValidator.evDilationSmallerOne,
3034 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003035 TosaErrorValidator.evConvOutputShapeMismatch,
3036 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003037 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003038 "template": True,
3039 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003040 "fully_connected": {
3041 "op": Op.FULLY_CONNECTED,
3042 "operands": (1, 2),
3043 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003044 "build_fcn": (
3045 build_fully_connected,
3046 TosaTensorGen.tgFullyConnected,
3047 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01003048 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003049 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003050 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003051 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003052 "error_if_validators": (
3053 TosaErrorValidator.evInputZeroPointNotZero,
3054 TosaErrorValidator.evWeightZeroPointNotZero,
3055 TosaErrorValidator.evWrongRank,
3056 TosaErrorValidator.evWrongInputType,
3057 TosaErrorValidator.evWrongOutputType,
3058 TosaErrorValidator.evWrongInputList,
3059 TosaErrorValidator.evWrongOutputList,
3060 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003061 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003062 "matmul": {
3063 "op": Op.MATMUL,
3064 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003065 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003066 "build_fcn": (
3067 build_matmul,
3068 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003069 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003070 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003071 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003072 "qgen": TosaQuantGen.qgMatmul,
3073 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003074 "error_if_validators": (
3075 TosaErrorValidator.evInputZeroPointNotZero,
3076 TosaErrorValidator.evWrongRank,
3077 TosaErrorValidator.evWrongInputType,
3078 TosaErrorValidator.evWrongOutputType,
3079 TosaErrorValidator.evWrongInputList,
3080 TosaErrorValidator.evWrongOutputList,
3081 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003082 "data_gen": {
3083 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003084 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003085 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003086 "max_pool2d": {
3087 "op": Op.MAX_POOL2D,
3088 "operands": (1, 0),
3089 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003090 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01003091 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003092 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003093 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003094 TosaArgGen.agPooling,
3095 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003096 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003097 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003098 "error_if_validators": (
3099 TosaErrorValidator.evKernelSmallerOne,
3100 TosaErrorValidator.evStrideSmallerOne,
3101 TosaErrorValidator.evPadSmallerZero,
3102 TosaErrorValidator.evWrongRank,
3103 TosaErrorValidator.evWrongInputType,
3104 TosaErrorValidator.evWrongOutputType,
3105 TosaErrorValidator.evWrongInputList,
3106 TosaErrorValidator.evWrongOutputList,
3107 TosaErrorValidator.evPadLargerEqualKernel,
3108 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003109 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003110 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003111 "data_gen": {
3112 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3113 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003114 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003115 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003116 "transpose_conv2d_TEMPLATE": {
3117 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003118 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003119 "rank": (4, 4),
3120 "build_fcn": (
3121 build_transpose_conv2d,
3122 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003123 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003124 TosaArgGen.agTransposeConv2D,
3125 ),
3126 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003127 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003128 "invalid_test_validators": (
3129 TosaInvalidValidator.ivHeightWidthInvalid,
3130 TosaInvalidValidator.ivNonPositiveOutputShape,
3131 ),
3132 "error_if_validators": (
3133 TosaErrorValidator.evWrongInputType,
3134 TosaErrorValidator.evWrongOutputType,
3135 TosaErrorValidator.evWrongInputList,
3136 TosaErrorValidator.evWrongOutputList,
3137 TosaErrorValidator.evInputZeroPointNotZero,
3138 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003139 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003140 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003141 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003142 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003143 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003144 "template": True,
3145 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003146 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003147 "clamp": {
3148 "op": Op.CLAMP,
3149 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003150 "build_fcn": (
3151 build_clamp,
3152 TosaTensorGen.tgBasic,
3153 TosaTensorValuesGen.tvgDefault,
3154 None,
3155 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003156 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003157 "error_if_validators": (
3158 TosaErrorValidator.evMaxSmallerMin,
3159 TosaErrorValidator.evWrongInputType,
3160 TosaErrorValidator.evWrongOutputType,
3161 TosaErrorValidator.evWrongInputList,
3162 TosaErrorValidator.evWrongOutputList,
3163 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003164 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003165 "sigmoid": {
3166 "op": Op.SIGMOID,
3167 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003168 "build_fcn": (
3169 build_sigmoid,
3170 TosaTensorGen.tgBasic,
3171 TosaTensorValuesGen.tvgDefault,
3172 None,
3173 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003174 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003175 "error_if_validators": (
3176 TosaErrorValidator.evWrongInputType,
3177 TosaErrorValidator.evWrongOutputType,
3178 TosaErrorValidator.evWrongInputList,
3179 TosaErrorValidator.evWrongOutputList,
3180 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003181 },
3182 "tanh": {
3183 "op": Op.TANH,
3184 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003185 "build_fcn": (
3186 build_tanh,
3187 TosaTensorGen.tgBasic,
3188 TosaTensorValuesGen.tvgDefault,
3189 None,
3190 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003191 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003192 "error_if_validators": (
3193 TosaErrorValidator.evWrongInputType,
3194 TosaErrorValidator.evWrongOutputType,
3195 TosaErrorValidator.evWrongInputList,
3196 TosaErrorValidator.evWrongOutputList,
3197 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003198 },
Won Jeon78155c62023-06-10 00:20:04 +00003199 "erf": {
3200 "op": Op.ERF,
3201 "operands": (1, 0),
3202 "build_fcn": (
3203 build_erf,
3204 TosaTensorGen.tgBasic,
3205 TosaTensorValuesGen.tvgDefault,
3206 None,
3207 ),
3208 "types": TYPE_FP,
3209 "error_if_validators": (
3210 TosaErrorValidator.evWrongInputType,
3211 TosaErrorValidator.evWrongOutputType,
3212 TosaErrorValidator.evWrongInputList,
3213 TosaErrorValidator.evWrongOutputList,
3214 ),
3215 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003216 # Elementwise Binary Operators
3217 "add": {
3218 "op": Op.ADD,
3219 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003220 "build_fcn": (
3221 build_binary_broadcast,
3222 TosaTensorGen.tgBroadcastFuzz,
3223 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003224 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003225 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003226 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003227 "error_if_validators": (
3228 TosaErrorValidator.evRankMismatch,
3229 TosaErrorValidator.evWrongInputType,
3230 TosaErrorValidator.evWrongOutputType,
3231 TosaErrorValidator.evWrongInputList,
3232 TosaErrorValidator.evWrongOutputList,
3233 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003234 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003235 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003236 "data_gen": {
3237 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3238 },
3239 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003240 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003241 "arithmetic_right_shift": {
3242 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3243 "operands": (2, 0),
3244 "build_fcn": (
3245 build_arithmetic_right_shift,
3246 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003247 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003248 TosaArgGen.agArithmeticRightShift,
3249 ),
3250 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003251 "error_if_validators": (
3252 TosaErrorValidator.evRankMismatch,
3253 TosaErrorValidator.evWrongInputType,
3254 TosaErrorValidator.evWrongOutputType,
3255 TosaErrorValidator.evWrongInputList,
3256 TosaErrorValidator.evWrongOutputList,
3257 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003258 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003259 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003260 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003261 "bitwise_and": {
3262 "op": Op.BITWISE_AND,
3263 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003264 "build_fcn": (
3265 build_binary_broadcast,
3266 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003267 TosaTensorValuesGen.tvgLazyGenDefault,
3268 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003269 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003270 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003271 "error_if_validators": (
3272 TosaErrorValidator.evRankMismatch,
3273 TosaErrorValidator.evWrongInputType,
3274 TosaErrorValidator.evWrongOutputType,
3275 TosaErrorValidator.evWrongInputList,
3276 TosaErrorValidator.evWrongOutputList,
3277 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003278 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003279 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003280 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003281 "bitwise_or": {
3282 "op": Op.BITWISE_OR,
3283 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003284 "build_fcn": (
3285 build_binary_broadcast,
3286 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003287 TosaTensorValuesGen.tvgLazyGenDefault,
3288 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003289 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003290 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003291 "error_if_validators": (
3292 TosaErrorValidator.evRankMismatch,
3293 TosaErrorValidator.evWrongInputType,
3294 TosaErrorValidator.evWrongOutputType,
3295 TosaErrorValidator.evWrongInputList,
3296 TosaErrorValidator.evWrongOutputList,
3297 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003298 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003299 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003300 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003301 "bitwise_xor": {
3302 "op": Op.BITWISE_XOR,
3303 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003304 "build_fcn": (
3305 build_binary_broadcast,
3306 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003307 TosaTensorValuesGen.tvgLazyGenDefault,
3308 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003309 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003310 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003311 "error_if_validators": (
3312 TosaErrorValidator.evRankMismatch,
3313 TosaErrorValidator.evWrongInputType,
3314 TosaErrorValidator.evWrongOutputType,
3315 TosaErrorValidator.evWrongInputList,
3316 TosaErrorValidator.evWrongOutputList,
3317 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003318 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003319 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003320 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003321 "intdiv": {
3322 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003323 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003324 "build_fcn": (
3325 build_binary_broadcast,
3326 TosaTensorGen.tgBroadcastFuzz,
3327 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003328 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003329 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003330 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003331 "error_if_validators": (
3332 TosaErrorValidator.evRankMismatch,
3333 TosaErrorValidator.evWrongInputType,
3334 TosaErrorValidator.evWrongOutputType,
3335 TosaErrorValidator.evWrongInputList,
3336 TosaErrorValidator.evWrongOutputList,
3337 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003338 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003339 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003340 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003341 "logical_and": {
3342 "op": Op.LOGICAL_AND,
3343 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003344 "build_fcn": (
3345 build_binary_broadcast,
3346 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003347 TosaTensorValuesGen.tvgLazyGenDefault,
3348 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003349 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003350 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003351 "error_if_validators": (
3352 TosaErrorValidator.evRankMismatch,
3353 TosaErrorValidator.evWrongInputType,
3354 TosaErrorValidator.evWrongOutputType,
3355 TosaErrorValidator.evWrongInputList,
3356 TosaErrorValidator.evWrongOutputList,
3357 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003358 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003359 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003360 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003361 "logical_left_shift": {
3362 "op": Op.LOGICAL_LEFT_SHIFT,
3363 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003364 "build_fcn": (
3365 build_binary_broadcast,
3366 TosaTensorGen.tgBroadcastFuzz,
3367 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003368 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003369 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003370 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003371 "error_if_validators": (
3372 TosaErrorValidator.evRankMismatch,
3373 TosaErrorValidator.evWrongInputType,
3374 TosaErrorValidator.evWrongOutputType,
3375 TosaErrorValidator.evWrongInputList,
3376 TosaErrorValidator.evWrongOutputList,
3377 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003378 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003379 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003380 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003381 "logical_right_shift": {
3382 "op": Op.LOGICAL_RIGHT_SHIFT,
3383 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003384 "build_fcn": (
3385 build_binary_broadcast,
3386 TosaTensorGen.tgBroadcastFuzz,
3387 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003388 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003389 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003390 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003391 "error_if_validators": (
3392 TosaErrorValidator.evRankMismatch,
3393 TosaErrorValidator.evWrongInputType,
3394 TosaErrorValidator.evWrongOutputType,
3395 TosaErrorValidator.evWrongInputList,
3396 TosaErrorValidator.evWrongOutputList,
3397 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003398 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003399 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003400 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003401 "logical_or": {
3402 "op": Op.LOGICAL_OR,
3403 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003404 "build_fcn": (
3405 build_binary_broadcast,
3406 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003407 TosaTensorValuesGen.tvgLazyGenDefault,
3408 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003409 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003410 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003411 "error_if_validators": (
3412 TosaErrorValidator.evRankMismatch,
3413 TosaErrorValidator.evWrongInputType,
3414 TosaErrorValidator.evWrongOutputType,
3415 TosaErrorValidator.evWrongInputList,
3416 TosaErrorValidator.evWrongOutputList,
3417 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003418 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003419 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003420 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003421 "logical_xor": {
3422 "op": Op.LOGICAL_XOR,
3423 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003424 "build_fcn": (
3425 build_binary_broadcast,
3426 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003427 TosaTensorValuesGen.tvgLazyGenDefault,
3428 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003429 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003430 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003431 "error_if_validators": (
3432 TosaErrorValidator.evRankMismatch,
3433 TosaErrorValidator.evWrongInputType,
3434 TosaErrorValidator.evWrongOutputType,
3435 TosaErrorValidator.evWrongInputList,
3436 TosaErrorValidator.evWrongOutputList,
3437 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003438 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003439 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003440 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003441 "maximum": {
3442 "op": Op.MAXIMUM,
3443 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003444 "build_fcn": (
3445 build_binary_broadcast,
3446 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003447 TosaTensorValuesGen.tvgLazyGenDefault,
3448 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003449 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003450 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003451 "error_if_validators": (
3452 TosaErrorValidator.evRankMismatch,
3453 TosaErrorValidator.evWrongInputType,
3454 TosaErrorValidator.evWrongOutputType,
3455 TosaErrorValidator.evWrongInputList,
3456 TosaErrorValidator.evWrongOutputList,
3457 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003458 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003459 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003460 "data_gen": {
3461 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3462 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003463 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003464 "minimum": {
3465 "op": Op.MINIMUM,
3466 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003467 "build_fcn": (
3468 build_binary_broadcast,
3469 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003470 TosaTensorValuesGen.tvgLazyGenDefault,
3471 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003472 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003474 "error_if_validators": (
3475 TosaErrorValidator.evRankMismatch,
3476 TosaErrorValidator.evWrongInputType,
3477 TosaErrorValidator.evWrongOutputType,
3478 TosaErrorValidator.evWrongInputList,
3479 TosaErrorValidator.evWrongOutputList,
3480 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003481 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003482 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003483 "data_gen": {
3484 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3485 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003486 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003487 "mul": {
3488 "op": Op.MUL,
3489 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003490 "build_fcn": (
3491 build_mul,
3492 TosaTensorGen.tgBroadcastFuzz,
3493 TosaTensorValuesGen.tvgMul,
3494 TosaArgGen.agMul,
3495 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003496 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003497 "error_if_validators": (
3498 TosaErrorValidator.evWrongInputType,
3499 TosaErrorValidator.evWrongOutputType,
3500 TosaErrorValidator.evWrongInputList,
3501 TosaErrorValidator.evWrongOutputList,
3502 TosaErrorValidator.evRankMismatch,
3503 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003504 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003505 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003506 "data_gen": {
3507 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3508 },
3509 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003510 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003511 "pow": {
3512 "op": Op.POW,
3513 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003514 "build_fcn": (
3515 build_binary_broadcast,
3516 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003517 TosaTensorValuesGen.tvgLazyGenDefault,
3518 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003519 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003520 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003521 "error_if_validators": (
3522 TosaErrorValidator.evRankMismatch,
3523 TosaErrorValidator.evWrongInputType,
3524 TosaErrorValidator.evWrongOutputType,
3525 TosaErrorValidator.evWrongInputList,
3526 TosaErrorValidator.evWrongOutputList,
3527 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003528 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003529 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003530 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003531 "sub": {
3532 "op": Op.SUB,
3533 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003534 "build_fcn": (
3535 build_binary_broadcast,
3536 TosaTensorGen.tgBroadcastFuzz,
3537 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003538 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003539 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003540 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003541 "error_if_validators": (
3542 TosaErrorValidator.evRankMismatch,
3543 TosaErrorValidator.evWrongInputType,
3544 TosaErrorValidator.evWrongOutputType,
3545 TosaErrorValidator.evWrongInputList,
3546 TosaErrorValidator.evWrongOutputList,
3547 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003548 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003549 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003550 "data_gen": {
3551 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3552 },
3553 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003554 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003555 "table": {
3556 "op": Op.TABLE,
3557 # Use the automatic generation functions to create the input array
3558 # but create the table tensor in the build function, as it may be
3559 # a different type from the input
3560 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003561 "build_fcn": (
3562 build_table,
3563 TosaTensorGen.tgBasic,
3564 TosaTensorValuesGen.tvgDefault,
3565 TosaArgGen.agTable,
3566 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003567 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003568 "error_if_validators": (
3569 TosaErrorValidator.evWrongInputType,
3570 TosaErrorValidator.evWrongOutputType,
3571 TosaErrorValidator.evWrongInputList,
3572 TosaErrorValidator.evWrongOutputList,
3573 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003574 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003575 # Elementwise Unary operators
3576 "abs": {
3577 "op": Op.ABS,
3578 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003579 "build_fcn": (
3580 build_unary,
3581 TosaTensorGen.tgBasic,
3582 TosaTensorValuesGen.tvgDefault,
3583 None,
3584 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003585 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003586 "error_if_validators": (
3587 TosaErrorValidator.evWrongInputType,
3588 TosaErrorValidator.evWrongOutputType,
3589 TosaErrorValidator.evWrongInputList,
3590 TosaErrorValidator.evWrongOutputList,
3591 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003592 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003593 "bitwise_not": {
3594 "op": Op.BITWISE_NOT,
3595 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003596 "build_fcn": (
3597 build_unary,
3598 TosaTensorGen.tgBasic,
3599 TosaTensorValuesGen.tvgDefault,
3600 None,
3601 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003602 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003603 "error_if_validators": (
3604 TosaErrorValidator.evWrongInputType,
3605 TosaErrorValidator.evWrongOutputType,
3606 TosaErrorValidator.evWrongInputList,
3607 TosaErrorValidator.evWrongOutputList,
3608 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003609 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003610 "ceil": {
3611 "op": Op.CEIL,
3612 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003613 "build_fcn": (
3614 build_unary,
3615 TosaTensorGen.tgBasic,
3616 TosaTensorValuesGen.tvgDefault,
3617 None,
3618 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003619 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003620 "error_if_validators": (
3621 TosaErrorValidator.evWrongInputType,
3622 TosaErrorValidator.evWrongOutputType,
3623 TosaErrorValidator.evWrongInputList,
3624 TosaErrorValidator.evWrongOutputList,
3625 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003626 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003627 "clz": {
3628 "op": Op.CLZ,
3629 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003630 "build_fcn": (
3631 build_unary,
3632 TosaTensorGen.tgBasic,
3633 TosaTensorValuesGen.tvgDefault,
3634 None,
3635 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003636 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003637 "error_if_validators": (
3638 TosaErrorValidator.evWrongInputType,
3639 TosaErrorValidator.evWrongOutputType,
3640 TosaErrorValidator.evWrongInputList,
3641 TosaErrorValidator.evWrongOutputList,
3642 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003643 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003644 "exp": {
3645 "op": Op.EXP,
3646 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003647 "build_fcn": (
3648 build_unary,
3649 TosaTensorGen.tgBasic,
3650 TosaTensorValuesGen.tvgDefault,
3651 None,
3652 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003653 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003654 "error_if_validators": (
3655 TosaErrorValidator.evWrongInputType,
3656 TosaErrorValidator.evWrongOutputType,
3657 TosaErrorValidator.evWrongInputList,
3658 TosaErrorValidator.evWrongOutputList,
3659 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003660 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003661 "floor": {
3662 "op": Op.FLOOR,
3663 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003664 "build_fcn": (
3665 build_unary,
3666 TosaTensorGen.tgBasic,
3667 TosaTensorValuesGen.tvgDefault,
3668 None,
3669 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003670 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003671 "error_if_validators": (
3672 TosaErrorValidator.evWrongInputType,
3673 TosaErrorValidator.evWrongOutputType,
3674 TosaErrorValidator.evWrongInputList,
3675 TosaErrorValidator.evWrongOutputList,
3676 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003677 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003678 "log": {
3679 "op": Op.LOG,
3680 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003681 "build_fcn": (
3682 build_unary,
3683 TosaTensorGen.tgBasic,
3684 TosaTensorValuesGen.tvgDefault,
3685 None,
3686 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003687 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003688 "error_if_validators": (
3689 TosaErrorValidator.evWrongInputType,
3690 TosaErrorValidator.evWrongOutputType,
3691 TosaErrorValidator.evWrongInputList,
3692 TosaErrorValidator.evWrongOutputList,
3693 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003694 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003695 "logical_not": {
3696 "op": Op.LOGICAL_NOT,
3697 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003698 "build_fcn": (
3699 build_unary,
3700 TosaTensorGen.tgBasic,
3701 TosaTensorValuesGen.tvgDefault,
3702 None,
3703 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003704 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003705 "error_if_validators": (
3706 TosaErrorValidator.evWrongInputType,
3707 TosaErrorValidator.evWrongOutputType,
3708 TosaErrorValidator.evWrongInputList,
3709 TosaErrorValidator.evWrongOutputList,
3710 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003711 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003712 "negate": {
3713 "op": Op.NEGATE,
3714 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003715 "build_fcn": (
3716 build_unary,
3717 TosaTensorGen.tgBasic,
3718 TosaTensorValuesGen.tvgNegate,
3719 None,
3720 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003721 "qgen": TosaQuantGen.qgUnary,
3722 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003723 "error_if_validators": (
3724 TosaErrorValidator.evInputZeroPointNotZero,
3725 TosaErrorValidator.evOutputZeroPointNotZero,
3726 TosaErrorValidator.evWrongInputType,
3727 TosaErrorValidator.evWrongOutputType,
3728 TosaErrorValidator.evWrongInputList,
3729 TosaErrorValidator.evWrongOutputList,
3730 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003731 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003732 "reciprocal": {
3733 "op": Op.RECIPROCAL,
3734 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003735 "build_fcn": (
3736 build_unary,
3737 TosaTensorGen.tgBasic,
3738 TosaTensorValuesGen.tvgDefault,
3739 None,
3740 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003741 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003742 "error_if_validators": (
3743 TosaErrorValidator.evWrongInputType,
3744 TosaErrorValidator.evWrongOutputType,
3745 TosaErrorValidator.evWrongInputList,
3746 TosaErrorValidator.evWrongOutputList,
3747 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003748 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003749 "rsqrt": {
3750 "op": Op.RSQRT,
3751 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003752 "build_fcn": (
3753 build_unary,
3754 TosaTensorGen.tgBasic,
3755 TosaTensorValuesGen.tvgDefault,
3756 None,
3757 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003758 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003759 "error_if_validators": (
3760 TosaErrorValidator.evWrongInputType,
3761 TosaErrorValidator.evWrongOutputType,
3762 TosaErrorValidator.evWrongInputList,
3763 TosaErrorValidator.evWrongOutputList,
3764 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003765 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003766 # Elementwise Ternary operators
3767 "select": {
3768 "op": Op.SELECT,
3769 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003770 "build_fcn": (
3771 build_select,
3772 TosaTensorGen.tgBroadcastFuzz,
3773 TosaTensorValuesGen.tvgSelect,
3774 None,
3775 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003776 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003777 "error_if_validators": (
3778 TosaErrorValidator.evRankMismatch,
3779 TosaErrorValidator.evWrongInputType,
3780 TosaErrorValidator.evWrongOutputType,
3781 TosaErrorValidator.evWrongInputList,
3782 TosaErrorValidator.evWrongOutputList,
3783 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003784 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003785 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003786 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003787 # Comparison operators
3788 "equal": {
3789 "op": Op.EQUAL,
3790 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003791 "build_fcn": (
3792 build_comparison,
3793 TosaTensorGen.tgBroadcastFuzz,
3794 TosaTensorValuesGen.tvgEqual,
3795 None,
3796 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003797 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003798 "error_if_validators": (
3799 TosaErrorValidator.evRankMismatch,
3800 TosaErrorValidator.evWrongInputType,
3801 TosaErrorValidator.evWrongOutputType,
3802 TosaErrorValidator.evWrongInputList,
3803 TosaErrorValidator.evWrongOutputList,
3804 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003805 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003806 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003807 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003808 "greater_equal": {
3809 "op": Op.GREATER_EQUAL,
3810 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003811 "build_fcn": (
3812 build_comparison,
3813 TosaTensorGen.tgBroadcastFuzz,
3814 TosaTensorValuesGen.tvgDefault,
3815 None,
3816 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003817 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003818 "error_if_validators": (
3819 TosaErrorValidator.evRankMismatch,
3820 TosaErrorValidator.evWrongInputType,
3821 TosaErrorValidator.evWrongOutputType,
3822 TosaErrorValidator.evWrongInputList,
3823 TosaErrorValidator.evWrongOutputList,
3824 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003825 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003826 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003827 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003828 "greater": {
3829 "op": Op.GREATER,
3830 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003831 "build_fcn": (
3832 build_comparison,
3833 TosaTensorGen.tgBroadcastFuzz,
3834 TosaTensorValuesGen.tvgDefault,
3835 None,
3836 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003837 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003838 "error_if_validators": (
3839 TosaErrorValidator.evRankMismatch,
3840 TosaErrorValidator.evWrongInputType,
3841 TosaErrorValidator.evWrongOutputType,
3842 TosaErrorValidator.evWrongInputList,
3843 TosaErrorValidator.evWrongOutputList,
3844 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003845 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003846 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003847 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003848 # Reduction operators
3849 "reduce_all": {
3850 "op": Op.REDUCE_ALL,
3851 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003852 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003853 "build_fcn": (
3854 build_reduce,
3855 TosaTensorGen.tgBasic,
3856 TosaTensorValuesGen.tvgDefault,
3857 TosaArgGen.agAxis,
3858 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003859 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003860 "error_if_validators": (
3861 TosaErrorValidator.evAxisLargerRank,
3862 TosaErrorValidator.evAxisSmallerZero,
3863 TosaErrorValidator.evShapeOfAxisNotOne,
3864 TosaErrorValidator.evWrongInputType,
3865 TosaErrorValidator.evWrongOutputType,
3866 TosaErrorValidator.evWrongRank,
3867 TosaErrorValidator.evWrongInputList,
3868 TosaErrorValidator.evWrongOutputList,
3869 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003870 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003871 "reduce_any": {
3872 "op": Op.REDUCE_ANY,
3873 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003874 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003875 "build_fcn": (
3876 build_reduce,
3877 TosaTensorGen.tgBasic,
3878 TosaTensorValuesGen.tvgDefault,
3879 TosaArgGen.agAxis,
3880 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003881 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003882 "error_if_validators": (
3883 TosaErrorValidator.evAxisLargerRank,
3884 TosaErrorValidator.evAxisSmallerZero,
3885 TosaErrorValidator.evShapeOfAxisNotOne,
3886 TosaErrorValidator.evWrongInputType,
3887 TosaErrorValidator.evWrongOutputType,
3888 TosaErrorValidator.evWrongRank,
3889 TosaErrorValidator.evWrongInputList,
3890 TosaErrorValidator.evWrongOutputList,
3891 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003892 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003893 "reduce_max": {
3894 "op": Op.REDUCE_MAX,
3895 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003896 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003897 "build_fcn": (
3898 build_reduce,
3899 TosaTensorGen.tgBasic,
3900 TosaTensorValuesGen.tvgDefault,
3901 TosaArgGen.agAxis,
3902 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003903 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003904 "error_if_validators": (
3905 TosaErrorValidator.evAxisLargerRank,
3906 TosaErrorValidator.evAxisSmallerZero,
3907 TosaErrorValidator.evShapeOfAxisNotOne,
3908 TosaErrorValidator.evWrongInputType,
3909 TosaErrorValidator.evWrongOutputType,
3910 TosaErrorValidator.evWrongRank,
3911 TosaErrorValidator.evWrongInputList,
3912 TosaErrorValidator.evWrongOutputList,
3913 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003914 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003915 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003916 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003917 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003918 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003919 "build_fcn": (
3920 build_reduce,
3921 TosaTensorGen.tgBasic,
3922 TosaTensorValuesGen.tvgDefault,
3923 TosaArgGen.agAxis,
3924 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003925 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003926 "error_if_validators": (
3927 TosaErrorValidator.evAxisLargerRank,
3928 TosaErrorValidator.evAxisSmallerZero,
3929 TosaErrorValidator.evShapeOfAxisNotOne,
3930 TosaErrorValidator.evWrongInputType,
3931 TosaErrorValidator.evWrongOutputType,
3932 TosaErrorValidator.evWrongRank,
3933 TosaErrorValidator.evWrongInputList,
3934 TosaErrorValidator.evWrongOutputList,
3935 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003936 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003937 "reduce_product": {
3938 "op": Op.REDUCE_PRODUCT,
3939 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003940 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003941 "build_fcn": (
3942 build_reduce,
3943 TosaTensorGen.tgBasic,
3944 TosaTensorValuesGen.tvgDefault,
3945 TosaArgGen.agAxis,
3946 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003947 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003948 "error_if_validators": (
3949 TosaErrorValidator.evAxisLargerRank,
3950 TosaErrorValidator.evAxisSmallerZero,
3951 TosaErrorValidator.evShapeOfAxisNotOne,
3952 TosaErrorValidator.evWrongInputType,
3953 TosaErrorValidator.evWrongOutputType,
3954 TosaErrorValidator.evWrongRank,
3955 TosaErrorValidator.evWrongInputList,
3956 TosaErrorValidator.evWrongOutputList,
3957 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003958 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003959 "reduce_sum": {
3960 "op": Op.REDUCE_SUM,
3961 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003962 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003963 "build_fcn": (
3964 build_reduce,
3965 TosaTensorGen.tgBasic,
3966 TosaTensorValuesGen.tvgReduceSum,
3967 TosaArgGen.agAxis,
3968 ),
James Ward24dbc422022-10-19 12:20:31 +01003969 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003970 "error_if_validators": (
3971 TosaErrorValidator.evAxisLargerRank,
3972 TosaErrorValidator.evAxisSmallerZero,
3973 TosaErrorValidator.evShapeOfAxisNotOne,
3974 TosaErrorValidator.evWrongInputType,
3975 TosaErrorValidator.evWrongOutputType,
3976 TosaErrorValidator.evWrongRank,
3977 TosaErrorValidator.evWrongInputList,
3978 TosaErrorValidator.evWrongOutputList,
3979 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003980 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003981 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003982 "concat": {
3983 "op": Op.CONCAT,
3984 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003985 "build_fcn": (
3986 build_concat,
3987 TosaTensorGen.tgConcat,
3988 TosaTensorValuesGen.tvgConcat,
3989 TosaArgGen.agAxis,
3990 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003991 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003992 "error_if_validators": (
3993 TosaErrorValidator.evAxisLargerRank,
3994 TosaErrorValidator.evAxisSmallerZero,
3995 TosaErrorValidator.evConcatInputRankMismatch,
3996 TosaErrorValidator.evConcatShapeSumMismatch,
3997 TosaErrorValidator.evConcatInputDimMismatch,
3998 TosaErrorValidator.evWrongInputType,
3999 TosaErrorValidator.evWrongOutputType,
4000 TosaErrorValidator.evWrongOutputList,
4001 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004002 },
4003 "pad": {
4004 "op": Op.PAD,
4005 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004006 "build_fcn": (
4007 build_pad,
4008 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004009 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004010 TosaArgGen.agPad,
4011 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004012 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004013 "error_if_validators": (
4014 TosaErrorValidator.evWrongInputType,
4015 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004016 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004017 TosaErrorValidator.evWrongOutputType,
4018 TosaErrorValidator.evWrongInputList,
4019 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004020 TosaErrorValidator.evRankMismatch,
4021 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004022 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004023 "data_gen": {
4024 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4025 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004026 },
Won Jeona21b2e82023-08-10 10:33:01 +00004027 "dim": {
4028 "op": Op.DIM,
4029 "operands": (1, 0),
4030 "build_fcn": (
4031 build_dim,
4032 TosaTensorGen.tgBasic,
4033 TosaTensorValuesGen.tvgDefault,
4034 TosaArgGen.agAxis,
4035 ),
4036 "types": TYPE_FIB,
4037 "error_if_validators": (
4038 TosaErrorValidator.evAxisLargerRank,
4039 TosaErrorValidator.evAxisSmallerZero,
4040 TosaErrorValidator.evWrongInputType,
4041 TosaErrorValidator.evWrongInputList,
4042 TosaErrorValidator.evWrongOutputList,
4043 TosaErrorValidator.evWrongRank,
4044 ),
4045 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004046 "reshape": {
4047 "op": Op.RESHAPE,
4048 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004049 "build_fcn": (
4050 build_reshape,
4051 TosaTensorGen.tgBasic,
4052 TosaTensorValuesGen.tvgDefault,
4053 TosaArgGen.agReshape,
4054 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004055 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004056 "error_if_validators": (
4057 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4058 TosaErrorValidator.evWrongInputType,
4059 TosaErrorValidator.evWrongOutputType,
4060 TosaErrorValidator.evWrongInputList,
4061 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00004062 TosaErrorValidator.evReshapeOutputSizeMultiInference,
4063 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004064 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004065 },
4066 "reverse": {
4067 "op": Op.REVERSE,
4068 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004069 "build_fcn": (
4070 build_reverse,
4071 TosaTensorGen.tgBasic,
4072 TosaTensorValuesGen.tvgDefault,
4073 TosaArgGen.agAxis,
4074 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004075 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004076 "error_if_validators": (
4077 TosaErrorValidator.evAxisSmallerZero,
4078 TosaErrorValidator.evAxisLargerRank,
4079 TosaErrorValidator.evWrongInputType,
4080 TosaErrorValidator.evWrongOutputType,
4081 TosaErrorValidator.evWrongInputList,
4082 TosaErrorValidator.evWrongOutputList,
4083 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004084 },
4085 "slice": {
4086 "op": Op.SLICE,
4087 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004088 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004089 "build_fcn": (
4090 build_slice,
4091 TosaTensorGen.tgBasic,
4092 TosaTensorValuesGen.tvgDefault,
4093 TosaArgGen.agSlice,
4094 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004095 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004096 "error_if_validators": (
4097 TosaErrorValidator.evStartSmallerZero,
4098 TosaErrorValidator.evSizeSmallerEqualZero,
4099 TosaErrorValidator.evStartSizeOutsideBounds,
4100 TosaErrorValidator.evSizeOutputShapeMismatch,
4101 TosaErrorValidator.evInputSizeStartLengthMismatch,
4102 TosaErrorValidator.evWrongRank,
4103 TosaErrorValidator.evWrongInputType,
4104 TosaErrorValidator.evWrongOutputType,
4105 TosaErrorValidator.evWrongInputList,
4106 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004107 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004108 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004109 },
4110 "tile": {
4111 "op": Op.TILE,
4112 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004113 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004114 "build_fcn": (
4115 build_tile,
4116 TosaTensorGen.tgBasic,
4117 TosaTensorValuesGen.tvgDefault,
4118 TosaArgGen.agTile,
4119 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004120 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004121 "error_if_validators": (
4122 TosaErrorValidator.evWrongInputType,
4123 TosaErrorValidator.evWrongOutputType,
4124 TosaErrorValidator.evWrongInputList,
4125 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004126 TosaErrorValidator.evRankMismatch,
4127 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004128 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004129 },
4130 "transpose": {
4131 "op": Op.TRANSPOSE,
4132 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004133 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004134 "build_fcn": (
4135 build_transpose,
4136 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004137 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004138 TosaArgGen.agTranspose,
4139 ),
4140 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004141 "error_if_validators": (
4142 TosaErrorValidator.evIndexOutsideBounds,
4143 TosaErrorValidator.evIndexUsedTwice,
4144 TosaErrorValidator.evWrongInputType,
4145 TosaErrorValidator.evWrongOutputType,
4146 TosaErrorValidator.evWrongInputList,
4147 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004148 TosaErrorValidator.evWrongRank,
4149 TosaErrorValidator.evRankMismatch,
4150 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004151 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004152 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004153 # Data nodes
4154 "const": {
4155 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004156 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004157 "build_fcn": (
4158 build_const,
4159 TosaTensorGen.tgBasic,
4160 TosaTensorValuesGen.tvgDefault,
4161 None,
4162 ),
Luke Hutton65872422023-02-20 10:33:04 +00004163 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004164 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004165 "identity": {
4166 "op": Op.IDENTITY,
4167 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004168 "build_fcn": (
4169 build_unary,
4170 TosaTensorGen.tgBasic,
4171 TosaTensorValuesGen.tvgDefault,
4172 None,
4173 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004174 "types": TYPE_FIB,
4175 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004176 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004177 "gather": {
4178 "op": Op.GATHER,
4179 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4180 "operands": (1, 0),
4181 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004182 "build_fcn": (
4183 build_gather,
4184 TosaTensorGen.tgBasic,
4185 TosaTensorValuesGen.tvgDefault,
4186 None,
4187 ),
James Ward24dbc422022-10-19 12:20:31 +01004188 "types": (
4189 DType.INT8,
4190 DType.INT16,
4191 DType.INT32,
4192 DType.FP16,
4193 DType.BF16,
4194 DType.FP32,
4195 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004196 "error_if_validators": (
4197 TosaErrorValidator.evWrongInputType,
4198 TosaErrorValidator.evWrongOutputType,
4199 TosaErrorValidator.evWrongInputList,
4200 TosaErrorValidator.evWrongOutputList,
4201 TosaErrorValidator.evWrongRank,
4202 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004203 },
4204 "scatter": {
4205 "op": Op.SCATTER,
4206 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004207 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004208 "operands": (2, 0),
4209 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004210 "build_fcn": (
4211 build_scatter,
4212 TosaTensorGen.tgScatter,
4213 TosaTensorValuesGen.tvgDefault,
4214 None,
4215 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004216 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004217 "error_if_validators": (
4218 TosaErrorValidator.evWrongInputType,
4219 TosaErrorValidator.evWrongOutputType,
4220 TosaErrorValidator.evWrongInputList,
4221 TosaErrorValidator.evWrongOutputList,
4222 TosaErrorValidator.evWrongRank,
4223 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004224 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004225 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004226 "resize": {
4227 "op": Op.RESIZE,
4228 "operands": (1, 0),
4229 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004230 "build_fcn": (
4231 build_resize,
4232 TosaTensorGen.tgNHWC,
4233 TosaTensorValuesGen.tvgDefault,
4234 TosaArgGen.agResize,
4235 ),
James Ward24dbc422022-10-19 12:20:31 +01004236 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004237 "invalid_test_validators": (
4238 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004239 ),
4240 "error_if_validators": (
4241 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004242 TosaErrorValidator.evScaleSmallerEqualZero,
4243 TosaErrorValidator.evScaleNLargerMax,
4244 TosaErrorValidator.evScaleDLargerMax,
4245 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004246 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004247 TosaErrorValidator.evBorderSmallerMin,
4248 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004249 TosaErrorValidator.evWrongInputType,
4250 TosaErrorValidator.evWrongOutputType,
4251 TosaErrorValidator.evWrongRank,
4252 TosaErrorValidator.evWrongInputList,
4253 TosaErrorValidator.evWrongOutputList,
4254 TosaErrorValidator.evBatchMismatch,
4255 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004256 TosaErrorValidator.evResizeOutputShapeMismatch,
4257 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004258 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004259 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004260 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004261 "cast": {
4262 "op": Op.CAST,
4263 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004264 "build_fcn": (
4265 build_cast,
4266 TosaTensorGen.tgBasic,
4267 TosaTensorValuesGen.tvgDefault,
4268 TosaArgGen.agCast,
4269 ),
James Ward8b390432022-08-12 20:48:56 +01004270 "types": (
4271 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004272 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004273 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004274 DType.INT8,
4275 DType.INT16,
4276 DType.INT32,
4277 DType.BOOL,
4278 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004279 "error_if_validators": (
4280 TosaErrorValidator.evWrongInputType,
4281 TosaErrorValidator.evWrongOutputType,
4282 TosaErrorValidator.evWrongInputList,
4283 TosaErrorValidator.evWrongOutputList,
4284 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004285 },
4286 "rescale": {
4287 "op": Op.RESCALE,
4288 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004289 "build_fcn": (
4290 build_rescale,
4291 TosaTensorGen.tgBasic,
4292 TosaTensorValuesGen.tvgDefault,
4293 TosaArgGen.agRescale,
4294 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004295 "types": [
4296 DType.UINT8,
4297 DType.INT8,
4298 DType.INT16,
4299 DType.INT32,
4300 DType.INT48,
4301 DType.UINT16,
4302 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004303 "error_if_validators": (
4304 TosaErrorValidator.evInputZeroPointNotZero,
4305 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004306 TosaErrorValidator.evU16InputZeroPointNotValid,
4307 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004308 TosaErrorValidator.evScaleTrue,
4309 TosaErrorValidator.evScaleNotTrue,
4310 TosaErrorValidator.evWrongInputType,
4311 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004312 TosaErrorValidator.evWrongInputList,
4313 TosaErrorValidator.evWrongOutputList,
4314 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004315 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004316 # Custom
4317 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004318 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004319 # Two varients of cond_if, one that generates one of two constant tensors (no
4320 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4321 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004322 "cond_if_const": {
4323 "op": Op.COND_IF,
4324 "operands": (0, 2),
4325 "build_fcn": (
4326 build_cond_if_const,
4327 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004328 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004329 TosaArgGen.agCondIf,
4330 ),
4331 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004332 "error_if_validators": (
4333 TosaErrorValidator.evOutputListThenGraphMismatch,
4334 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004335 TosaErrorValidator.evCondIfCondNotMatchingBool,
4336 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004337 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004338 },
4339 "cond_if_binary": {
4340 "op": Op.COND_IF,
4341 "operands": (2, 0),
4342 "build_fcn": (
4343 build_cond_if_binary,
4344 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004345 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004346 TosaArgGen.agCondIf,
4347 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004348 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004349 "error_if_validators": (
4350 TosaErrorValidator.evInputListThenGraphMismatch,
4351 TosaErrorValidator.evInputListElseGraphMismatch,
4352 TosaErrorValidator.evOutputListThenGraphMismatch,
4353 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004354 TosaErrorValidator.evCondIfCondNotMatchingBool,
4355 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004356 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004357 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004358 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004359 "while_loop": {
4360 "op": Op.WHILE_LOOP,
4361 "operands": (0, 1),
4362 "build_fcn": (
4363 build_while_loop,
4364 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004365 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004366 TosaArgGen.agWhileLoop,
4367 ),
4368 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004369 "error_if_validators": (
4370 TosaErrorValidator.evInputListOutputListMismatch,
4371 TosaErrorValidator.evInputListCondGraphMismatch,
4372 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4373 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4374 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004375 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004376 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004377 },
Luke Hutton57287132023-02-06 14:54:18 +00004378 "fft2d": {
4379 "op": Op.FFT2D,
4380 "operands": (2, 0),
4381 "rank": (3, 3),
4382 "build_fcn": (
4383 build_fft2d,
4384 TosaTensorGen.tgFFT2d,
4385 TosaTensorValuesGen.tvgDefault,
4386 TosaArgGen.agFFT2d,
4387 ),
4388 "types": [DType.FP32],
4389 "error_if_validators": (
4390 TosaErrorValidator.evWrongInputType,
4391 TosaErrorValidator.evWrongOutputType,
4392 TosaErrorValidator.evWrongInputList,
4393 TosaErrorValidator.evWrongOutputList,
4394 TosaErrorValidator.evWrongRank,
4395 TosaErrorValidator.evBatchMismatch,
4396 TosaErrorValidator.evKernelNotPowerOfTwo,
4397 TosaErrorValidator.evFFTInputShapeMismatch,
4398 TosaErrorValidator.evFFTOutputShapeMismatch,
4399 ),
4400 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004401 "rfft2d": {
4402 "op": Op.RFFT2D,
4403 "operands": (1, 0),
4404 "rank": (3, 3),
4405 "build_fcn": (
4406 build_rfft2d,
4407 TosaTensorGen.tgRFFT2d,
4408 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004409 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004410 ),
4411 "types": [DType.FP32],
4412 "error_if_validators": (
4413 TosaErrorValidator.evWrongInputType,
4414 TosaErrorValidator.evWrongOutputType,
4415 TosaErrorValidator.evWrongInputList,
4416 TosaErrorValidator.evWrongOutputList,
4417 TosaErrorValidator.evWrongRank,
4418 TosaErrorValidator.evBatchMismatch,
4419 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004420 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004421 ),
4422 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004423 }
4424
Kevin Cheng550ccc52021-03-03 11:21:43 -08004425
Eric Kunzee5e26762020-10-13 16:11:07 -07004426class OutputShaper:
4427 # Methods in this class compute the expected output shape and datatype
4428 # for common classes of operations
4429 def __init__(self):
4430 pass
4431
4432 # These methods return arguments that can be used for
4433 # creating a new output tensor
4434 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004435 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4436 if error_name != ErrorIf.RankMismatch:
4437 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004438 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004439
4440 shape = []
4441 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004442 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004443 shape.append(b.shape[i])
4444 else:
4445 shape.append(a.shape[i])
4446
Jerry Ge135c9552023-05-23 20:59:32 +00004447 fuzz_idx = rng.integers(0, len(a.shape))
4448 if error_name == ErrorIf.DimensionMismatch:
4449 shape[fuzz_idx] += 1
4450
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004451 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004452 all_dtypes = [
4453 DType.INT8,
4454 DType.INT16,
4455 DType.INT32,
4456 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004457 DType.FP16,
4458 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004459 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004460 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004461 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4462 outputDType = rng.choice(wrong_dtypes)
4463 else:
4464 outputDType = a.dtype
4465
4466 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004467
4468 @staticmethod
4469 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004470 assert len(a.shape) == len(b.shape)
4471 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004472
4473 shape = []
4474 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004475 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004476 shape.append(a.shape[i])
4477
Kevin Cheng550ccc52021-03-03 11:21:43 -08004478 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004479
4480 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004481 def unaryOp(ser, rng, a, error_name=None):
4482 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004483 all_dtypes = [
4484 DType.INT8,
4485 DType.INT16,
4486 DType.INT32,
4487 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004488 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004489 DType.FP16,
4490 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004491 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004492 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4493 outputDType = rng.choice(wrong_dtypes)
4494 else:
4495 outputDType = a.dtype
4496
4497 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004498
4499 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004500 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004501 if error_name != ErrorIf.RankMismatch:
4502 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004503 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004504
4505 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004506 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004507 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004508 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4509 else:
4510 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004511
Jerry Ge135c9552023-05-23 20:59:32 +00004512 fuzz_idx = rng.integers(0, len(a.shape))
4513 if error_name == ErrorIf.DimensionMismatch:
4514 shape[fuzz_idx] += 1
4515
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004516 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004517 all_dtypes = [
4518 DType.INT8,
4519 DType.INT16,
4520 DType.INT32,
4521 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004522 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004523 DType.FP16,
4524 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004525 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004526 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4527 outputDType = rng.choice(wrong_dtypes)
4528 else:
4529 outputDType = a.dtype
4530
4531 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004532
4533 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004534 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004535 if error_name != ErrorIf.RankMismatch:
4536 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004537 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004538
4539 # Do broadcast
4540 shape = []
4541 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004542 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004543 shape.append(b.shape[i])
4544 else:
4545 shape.append(a.shape[i])
4546
Jerry Ge135c9552023-05-23 20:59:32 +00004547 fuzz_idx = rng.integers(0, len(a.shape))
4548 if error_name == ErrorIf.DimensionMismatch:
4549 shape[fuzz_idx] += 1
4550
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004551 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004552 wrong_dtypes = [
4553 DType.INT8,
4554 DType.INT16,
4555 DType.INT32,
4556 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004557 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004558 DType.FP16,
4559 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004560 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004561 outputDType = rng.choice(wrong_dtypes)
4562 else:
4563 outputDType = DType.BOOL
4564
4565 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004566
4567 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004568 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004569 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004570 if error_name not in [
4571 ErrorIf.AxisSmallerZero,
4572 ErrorIf.AxisLargerRank,
4573 ErrorIf.ShapeOfAxisNotOne,
4574 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004575 shape[axis] = 1
4576 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4577 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004578
Matthew Haddond6ce7252021-09-29 15:35:44 +01004579 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004580 all_dtypes = [
4581 DType.INT8,
4582 DType.INT16,
4583 DType.INT32,
4584 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004585 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004586 DType.FP16,
4587 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004588 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004589 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4590 outputDType = rng.choice(wrong_dtypes)
4591 else:
4592 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004593
Matthew Haddond6ce7252021-09-29 15:35:44 +01004594 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004595
4596 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004597 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004598 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004599
4600 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4601 del shape[axis]
4602
4603 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4604 remove = rng.choice([True, False])
4605 if remove and len(shape) > 1:
4606 del shape[0]
4607 else:
4608 shape.append(1)
4609 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4610 for i in range(len(shape)):
4611 shape[i] = shape[i] + rng.integers(1, 10)
4612
4613 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004614 all_dtypes = [
4615 DType.INT8,
4616 DType.INT16,
4617 DType.INT32,
4618 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004619 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004620 DType.FP16,
4621 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004622 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004623 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4624 outputDType = rng.choice(wrong_dtypes)
4625 else:
4626 outputDType = DType.INT32
4627
4628 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004629
4630 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004631 def conv2dOp(
4632 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4633 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004634
4635 # IFM: NHWC
4636 # Filter: OHWI
4637 # OFM: NHWC
4638
Kevin Cheng550ccc52021-03-03 11:21:43 -08004639 h = (
4640 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004641 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004642 + padding[0]
4643 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004644 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004645 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004646
Kevin Cheng550ccc52021-03-03 11:21:43 -08004647 w = (
4648 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004649 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004650 + padding[2]
4651 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004652 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004653 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004654
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004655 if error_name == ErrorIf.ConvOutputShapeMismatch:
4656 choices = [1, 2, 3]
4657 change = rng.choice(choices)
4658 # increment in multiples of stride to not hit non-integer error case
4659 if change in [1, 3]:
4660 h = h + (rng.choice(choices) * strides[0])
4661 if change in [2, 3]:
4662 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004663
Eric Kunzee5e26762020-10-13 16:11:07 -07004664 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4665
James Ward8b390432022-08-12 20:48:56 +01004666 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004667 # Pick some potentially correct output dtype if input type is incorrect
4668 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004669 else:
James Ward8b390432022-08-12 20:48:56 +01004670 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004671
4672 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004673 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004674 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004675 else:
4676 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004677 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004678 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004679
Kevin Cheng550ccc52021-03-03 11:21:43 -08004680 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004681
4682 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004683 def conv3dOp(
4684 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4685 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004686
4687 # IFM: NDHWC
4688 # Filter: ODHWI
4689 # OFM: NDHWC
4690
4691 d = (
4692 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004693 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004694 + padding[0]
4695 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004696 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004697 ) // strides[0] + 1
4698
4699 h = (
4700 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004701 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004702 + padding[2]
4703 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004704 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004705 ) // strides[1] + 1
4706
4707 w = (
4708 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004709 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004710 + padding[4]
4711 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004712 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004713 ) // strides[2] + 1
4714
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004715 if error_name == ErrorIf.ConvOutputShapeMismatch:
4716 choices = [1, 2, 3, 4]
4717 change = rng.choice(choices)
4718 # increment in multiples of stride to not hit non-integer error case
4719 if change in [1, 4]:
4720 d = d + (rng.choice(choices) * strides[0])
4721 if change in [2, 4]:
4722 h = h + (rng.choice(choices) * strides[1])
4723 if change in [3, 4]:
4724 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004725
Kevin Cheng1533b852021-09-01 12:51:58 -07004726 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4727
James Ward8b390432022-08-12 20:48:56 +01004728 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004729 # Pick some potentially correct output dtype if input type is incorrect
4730 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004731 else:
James Ward8b390432022-08-12 20:48:56 +01004732 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004733
4734 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004735 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004736 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004737 else:
4738 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004739 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004740 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004741
4742 return ser.addOutput(ofm_shape, out_dtype)
4743
4744 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004745 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004746 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004747 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004748 # IFM: NHWC
4749 # Filter: HWCM
4750 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004751
Kevin Cheng550ccc52021-03-03 11:21:43 -08004752 h = (
4753 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004754 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004755 + padding[0]
4756 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004757 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004758 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004759
Kevin Cheng550ccc52021-03-03 11:21:43 -08004760 w = (
4761 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004762 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004763 + padding[2]
4764 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004765 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004766 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004767
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004768 if error_name == ErrorIf.ConvOutputShapeMismatch:
4769 choices = [1, 2, 3]
4770 change = rng.choice(choices)
4771 # increment in multiples of stride to not hit non-integer error case
4772 if change in [1, 3]:
4773 h = h + (rng.choice(choices) * strides[0])
4774 if change in [2, 3]:
4775 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004776
Eric Kunzee5e26762020-10-13 16:11:07 -07004777 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4778
James Ward8b390432022-08-12 20:48:56 +01004779 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004780 # Pick some potentially correct output dtype if input type is incorrect
4781 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004782 else:
James Ward8b390432022-08-12 20:48:56 +01004783 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004784
4785 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004786 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004787 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004788 else:
4789 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004790 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004791 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004792
Kevin Cheng550ccc52021-03-03 11:21:43 -08004793 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004794
4795 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004796 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004797 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004798 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004799 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004800 h = 1
4801 w = 1
4802 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004803 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4804 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004805
4806 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004807 choices = [1, 2, 3]
4808 change = rng.choice(choices)
4809 # increment in multiples of stride to not hit non-integer error case
4810 if change in [1, 3]:
4811 h = h + (rng.choice(choices) * stride[0])
4812 if change in [2, 3]:
4813 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004814 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004815
4816 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004817 all_dtypes = [
4818 DType.INT8,
4819 DType.INT16,
4820 DType.INT32,
4821 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004822 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004823 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004824 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004825 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004826 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4827 outputDType = rng.choice(wrong_dtypes)
4828 else:
4829 outputDType = ifm.dtype
4830
4831 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004832
4833 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004834 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004835 # input: N, IC
4836 # filter: OC, IC
4837 # output: N, OC
4838
4839 output_shape = [input.shape[0], filter.shape[0]]
4840
James Ward8b390432022-08-12 20:48:56 +01004841 # Validated in arg_gen (also invalidated for ErrorIf)
4842 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004843
Kevin Cheng550ccc52021-03-03 11:21:43 -08004844 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004845
4846 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004847 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004848 # a: N, H, C
4849 # b: N, C, W
4850 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004851
Kevin Cheng2d60f002021-06-09 14:18:32 -07004852 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004853
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004854 if error_name == ErrorIf.WrongOutputType:
4855 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004856 incorrect_types = (
4857 DType.INT4,
4858 DType.INT8,
4859 DType.INT16,
4860 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004861 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004862 DType.FP16,
4863 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004864 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004865 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004866 incorrect_types = (
4867 DType.INT4,
4868 DType.INT8,
4869 DType.INT16,
4870 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004871 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004872 DType.FP16,
4873 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004874 )
James Ward24dbc422022-10-19 12:20:31 +01004875 elif (
4876 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4877 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004878 incorrect_types = (
4879 DType.INT4,
4880 DType.INT8,
4881 DType.INT16,
4882 DType.INT32,
4883 DType.INT48,
4884 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004885 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004886 elif error_name == ErrorIf.WrongInputType:
4887 # Pick some potentially correct output dtype if input type is incorrect
4888 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004889 else:
James Ward8b390432022-08-12 20:48:56 +01004890 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004891
Kevin Cheng550ccc52021-03-03 11:21:43 -08004892 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004893
4894 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004895 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004896 input1 = a[0]
4897 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004898
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004899 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004900 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004901 if not (
4902 # unable to concat tensors of different ranks
4903 error_name == ErrorIf.ConcatInputRankMismatch
4904 # unable to concat tensors along an invalid axis
4905 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004906 ):
4907 for tensor in remaining_inputs:
4908 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004909
Matthew Haddon01c359d2021-10-15 16:30:48 +01004910 if error_name == ErrorIf.ConcatShapeSumMismatch:
4911 output_shape[axis] += rng.integers(5, 10)
4912
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004913 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004914 all_dtypes = {
4915 DType.INT8,
4916 DType.INT16,
4917 DType.INT32,
4918 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004919 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004920 DType.FP16,
4921 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004922 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004923 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4924 outputDType = rng.choice(wrong_dtypes)
4925 else:
4926 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004927
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004928 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004929
4930 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004931 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004932
4933 output_shape = a.shape.copy()
4934
4935 for i in range(len(output_shape)):
4936 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4937
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004938 if error_name == ErrorIf.PadOutputShapeMismatch:
4939 bad_dim = rng.choice(range(len(output_shape)))
4940 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004941 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004942 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004943
Matthew Haddone807aae2021-10-11 18:12:58 +01004944 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004945 all_dtypes = [
4946 DType.INT8,
4947 DType.INT16,
4948 DType.INT32,
4949 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004950 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004951 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004952 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004953 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004954 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4955 outputDType = rng.choice(wrong_dtypes)
4956 else:
4957 outputDType = a.dtype
4958
4959 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004960
4961 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00004962 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00004963 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00004964
4965 if error_name == ErrorIf.WrongOutputType:
4966 all_dtypes = [
4967 DType.INT8,
4968 DType.INT16,
4969 DType.INT32,
4970 DType.INT48,
4971 DType.FP32,
4972 DType.FP16,
4973 DType.BF16,
4974 ]
4975 wrong_dtypes = list(set(all_dtypes))
4976 outputDType = rng.choice(wrong_dtypes)
4977 else:
4978 outputDType = DType.SHAPE
4979
4980 return ser.addOutput(output_shape, outputDType)
4981
4982 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004983 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004984 output_shape = shape.copy()
4985
Matthew Haddone807aae2021-10-11 18:12:58 +01004986 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4987 for i in range(len(output_shape)):
4988 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4989
4990 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004991 all_dtypes = [
4992 DType.INT8,
4993 DType.INT16,
4994 DType.INT32,
4995 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004996 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004997 DType.FP16,
4998 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004999 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005000 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5001 outputDType = rng.choice(wrong_dtypes)
5002 else:
5003 outputDType = a.dtype
5004
5005 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005006
5007 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005008 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005009
Matthew Haddone807aae2021-10-11 18:12:58 +01005010 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005011 all_dtypes = [
5012 DType.INT8,
5013 DType.INT16,
5014 DType.INT32,
5015 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005016 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005017 DType.FP16,
5018 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005019 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005020 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005021 outputDType = rng.choice(wrong_dtypes)
5022 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005023 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005024
Luke Huttona4e48ca2023-02-22 11:53:48 +00005025 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005026 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005027 for index in range(len(output_shape)):
5028 if output_shape[index] <= 2:
5029 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5030 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005031 output_shape[index] = output_shape[index] + rng.choice(
5032 [-2, -1, 1, 2]
5033 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005034 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5035 output_shape = input.shape.copy()
5036 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005037 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005038
5039 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005040
5041 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005042 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005043
5044 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005045 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005046
5047 for i in range(len(output_shape)):
5048 output_shape[i] = a.shape[i] * multiples[i]
5049
Luke Huttona4e48ca2023-02-22 11:53:48 +00005050 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005051 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005052
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005053 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005054 all_dtypes = [
5055 DType.INT8,
5056 DType.INT16,
5057 DType.INT32,
5058 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005059 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005060 DType.FP16,
5061 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005062 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005063 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5064 outputDType = rng.choice(wrong_dtypes)
5065 else:
5066 outputDType = a.dtype
5067
5068 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005069
5070 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005071 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005072 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005073
Kevin Cheng550ccc52021-03-03 11:21:43 -08005074 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005075
Luke Huttona4e48ca2023-02-22 11:53:48 +00005076 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005077 for i in range(len(output_shape)):
5078 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005079
Luke Huttona4e48ca2023-02-22 11:53:48 +00005080 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5081 for i in range(len(output_shape)):
5082 output_shape[i] += rng.integers(1, 10)
5083 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005084 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005085
Matthew Haddone807aae2021-10-11 18:12:58 +01005086 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005087 all_dtypes = [
5088 DType.INT8,
5089 DType.INT16,
5090 DType.INT32,
5091 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005092 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005093 DType.FP16,
5094 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005095 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005096 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5097 outputDType = rng.choice(wrong_dtypes)
5098 else:
5099 outputDType = a.dtype
5100
5101 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005102
5103 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005104 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005105 if error_name != ErrorIf.WrongRank:
5106 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005107 assert len(indices.shape) == 2
5108 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005109
Kevin Cheng77d0f762020-11-24 10:26:32 -08005110 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5111
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005112 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005113 all_dtypes = [
5114 DType.INT8,
5115 DType.INT16,
5116 DType.INT32,
5117 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005118 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005119 DType.FP16,
5120 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005121 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005122 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5123 outputDType = rng.choice(wrong_dtypes)
5124 else:
5125 outputDType = values.dtype
5126
5127 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005128
5129 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005130 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005131 if error_name != ErrorIf.WrongRank:
5132 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005133 assert len(indices.shape) == 2
5134 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005135 assert values_in.shape[0] == indices.shape[0] # N
5136 assert input.shape[1] == indices.shape[1] # W
5137 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005138
5139 output_shape = values_in.shape
5140
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005141 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005142 all_dtypes = [
5143 DType.INT8,
5144 DType.INT16,
5145 DType.INT32,
5146 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005147 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005148 DType.FP16,
5149 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005150 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005151 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5152 outputDType = rng.choice(wrong_dtypes)
5153 else:
5154 outputDType = values_in.dtype
5155
5156 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005157
5158 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005159 def tableOp(ser, rng, input, error_name=None):
5160 # Same shape as the input, dtype dependent on input dtype
5161 if error_name != ErrorIf.WrongInputType:
5162 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005163 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005164 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005165 wrong_dtypes = [
5166 DType.INT8,
5167 DType.INT16,
5168 DType.INT32,
5169 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005170 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005171 DType.FP16,
5172 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005173 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005174 wrong_dtypes.remove(output_dtype)
5175 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005176 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005177
5178 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005179 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005180 serializer,
5181 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005182 input,
5183 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005184 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005185 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005186 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005187 input_dtype,
5188 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005189 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005190 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005191 # Calculate OH, OW
5192 scale_y_n = scale[0]
5193 scale_y_d = scale[1]
5194 scale_x_n = scale[2]
5195 scale_x_d = scale[3]
5196 if error_name == ErrorIf.ScaleSmallerEqualZero:
5197 scale_y_n = max(scale_y_n, 1)
5198 scale_y_d = max(scale_y_d, 1)
5199 scale_x_n = max(scale_x_n, 1)
5200 scale_x_d = max(scale_x_d, 1)
5201
5202 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5203 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5204
5205 if error_name is not None:
5206 # Make sure the output tensor is valid, which can occur when
5207 # scale, offset or border have been changed for ERROR_IFs
5208 oh = max(oh, 1)
5209 ow = max(ow, 1)
5210 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005211 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5212 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005213
5214 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5215 choices = [1, 2, 3]
5216 change = rng.choice(choices)
5217 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5218 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005219 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005220 oh -= scale_y_d
5221 assert oh > 0 # Should have been caught in agResize
5222 else:
5223 oh += scale_y_d
5224 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005225 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005226 ow -= scale_x_d
5227 assert ow > 0 # Should have been caught in agResize
5228 else:
5229 ow += scale_x_d
5230
Matthew Haddon848efb42021-09-09 12:30:53 +01005231 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005232 output_dims = [
5233 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005234 oh,
5235 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005236 input.shape[0],
5237 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005238 elif error_name == ErrorIf.BatchMismatch:
5239 output_dims = [
5240 input.shape[0] + rng.integers(1, 10),
5241 oh,
5242 ow,
5243 input.shape[3],
5244 ]
5245 elif error_name == ErrorIf.ChannelMismatch:
5246 output_dims = [
5247 input.shape[0],
5248 oh,
5249 ow,
5250 input.shape[3] + rng.integers(1, 10),
5251 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005252 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005253 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005254
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005255 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005256
5257 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005258 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005259 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005260
5261 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005262 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005263 if error_name == ErrorIf.ConvOutputShapeMismatch:
5264 choices = [1, 2, 3]
5265 change = rng.choice(choices)
5266 if change in [1, 3]:
5267 output_shape[1] = output_shape[1] + rng.choice(choices)
5268 if change in [2, 3]:
5269 output_shape[2] = output_shape[2] + rng.choice(choices)
5270
James Ward8b390432022-08-12 20:48:56 +01005271 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005272 # Pick some potentially correct output dtype if input type is incorrect
5273 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005274 else:
James Ward8b390432022-08-12 20:48:56 +01005275 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005276
5277 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005278 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005279 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005280 else:
5281 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005282 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005283 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005284
Kevin Cheng550ccc52021-03-03 11:21:43 -08005285 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005286
5287 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005288 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5289 outputs = []
5290
5291 assert ifm1.dtype == ifm2.dtype
5292 input_dtype = ifm1.dtype
5293
5294 if error_name != ErrorIf.FFTInputShapeMismatch:
5295 assert ifm1.shape == ifm2.shape
5296
5297 input_shape = ifm1.shape
5298 if error_name != ErrorIf.WrongRank:
5299 assert len(input_shape) == 3
5300
5301 output_shape = input_shape.copy()
5302 output_dtype = input_dtype
5303
5304 if error_name == ErrorIf.WrongOutputType:
5305 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005306 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005307 output_dtype = rng.choice(wrong_dtypes)
5308 elif error_name == ErrorIf.BatchMismatch:
5309 output_shape[0] += rng.integers(1, 10)
5310 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5311 modify_dim = rng.choice([1, 2])
5312 output_shape[modify_dim] += rng.integers(1, 10)
5313
5314 outputs.append(serializer.addOutput(output_shape, output_dtype))
5315 outputs.append(serializer.addOutput(output_shape, output_dtype))
5316 return outputs
5317
5318 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005319 def rfft2dOp(serializer, rng, value, error_name=None):
5320 outputs = []
5321
5322 input_shape = value.shape
5323 if error_name != ErrorIf.WrongRank:
5324 assert len(input_shape) == 3
5325
5326 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5327
5328 output_dtype = value.dtype
5329 if error_name == ErrorIf.WrongOutputType:
5330 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005331 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005332 output_dtype = rng.choice(wrong_dtypes)
5333 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005334 output_shape[0] += rng.integers(1, 10)
5335 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5336 modify_dim = rng.choice([1, 2])
5337 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005338
5339 outputs.append(serializer.addOutput(output_shape, output_dtype))
5340 outputs.append(serializer.addOutput(output_shape, output_dtype))
5341 return outputs