blob: 04093b8348626e464507c0d92ee937b2d4b62550 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
40 TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6)
41 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
69 vals.append(v)
70 return tuple(sorted(vals))
71
72 self.random_float_range = {}
73 for dtype in (DType.FP32, DType.FP16, DType.BF16):
74 self.random_float_range[dtype] = convertFPRange(
75 args.tensor_fp_value_range,
76 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
77 )
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 def createSerializer(self, opName, testPath):
80 self.testPath = os.path.join(opName, testPath)
81
82 fullPath = os.path.join(self.basePath, self.testPath)
83 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010084 # Embed const data in the flatbuffer
85 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010086 if self.args.lazy_data_gen:
87 # Lazy data generation - so make constants files
88 constMode = ts.ConstMode.INPUTS
89 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 constMode = ts.ConstMode.EMBED_DUMP
91 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070092
93 def getSerializer(self):
94 return self.ser
95
Jeremy Johnson1271c442023-09-05 11:39:26 +010096 def serialize(self, testName, metaData=None):
97 path = Path(self.basePath) / self.testPath
98
99 # Write out TOSA flatbuffer binary
100 path_fb = path / f"{testName}.tosa"
101 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 fd.write(self.ser.serialize())
103
Jeremy Johnson1271c442023-09-05 11:39:26 +0100104 # Get JSON descriptor from serializer
105 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
106
107 if metaData:
108 # Add extra meta data to desc.json
109 desc["meta"] = metaData
110
111 # Validate desc.json before we output it
112 self.descSchemaValidator.validate_config(desc)
113
114 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100115 if "data_gen" in metaData:
116 if self.args.lazy_data_gen:
117 # Output datagen meta data as CPP data
118 path_md = path / f"{testName}_meta_data_gen.cpp"
119 with path_md.open("w") as fd:
120 fd.write(TOSA_AUTOGENERATED_HEADER)
121 fd.write("// Test meta data for data generation setup\n\n")
122 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
123 json.dump(metaData["data_gen"], fd)
124 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100125 if "compliance" in metaData:
126 # Output datagen meta data as CPP data
127 path_md = path / f"{testName}_meta_compliance.cpp"
128 with path_md.open("w") as fd:
129 fd.write(TOSA_AUTOGENERATED_HEADER)
130 fd.write("// Test meta data for compliance validation\n\n")
131 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
132 json.dump(metaData["compliance"], fd)
133 fd.write(')";\n\n')
134
135 # Write desc.json
136 path_desc = path / "desc.json"
137 with path_desc.open("w") as fd:
138 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700139
Matthew Haddon74567092021-07-16 15:38:20 +0100140 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000141 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100142 seed = self.random_seed + 1
143 self.rng = np.random.default_rng(seed)
144
Jeremy Johnson1271c442023-09-05 11:39:26 +0100145 def getDTypeRange(self, dtype, high_inclusive=False):
146 # Returns dtype value range boundaries (low, high)
147 # The high boundary is excluded in the range
148 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100149 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100150 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 elif dtype == DType.BOOL:
152 rng = (0, 2)
153 elif dtype == DType.UINT8:
154 rng = (0, 256)
155 elif dtype == DType.UINT16:
156 rng = (0, 65536)
157 elif dtype == DType.INT4:
158 # TOSA specific INT4 weight range from -7 to 7
159 rng = (-7, 8)
160 elif dtype == DType.INT8:
161 rng = (-128, 128)
162 elif dtype == DType.INT16:
163 rng = (-32768, 32768)
164 elif dtype in (DType.INT32, DType.SHAPE):
165 # restricting too large value for SHAPE
166 rng = (-(1 << 31), (1 << 31))
167 elif dtype == DType.INT48:
168 rng = (-(1 << 47), (1 << 47))
169 else:
170 raise Exception("Unknown dtype: {}".format(dtype))
171
172 if not high_inclusive:
173 # Exclusive high: low <= range < high
174 return rng
175 else:
176 # Inclusive range: low <= range <= high
177 return (rng[0], rng[1] - 1)
178
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000179 def getRandTensor(self, shape, dtype, data_range=None):
180 if data_range is None:
181 low, high = self.getDTypeRange(dtype)
182 else:
183 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100184
Eric Kunzee5e26762020-10-13 16:11:07 -0700185 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100188 return np.int64(self.rng.integers(low=low, high=high, size=shape))
189 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
190 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
191
192 if dtype == DType.FP16:
193 return np.float16(f_tensor)
194 else:
195 f32_tensor = np.float32(f_tensor)
196 if dtype == DType.BF16:
197 # Floor the last 16 bits of each f32 value
198 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
199 else:
200 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100202 # All other integer types
203 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700204
Kevin Cheng989cb052021-04-28 16:29:44 -0700205 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 placeholders = []
207
Kevin Cheng989cb052021-04-28 16:29:44 -0700208 assert len(shape_list) == len(dtype_list)
209
Jeremy Johnson1271c442023-09-05 11:39:26 +0100210 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100212 if not self.args.lazy_data_gen:
213 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700214 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
216 return placeholders
217
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700219 consts = []
220
Kevin Cheng989cb052021-04-28 16:29:44 -0700221 assert len(shape_list) == len(dtype_list)
222
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100225 if not self.args.lazy_data_gen:
226 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700227 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700228
229 return consts
230
231 def makeShape(self, rank):
232 if self.targetted_shape:
233 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800234 return np.int32(
235 self.rng.integers(
236 low=self.args.tensor_shape_range[0],
237 high=self.args.tensor_shape_range[1],
238 size=rank,
239 )
240 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700241
242 def setTargetShape(self, shape):
243 self.targetted_shape = shape
244
245 def randInt(self, low=0, high=256):
246 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
247
248 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100249 low, high = self.getDTypeRange(dtype)
250
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100251 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100252 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100253 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100254 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100255 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100256 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
257 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 elif dtype == DType.BOOL:
259 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700260 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700261 # Special size
262 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700263
264 return np.int32(self.rng.integers(low, high, size=1))[0]
265
266 def shapeStr(self, shape):
267
268 sStr = []
269 # Convert to strings
270 for i in shape:
271 sStr.append(str(i))
272
Kevin Cheng550ccc52021-03-03 11:21:43 -0800273 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100275 def typeStr(self, dtype):
276 if isinstance(dtype, list) or isinstance(dtype, tuple):
277 assert len(dtype) >= 2
278 strs = [self.typeStr(t) for t in dtype]
279 # Limit types to the first 2 as the 3rd is the accumulator
280 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700281 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100282 if dtype in gtu.DTYPE_ATTRIBUTES:
283 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700284 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100285 raise Exception(
286 "Unknown dtype, cannot convert to string: {}".format(dtype)
287 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700288
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100289 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100290 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100291 if dtype in gtu.DTYPE_ATTRIBUTES:
292 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700293 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100294 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700295
Luke Hutton57287132023-02-06 14:54:18 +0000296 def constrictBatchSize(self, shape):
297 # Limit the batch size unless an explicit target shape set
298 if self.args.max_batch_size and not self.args.target_shapes:
299 shape[0] = min(shape[0], self.args.max_batch_size)
300 return shape
301
James Ward30124a82023-02-02 14:56:33 +0000302 def makeDimension(self):
303 return self.randInt(
304 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
305 )
306
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100307 def tensorComplianceMetaData(
308 self, op, inputType, argsDict, outputTensor, errorName
309 ):
310 if (
311 errorName
312 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
313 or not gtu.dtypeIsSupportedByCompliance(inputType)
314 ):
315 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100316 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100317
Jeremy Johnson1271c442023-09-05 11:39:26 +0100318 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100319 compliance_tens = {
320 "mode": None,
321 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
322 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
323 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100324 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
325 mode = gtu.ComplianceMode.DOT_PRODUCT
326 compliance_tens["dot_product_info"] = {
327 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100328 "ks": int(argsDict["ksb"])
329 if "ksb" in argsDict
330 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100331 }
332 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
333 mode = gtu.ComplianceMode.FP_SPECIAL
334 elif "compliance" in op and "ulp" in op["compliance"]:
335 mode = gtu.ComplianceMode.ULP
336 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
337 elif op["op"] == Op.REDUCE_PRODUCT:
338 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnson9a758382023-11-07 16:27:35 +0000339 elif op["op"] in (Op.EXP, Op.POW):
340 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnson1271c442023-09-05 11:39:26 +0100341 else:
342 mode = gtu.ComplianceMode.EXACT
343 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
344
345 return compliance_tens
346
347 # Build Op functions
348 # Create the output tensor (calling OutputShaper as needed)
349 # Do final tweaks to attributes (if necessary for errorIf)
350 # Add Op into graph
351 # Return resulting tensor information or BuildInfo
352
353 class BuildInfo:
354 """Enhanced build information containing result tensor and associated compliance dict."""
355
356 def __init__(self, resultTensor, complianceDict):
357 self.resultTensor = resultTensor
358 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700359
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000360 def build_unary(
361 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
362 ):
363 assert len(inputs) == 1
364 a = inputs[0]
365 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100366
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000367 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100368
369 # Ensure new output type has correct qinfo
370 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000371 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000372 qinfo = [
373 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000374 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000375 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100376
377 # Invalidate Input/Output list for error if checks.
378 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000379 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100380 pCount, cCount = op["operands"]
381 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000382 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
383 self, error_name, input_list, output_list
384 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100385
Les Bell729b0352021-11-24 10:28:21 +0000386 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100387 self.ser,
388 validator_fcns,
389 error_name,
390 op=op,
391 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000392 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000394 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100395 input_list=input_list,
396 output_list=output_list,
397 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000398 ):
399 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100400
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000401 attr = None
402 if op["op"] == Op.NEGATE:
403 attr = ts.TosaSerializerAttribute()
404 attr.NegateAttribute(qinfo[0], qinfo[1])
405
406 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000407
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000408 compliance = self.tensorComplianceMetaData(
409 op, a.dtype, args_dict, result_tensor, error_name
410 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000411 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700412
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000413 def build_binary_broadcast(
414 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
415 ):
416 assert len(inputs) == 2
417 a, b = inputs
418 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000419 self.ser, self.rng, a, b, error_name
420 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100421
422 # Invalidate Input/Output list for error if checks.
423 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000424 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100425 pCount, cCount = op["operands"]
426 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000427 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
428 self, error_name, input_list, output_list
429 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100430
Les Bell729b0352021-11-24 10:28:21 +0000431 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100432 self.ser,
433 validator_fcns,
434 error_name,
435 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000436 input1=a,
437 input2=b,
438 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000439 output_dtype=result_tensor.dtype,
440 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100441 input_list=input_list,
442 output_list=output_list,
443 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000444 ):
445 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100446
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000447 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000448
Jeremy Johnson9a758382023-11-07 16:27:35 +0000449 compliance = self.tensorComplianceMetaData(
450 op, a.dtype, args_dict, result_tensor, error_name
451 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000452
453 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700454
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100455 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700456 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000457 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700458 return result_tens
459
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000460 def build_arithmetic_right_shift(
461 self, op, a, b, round, validator_fcns=None, error_name=None
462 ):
463 result_tens = OutputShaper.binaryBroadcastOp(
464 self.ser, self.rng, a, b, error_name
465 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100466
467 # Invalidate Input/Output list for error if checks.
468 input_list = [a.name, b.name]
469 output_list = [result_tens.name]
470 pCount, cCount = op["operands"]
471 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000472 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
473 self, error_name, input_list, output_list
474 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100475
Les Bell729b0352021-11-24 10:28:21 +0000476 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100477 self.ser,
478 validator_fcns,
479 error_name,
480 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000481 input1=a,
482 input2=b,
483 input_dtype=a.dtype,
484 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000485 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100486 input_list=input_list,
487 output_list=output_list,
488 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000489 ):
490 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800491
492 attr = ts.TosaSerializerAttribute()
493 attr.ArithmeticRightShiftAttribute(round)
494
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000495 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800496 return result_tens
497
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100498 def build_mul(
499 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
500 ):
501 assert len(inputs) == 2
502 a, b = inputs
503 shift = args_dict["shift"]
504
505 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000506 self.ser, self.rng, a, b, error_name
507 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700508
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100509 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100510 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100511 result_tensor.setDtype(DType.INT32)
512
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100513 if error_name == ErrorIf.WrongOutputType:
514 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
515 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100516 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100517
518 # Invalidate Input/Output list for error if checks.
519 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100520 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100521 pCount, cCount = op["operands"]
522 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000523 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
524 self, error_name, input_list, output_list
525 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100526
Les Bell729b0352021-11-24 10:28:21 +0000527 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100528 self.ser,
529 validator_fcns,
530 error_name,
531 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000532 input1=a,
533 input2=b,
534 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100535 output_dtype=result_tensor.dtype,
536 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100537 input_list=input_list,
538 output_list=output_list,
539 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000540 ):
541 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700542
Kevin Chengaee1fac2020-11-11 13:54:06 -0800543 attr = ts.TosaSerializerAttribute()
544 attr.MulAttribute(shift)
545
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000546 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100547
548 compliance = self.tensorComplianceMetaData(
549 op, a.dtype, args_dict, result_tensor, error_name
550 )
551
552 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700553
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100554 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
555 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700556
Kevin Chengfe392ce2021-10-18 21:51:55 +0000557 attr = ts.TosaSerializerAttribute()
558 attr.TableAttribute(table)
559
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100560 # Invalidate Input/Output list for error if checks.
561 input_list = [a.name]
562 output_list = [result_tens.name]
563 pCount, cCount = op["operands"]
564 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000565 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
566 self, error_name, input_list, output_list
567 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100568
Les Bell729b0352021-11-24 10:28:21 +0000569 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100570 self.ser,
571 validator_fcns,
572 error_name,
573 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000574 input_shape=a.shape,
575 input_dtype=a.dtype,
576 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000577 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100578 input_list=input_list,
579 output_list=output_list,
580 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000581 ):
582 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100583
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000584 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700585
586 return result_tens
587
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100588 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
589 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
590
591 # Invalidate Input/Output list for error if checks.
592 input_list = [cond.name, a.name, b.name]
593 output_list = [result_tens.name]
594 pCount, cCount = op["operands"]
595 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000596 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
597 self, error_name, input_list, output_list
598 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100599
Les Bell729b0352021-11-24 10:28:21 +0000600 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100601 self.ser,
602 validator_fcns,
603 error_name,
604 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000605 input1=cond,
606 input2=a,
607 input3=b,
608 input_shape=a.shape,
609 input_dtype=a.dtype,
610 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000611 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100612 input_list=input_list,
613 output_list=output_list,
614 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000615 ):
616 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100617
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000618 self.ser.addOperator(
619 op["op"],
620 input_list,
621 output_list,
622 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700623 return result_tens
624
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100625 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000626 result_tens = OutputShaper.binaryComparisonOp(
627 self.ser, self.rng, a, b, error_name
628 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100629
630 # Invalidate Input/Output list for error if checks.
631 input_list = [a.name, b.name]
632 output_list = [result_tens.name]
633 pCount, cCount = op["operands"]
634 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000635 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
636 self, error_name, input_list, output_list
637 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100638
Les Bell729b0352021-11-24 10:28:21 +0000639 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100640 self.ser,
641 validator_fcns,
642 error_name,
643 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000644 input1=a,
645 input2=b,
646 input_shape=a.shape,
647 input_dtype=a.dtype,
648 output_shape=result_tens.shape,
649 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000650 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100651 input_list=input_list,
652 output_list=output_list,
653 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000654 ):
655 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100656
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000657 self.ser.addOperator(
658 op["op"],
659 input_list,
660 output_list,
661 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700662 return result_tens
663
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000664 def build_argmax(
665 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
666 ):
667 assert len(inputs) == 1
668 a = inputs[0]
669 axis = args_dict["axis"]
670 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100671
672 # Invalidate Input/Output list for error if checks.
673 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000674 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100675 pCount, cCount = op["operands"]
676 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000677 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
678 self, error_name, input_list, output_list
679 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100680
Les Bell729b0352021-11-24 10:28:21 +0000681 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100682 self.ser,
683 validator_fcns,
684 error_name,
685 op=op,
686 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000687 input_shape=a.shape,
688 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000689 output_shape=result_tensor.shape,
690 output_dtype=result_tensor.dtype,
691 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100692 input_list=input_list,
693 output_list=output_list,
694 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000695 ):
696 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700697
698 attr = ts.TosaSerializerAttribute()
699 attr.AxisAttribute(axis)
700
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000701 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000702
703 compliance = self.tensorComplianceMetaData(
704 op, inputs[0].dtype, args_dict, result_tensor, error_name
705 )
706 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700707
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000708 def build_pool2d(
709 self,
710 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100711 inputs,
712 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000713 validator_fcns=None,
714 error_name=None,
715 qinfo=None,
716 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100717 assert len(inputs) == 1
718 input = inputs[0]
719 # max_pool has no accum_dtype
720 accum_dtype = (
721 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
722 )
723 stride = args_dict["stride"]
724 pad = args_dict["pad"]
725 kernel = args_dict["kernel"]
726
Jeremy Johnson0601f802023-11-08 16:28:09 +0000727 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000728 self.ser, self.rng, input, kernel, stride, pad, error_name
729 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100730
731 # Ensure new output type has correct qinfo
732 if error_name == ErrorIf.WrongInputType:
733 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000734 qinfo = [
735 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000736 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000737 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100738
739 # Invalidate Input/Output list for error if checks.
740 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000741 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100742 pCount, cCount = op["operands"]
743 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000744 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
745 self, error_name, input_list, output_list
746 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100747
Les Bell729b0352021-11-24 10:28:21 +0000748 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100749 self.ser,
750 validator_fcns,
751 error_name,
752 op=op,
753 input_shape=input.shape,
754 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000755 output_shape=result_tensor.shape,
756 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100757 kernel=kernel,
758 stride=stride,
759 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000760 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000761 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100762 input_list=input_list,
763 output_list=output_list,
764 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000765 ):
766 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700767
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000768 if qinfo is None:
769 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700770
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000771 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100772 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000773
774 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700775
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100776 compliance = self.tensorComplianceMetaData(
777 op, inputs[0].dtype, args_dict, result_tensor, error_name
778 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100779
780 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100781
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000782 def build_conv2d(
783 self,
784 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100785 inputs,
786 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000787 validator_fcns=None,
788 error_name=None,
789 qinfo=None,
790 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100791 assert len(inputs) == 3
792 ifm, filter, bias = inputs
793 accum_dtype = args_dict["acc_type"]
794 strides = args_dict["stride"]
795 padding = args_dict["pad"]
796 dilations = args_dict["dilation"]
797
Kevin Cheng550ccc52021-03-03 11:21:43 -0800798 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100799 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100800 self.ser,
801 self.rng,
802 ifm,
803 filter,
804 accum_dtype,
805 strides,
806 padding,
807 dilations,
808 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000809 )
810
811 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000812 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
813 DType.INT8,
814 DType.UINT8,
815 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000816 qinfo = [
817 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100818 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000819 ]
Les Bell0e027d42021-11-09 14:42:14 +0000820
821 # Invalidate Input/Output list for error_if checks.
822 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100823 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000824 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000825 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
826 self, error_name, input_list, output_list
827 )
Les Bell0e027d42021-11-09 14:42:14 +0000828
Les Bell729b0352021-11-24 10:28:21 +0000829 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000830 self.ser,
831 validator_fcns,
832 error_name,
833 op=op,
834 input_dtype=ifm.dtype,
835 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100836 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000837 qinfo=qinfo,
838 input_list=input_list,
839 num_operands=num_operands,
840 output_list=output_list,
841 pad=padding,
842 stride=strides,
843 dilation=dilations,
844 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100845 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100846 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000847 ):
848 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700849
Tai Lyd3797f02023-11-15 23:06:19 +0000850 # TODO - Test local_bound, for now set local bound attribute to False
851 local_bound = False
852
Eric Kunzee5e26762020-10-13 16:11:07 -0700853 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000854 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700855
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000856 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100857
858 compliance = self.tensorComplianceMetaData(
859 op, ifm.dtype, args_dict, result_tensor, error_name
860 )
861
862 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700863
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000864 def build_conv3d(
865 self,
866 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100867 inputs,
868 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000869 validator_fcns=None,
870 error_name=None,
871 qinfo=None,
872 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100873 assert len(inputs) == 3
874 ifm, filter, bias = inputs
875 accum_dtype = args_dict["acc_type"]
876 strides = args_dict["stride"]
877 padding = args_dict["pad"]
878 dilations = args_dict["dilation"]
879
Kevin Cheng1533b852021-09-01 12:51:58 -0700880 assert len(padding) == 6
881 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100882 self.ser,
883 self.rng,
884 ifm,
885 filter,
886 accum_dtype,
887 strides,
888 padding,
889 dilations,
890 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000891 )
892
893 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000894 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
895 DType.INT8,
896 DType.UINT8,
897 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000898 qinfo = [
899 TosaQuantGen.getZeroPoint(self, ifm.dtype),
900 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
901 ]
Les Bell0e027d42021-11-09 14:42:14 +0000902
903 # Invalidate Input/Output list for error_if checks.
904 input_list = [ifm.name, filter.name, bias.name]
905 output_list = [result_tens.name]
906 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000907 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
908 self, error_name, input_list, output_list
909 )
Les Bell0e027d42021-11-09 14:42:14 +0000910
Les Bell729b0352021-11-24 10:28:21 +0000911 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000912 self.ser,
913 validator_fcns,
914 error_name,
915 op=op,
916 input_dtype=ifm.dtype,
917 weight_dtype=filter.dtype,
918 output_dtype=result_tens.dtype,
919 qinfo=qinfo,
920 input_list=input_list,
921 num_operands=num_operands,
922 output_list=output_list,
923 pad=padding,
924 stride=strides,
925 dilation=dilations,
926 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100927 weight_shape=filter.shape,
928 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000929 ):
930 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700931
Tai Lyd3797f02023-11-15 23:06:19 +0000932 # TODO - Test local_bound, for now set local bound attribute to False
933 local_bound = False
934
Kevin Cheng1533b852021-09-01 12:51:58 -0700935 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000936 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700937
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000938 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700939 return result_tens
940
Kevin Cheng550ccc52021-03-03 11:21:43 -0800941 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000942 self,
943 op,
944 ifm,
945 filter,
946 bias,
James Ward8b390432022-08-12 20:48:56 +0100947 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000948 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700949 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000950 output_shape,
951 validator_fcns=None,
952 error_name=None,
953 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800954 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700955 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000956 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100957 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000958 )
Les Bell0e027d42021-11-09 14:42:14 +0000959
960 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000961 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
962 DType.INT8,
963 DType.UINT8,
964 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000965 qinfo = [
966 TosaQuantGen.getZeroPoint(self, ifm.dtype),
967 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
968 ]
Les Bell0e027d42021-11-09 14:42:14 +0000969
970 # Invalidate Input/Output list for error_if checks.
971 input_list = [ifm.name, filter.name, bias.name]
972 output_list = [result_tens.name]
973 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000974 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
975 self, error_name, input_list, output_list
976 )
Les Bell0e027d42021-11-09 14:42:14 +0000977
Les Bell729b0352021-11-24 10:28:21 +0000978 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000979 self.ser,
980 validator_fcns,
981 error_name,
982 op=op,
983 input_dtype=ifm.dtype,
984 weight_dtype=filter.dtype,
985 output_dtype=result_tens.dtype,
986 qinfo=qinfo,
987 input_list=input_list,
988 num_operands=num_operands,
989 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700990 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000991 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000992 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100993 weight_shape=filter.shape,
994 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000995 ):
996 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700997
Tai Lyd3797f02023-11-15 23:06:19 +0000998 # TODO - Test local_bound, for now set local bound attribute to False
999 local_bound = False
1000
Eric Kunzee5e26762020-10-13 16:11:07 -07001001 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001002 attr.TransposeConvAttribute(
1003 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1004 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001005
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001006 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001007 return result_tens
1008
Kevin Cheng550ccc52021-03-03 11:21:43 -08001009 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001010 self,
1011 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001012 inputs,
1013 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001014 validator_fcns=None,
1015 error_name=None,
1016 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001017 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001018 assert len(inputs) == 3
1019 ifm, filter, bias = inputs
1020 accum_dtype = args_dict["acc_type"]
1021 strides = args_dict["stride"]
1022 padding = args_dict["pad"]
1023 dilations = args_dict["dilation"]
1024
Kevin Cheng550ccc52021-03-03 11:21:43 -08001025 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001026 self.ser,
1027 self.rng,
1028 ifm,
1029 filter,
1030 accum_dtype,
1031 strides,
1032 padding,
1033 dilations,
1034 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001035 )
1036
1037 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001038 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1039 DType.INT8,
1040 DType.UINT8,
1041 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001042 qinfo = [
1043 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1044 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1045 ]
Les Bell0e027d42021-11-09 14:42:14 +00001046
1047 # Invalidate Input/Output list for error_if checks.
1048 input_list = [ifm.name, filter.name, bias.name]
1049 output_list = [result_tens.name]
1050 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001051 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1052 self, error_name, input_list, output_list
1053 )
Les Bell0e027d42021-11-09 14:42:14 +00001054
Les Bell729b0352021-11-24 10:28:21 +00001055 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001056 self.ser,
1057 validator_fcns,
1058 error_name,
1059 op=op,
1060 input_dtype=ifm.dtype,
1061 weight_dtype=filter.dtype,
1062 output_dtype=result_tens.dtype,
1063 qinfo=qinfo,
1064 input_list=input_list,
1065 num_operands=num_operands,
1066 output_list=output_list,
1067 pad=padding,
1068 stride=strides,
1069 dilation=dilations,
1070 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001071 weight_shape=filter.shape,
1072 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001073 ):
1074 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001075
Tai Lyd3797f02023-11-15 23:06:19 +00001076 # TODO - Test local_bound, for now set local bound attribute to False
1077 local_bound = False
1078
Eric Kunzee5e26762020-10-13 16:11:07 -07001079 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001080 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001081
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001082 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001083 return result_tens
1084
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001085 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001086 self,
1087 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001088 inputs,
1089 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001090 validator_fcns=None,
1091 error_name=None,
1092 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001093 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001094 assert len(inputs) == 3
1095 ifm, filter, bias = inputs
1096 accum_dtype = args_dict["acc_type"]
1097
1098 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001099 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001100 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001101
1102 # Invalidate Input/Output list for error if checks.
1103 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001104 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001105 pCount, cCount = op["operands"]
1106 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001107 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1108 self, error_name, input_list, output_list
1109 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001110
Les Bell729b0352021-11-24 10:28:21 +00001111 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001112 self.ser,
1113 validator_fcns,
1114 error_name,
1115 op=op,
1116 input_shape=ifm.shape,
1117 input_dtype=ifm.dtype,
1118 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001119 output_shape=result_tensor.shape,
1120 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001121 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001122 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001123 input_list=input_list,
1124 output_list=output_list,
1125 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001126 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001127 ):
1128 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001129
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001130 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001131 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001132
1133 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001134
1135 compliance = self.tensorComplianceMetaData(
1136 op, ifm.dtype, args_dict, result_tensor, error_name
1137 )
1138
1139 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001140
James Ward8b390432022-08-12 20:48:56 +01001141 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001142 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001143 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001144 assert len(inputs) == 2
1145 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001146 accum_dtype = args_dict["acc_type"]
1147 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001148 self.ser, self.rng, a, b, accum_dtype, error_name
1149 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001150
1151 # Invalidate Input/Output list for error if checks.
1152 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001153 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001154 pCount, cCount = op["operands"]
1155 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001156 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1157 self, error_name, input_list, output_list
1158 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001159
Les Bell729b0352021-11-24 10:28:21 +00001160 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001161 self.ser,
1162 validator_fcns,
1163 error_name,
1164 op=op,
1165 input_shape=a.shape,
1166 input_dtype=a.dtype,
1167 input2_shape=b.shape,
1168 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001169 output_shape=result_tensor.shape,
1170 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001171 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001172 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001173 input_list=input_list,
1174 output_list=output_list,
1175 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001176 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001177 ):
1178 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001179
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001180 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001181 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001182
1183 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001184
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001185 compliance = self.tensorComplianceMetaData(
1186 op, a.dtype, args_dict, result_tensor, error_name
1187 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001188
1189 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001190
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001191 def build_reduce(
1192 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1193 ):
1194 assert len(inputs) == 1
1195 a = inputs[0]
1196 axis = args_dict["axis"]
1197 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001198
1199 # Invalidate Input/Output list for error if checks.
1200 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001201 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001202 pCount, cCount = op["operands"]
1203 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001204 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1205 self, error_name, input_list, output_list
1206 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001207
Les Bell729b0352021-11-24 10:28:21 +00001208 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001209 self.ser,
1210 validator_fcns,
1211 error_name,
1212 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001213 axis=axis,
1214 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001215 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001216 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001217 output_dtype=result_tensor.dtype,
1218 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001219 input_list=input_list,
1220 output_list=output_list,
1221 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001222 ):
1223 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001224
1225 attr = ts.TosaSerializerAttribute()
1226 attr.AxisAttribute(axis)
1227
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001228 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001229
1230 if op["op"] == Op.REDUCE_PRODUCT:
1231 # TODO: Add compliance support!
1232 compliance = None
1233 else:
1234 compliance = self.tensorComplianceMetaData(
1235 op, a.dtype, args_dict, result_tensor, error_name
1236 )
1237
1238 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001239
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001240 def build_clamp(
1241 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1242 ):
1243 assert len(inputs) == 1
1244 a = inputs[0]
1245
1246 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001247
Jeremy Johnson18e26662021-07-22 16:15:29 +01001248 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001249
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001250 if error_name == ErrorIf.MaxSmallerMin:
1251 # Make sure the numbers are different to invoke this error
1252 while v[0] == v[1]:
1253 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1254 max_val = min(v)
1255 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001256 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001257 max_val = max(v)
1258 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001259
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001260 # Invalidate Input/Output list for error if checks.
1261 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001262 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001263 pCount, cCount = op["operands"]
1264 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001265 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1266 self, error_name, input_list, output_list
1267 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001268
Les Bell729b0352021-11-24 10:28:21 +00001269 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001270 self.ser,
1271 validator_fcns,
1272 error_name,
1273 op=op,
1274 max_val=max_val,
1275 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001276 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001277 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001278 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001279 output_dtype=result_tensor.dtype,
1280 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001281 input_list=input_list,
1282 output_list=output_list,
1283 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001284 ):
1285 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001286
1287 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001288 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1289 if a.dtype == DType.FP16:
1290 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1291 min_val = min_val.astype(np.float32)
1292 max_val = max_val.astype(np.float32)
1293
1294 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001295 else:
James Ward34071252022-12-07 15:48:47 +00001296 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001297
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001298 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001299
1300 compliance = self.tensorComplianceMetaData(
1301 op, a.dtype, args_dict, result_tensor, error_name
1302 )
1303
1304 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001305
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001306 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1307 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001308 attr = ts.TosaSerializerAttribute()
1309
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001310 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001311
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001312 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001313 return result_tens
1314
1315 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001316 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1317 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001318
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001319 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001320 return result_tens
1321
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001322 def build_activation(
1323 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1324 ):
1325 assert len(inputs) == 1
1326 a = inputs[0]
1327
1328 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001329
1330 # Invalidate Input/Output list for error if checks.
1331 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001332 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001333 pCount, cCount = op["operands"]
1334 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001335 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1336 self, error_name, input_list, output_list
1337 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001338
Les Bell729b0352021-11-24 10:28:21 +00001339 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001340 self.ser,
1341 validator_fcns,
1342 error_name,
1343 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001344 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001345 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001346 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001347 output_dtype=result_tensor.dtype,
1348 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001349 input_list=input_list,
1350 output_list=output_list,
1351 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001352 ):
1353 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001354
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001355 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001356
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001357 compliance = self.tensorComplianceMetaData(
1358 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001359 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001360
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001361 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001362
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001363 def build_concat(
1364 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1365 ):
1366 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001367 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001368 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001369
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001370 result_tensor = OutputShaper.concatOp(
1371 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001372 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001373
Matthew Haddon818ab902021-07-27 09:12:49 +01001374 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001375 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001376 input_tensor_names.append(tensor.name)
1377
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001378 # Invalidate Input/Output list for error if checks.
1379 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001380 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001381 pCount, cCount = op["operands"]
1382 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001383 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1384 self, error_name, input_list, output_list
1385 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001386
Les Bell729b0352021-11-24 10:28:21 +00001387 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001388 self.ser,
1389 validator_fcns,
1390 error_name,
1391 op=op,
1392 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001393 input_shape=inputs[0].shape,
1394 output_shape=result_tensor.shape,
1395 input_dtype=inputs[0].dtype,
1396 output_dtype=result_tensor.dtype,
1397 inputs=inputs,
1398 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001399 input_list=input_list,
1400 output_list=output_list,
1401 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001402 ):
1403 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001404
1405 attr = ts.TosaSerializerAttribute()
1406 attr.AxisAttribute(axis)
1407
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001408 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001409 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001410
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001411 def build_pad(
1412 self,
1413 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001414 inputs,
1415 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001416 validator_fcns=None,
1417 error_name=None,
1418 qinfo=None,
1419 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001420 assert len(inputs) == 1
1421 a = inputs[0]
1422 padding = args_dict["pad"]
1423 pad_const_int = args_dict["pad_const_int"]
1424 pad_const_float = args_dict["pad_const_fp"]
1425
1426 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001427
Kevin Chengfe392ce2021-10-18 21:51:55 +00001428 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001429 attr.PadAttribute(
1430 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1431 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001432
Matthew Haddone807aae2021-10-11 18:12:58 +01001433 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001434 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001435 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001436 pCount, cCount = op["operands"]
1437 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001438 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1439 self, error_name, input_list, output_list
1440 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001441
Les Bell729b0352021-11-24 10:28:21 +00001442 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001443 self.ser,
1444 validator_fcns,
1445 error_name,
1446 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001447 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001448 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001449 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001450 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001451 pad=padding,
1452 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001453 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001454 input_list=input_list,
1455 output_list=output_list,
1456 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001457 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001458 ):
1459 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001460
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001461 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001462
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001463 compliance = self.tensorComplianceMetaData(
1464 op, a.dtype, args_dict, result_tensor, error_name
1465 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001466
1467 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001468
Won Jeona21b2e82023-08-10 10:33:01 +00001469 def build_dim(
1470 self,
1471 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001472 inputs,
1473 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001474 validator_fcns=None,
1475 error_name=None,
1476 qinfo=None,
1477 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001478 assert len(inputs) == 1
1479 a = inputs[0]
1480 axis = args_dict["axis"]
1481 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001482
1483 # Invalidate Input/Output list for error if checks.
1484 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001485 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001486 pCount, cCount = op["operands"]
1487 num_operands = pCount + cCount
1488 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1489 self, error_name, input_list, output_list
1490 )
1491
1492 if not TosaErrorValidator.evValidateErrorIfs(
1493 self.ser,
1494 validator_fcns,
1495 error_name,
1496 op=op,
1497 axis=axis,
1498 input_shape=a.shape,
1499 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001500 output_shape=result_tensor.shape,
1501 output_dtype=result_tensor.dtype,
1502 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001503 input_list=input_list,
1504 output_list=output_list,
1505 num_operands=num_operands,
1506 ):
1507 return None
1508
1509 attr = ts.TosaSerializerAttribute()
1510 attr.AxisAttribute(axis)
1511
1512 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001513 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001514
Matthew Haddone807aae2021-10-11 18:12:58 +01001515 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001516 result_tens = OutputShaper.reshapeOp(
1517 self.ser, self.rng, a, newShape, error_name
1518 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001519
1520 # Invalidate Input/Output list for error if checks.
1521 input_list = [a.name]
1522 output_list = [result_tens.name]
1523 pCount, cCount = op["operands"]
1524 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001525 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1526 self, error_name, input_list, output_list
1527 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001528
Les Bell729b0352021-11-24 10:28:21 +00001529 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001530 self.ser,
1531 validator_fcns,
1532 error_name,
1533 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001534 input_shape=a.shape,
1535 output_shape=result_tens.shape,
1536 input_dtype=a.dtype,
1537 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001538 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001539 input_list=input_list,
1540 output_list=output_list,
1541 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001542 ):
1543 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001544
1545 attr = ts.TosaSerializerAttribute()
1546 attr.ReshapeAttribute(newShape)
1547
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001548 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001549 return result_tens
1550
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001551 def build_reverse(
1552 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1553 ):
1554 assert len(inputs) == 1
1555 a = inputs[0]
1556 axis = args_dict["axis"]
1557 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001558
1559 # Invalidate Input/Output list for error if checks.
1560 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001561 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001562 pCount, cCount = op["operands"]
1563 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001564 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1565 self, error_name, input_list, output_list
1566 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001567
Les Bell729b0352021-11-24 10:28:21 +00001568 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001569 self.ser,
1570 validator_fcns,
1571 error_name,
1572 op=op,
1573 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001574 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001575 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001576 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001577 output_dtype=result_tensor.dtype,
1578 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001579 input_list=input_list,
1580 output_list=output_list,
1581 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001582 ):
1583 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001584
1585 attr = ts.TosaSerializerAttribute()
1586 attr.AxisAttribute(axis)
1587
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001588 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001589 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001590
Matthew Haddone807aae2021-10-11 18:12:58 +01001591 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1592 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001593
Kevin Chengfe392ce2021-10-18 21:51:55 +00001594 attr = ts.TosaSerializerAttribute()
1595 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001596
Matthew Haddone807aae2021-10-11 18:12:58 +01001597 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001598 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001599 output_list = [result_tens.name]
1600 pCount, cCount = op["operands"]
1601 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001602 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1603 self, error_name, input_list, output_list
1604 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001605
Les Bell729b0352021-11-24 10:28:21 +00001606 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001607 self.ser,
1608 validator_fcns,
1609 error_name,
1610 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001611 input_shape=a.shape,
1612 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001613 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001614 input_dtype=a.dtype,
1615 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001616 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001617 input_list=input_list,
1618 output_list=output_list,
1619 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001620 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001621 ):
1622 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001623
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001624 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001625 return result_tens
1626
Matthew Haddone807aae2021-10-11 18:12:58 +01001627 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001628 result_tens = OutputShaper.sliceOp(
1629 self.ser, self.rng, a, start, size, error_name
1630 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001631
1632 # Invalidate Input/Output list for error if checks.
1633 input_list = [a.name]
1634 output_list = [result_tens.name]
1635 pCount, cCount = op["operands"]
1636 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001637 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1638 self, error_name, input_list, output_list
1639 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001640
Les Bell729b0352021-11-24 10:28:21 +00001641 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001642 self.ser,
1643 validator_fcns,
1644 error_name,
1645 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001646 input_shape=a.shape,
1647 output_shape=result_tens.shape,
1648 input_dtype=a.dtype,
1649 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001650 start=start,
1651 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001652 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001653 input_list=input_list,
1654 output_list=output_list,
1655 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001656 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001657 ):
1658 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001659
1660 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001661 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001662
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001663 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001664 return result_tens
1665
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001666 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1667 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1668
1669 # Invalidate Input/Output list for error if checks.
1670 input_list = [a.name]
1671 output_list = [result_tens.name]
1672 pCount, cCount = op["operands"]
1673 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001674 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1675 self, error_name, input_list, output_list
1676 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001677
Les Bell729b0352021-11-24 10:28:21 +00001678 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001679 self.ser,
1680 validator_fcns,
1681 error_name,
1682 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001683 input_shape=a.shape,
1684 output_shape=result_tens.shape,
1685 input_dtype=a.dtype,
1686 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001687 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001688 input_list=input_list,
1689 output_list=output_list,
1690 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001691 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001692 ):
1693 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001694
1695 attr = ts.TosaSerializerAttribute()
1696 attr.TileAttribute(multiples)
1697
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001698 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001699 return result_tens
1700
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001701 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001702
1703 # Create a new indicies tensor
1704 # here with data that doesn't exceed the dimensions of the values tensor
1705
Kevin Cheng550ccc52021-03-03 11:21:43 -08001706 K = values.shape[1] # K
1707 W = self.randInt(
1708 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1709 ) # W
1710 indicies_arr = np.int32(
1711 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1712 ) # (N, W)
1713 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001714
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001715 result_tens = OutputShaper.gatherOp(
1716 self.ser, self.rng, values, indicies, error_name
1717 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001718
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001719 # Invalidate Input/Output list for error if checks.
1720 input_list = [values.name, indicies.name]
1721 output_list = [result_tens.name]
1722 pCount, cCount = op["operands"]
1723 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001724 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1725 self, error_name, input_list, output_list
1726 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001727
Les Bell729b0352021-11-24 10:28:21 +00001728 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001729 self.ser,
1730 validator_fcns,
1731 error_name,
1732 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001733 input_shape=values.shape,
1734 output_shape=result_tens.shape,
1735 input_dtype=values.dtype,
1736 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001737 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001738 input_list=input_list,
1739 output_list=output_list,
1740 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001741 ):
1742 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001743
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001744 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001745
1746 return result_tens
1747
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001748 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001749
1750 # Create a new indicies tensor
1751 # here with data that doesn't exceed the dimensions of the values_in tensor
1752
Kevin Cheng550ccc52021-03-03 11:21:43 -08001753 K = values_in.shape[1] # K
1754 W = input.shape[1] # W
1755 indicies_arr = np.int32(
1756 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1757 ) # (N, W)
1758 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001759
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001760 result_tens = OutputShaper.scatterOp(
1761 self.ser, self.rng, values_in, indicies, input, error_name
1762 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001763
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001764 # Invalidate Input/Output list for error if checks.
1765 input_list = [values_in.name, indicies.name, input.name]
1766 output_list = [result_tens.name]
1767 pCount, cCount = op["operands"]
1768 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001769 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1770 self, error_name, input_list, output_list
1771 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001772
Les Bell729b0352021-11-24 10:28:21 +00001773 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001774 self.ser,
1775 validator_fcns,
1776 error_name,
1777 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001778 input_shape=values_in.shape,
1779 output_shape=result_tens.shape,
1780 input_dtype=values_in.dtype,
1781 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001782 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001783 input_list=input_list,
1784 output_list=output_list,
1785 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001786 ):
1787 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001788
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001789 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001790
Kevin Cheng77d0f762020-11-24 10:26:32 -08001791 return result_tens
1792
Kevin Cheng550ccc52021-03-03 11:21:43 -08001793 def build_resize(
1794 self,
1795 op,
1796 input,
1797 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001798 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001799 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001800 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001801 input_dtype,
1802 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001803 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001804 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001805 ):
1806 result_tens = OutputShaper.resizeOp(
1807 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001808 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001809 input,
1810 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001811 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001812 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001813 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001814 input_dtype,
1815 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001816 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001817 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001818
Matthew Haddon848efb42021-09-09 12:30:53 +01001819 # Invalidate Input/Output list for error if checks.
1820 input_list = [input.name]
1821 output_list = [result_tens.name]
1822 pCount, cCount = op["operands"]
1823 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001824 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1825 self, error_name, input_list, output_list
1826 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001827
Les Bell729b0352021-11-24 10:28:21 +00001828 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001829 self.ser,
1830 validator_fcns,
1831 error_name,
1832 op=op,
1833 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001834 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001835 input_dtype=input_dtype,
1836 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001837 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001838 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001839 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001840 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001841 input_list=input_list,
1842 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001843 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001844 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001845 ):
1846 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001847
Eric Kunzee5e26762020-10-13 16:11:07 -07001848 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001849
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001850 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001851
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001852 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001853 return result_tens
1854
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001855 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1856 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1857 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001858 self.ser.addOperator(
1859 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1860 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001861 return result_tens
1862
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001863 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001864 self.ser.addOutputTensor(val)
1865 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001866
1867 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001868 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001869 result_tens = OutputShaper.typeConversionOp(
1870 self.ser, self.rng, val, out_dtype, error_name
1871 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001872
1873 # Invalidate Input/Output list for error if checks.
1874 input_list = [val.name]
1875 output_list = [result_tens.name]
1876 pCount, cCount = op["operands"]
1877 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001878 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1879 self, error_name, input_list, output_list
1880 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001881
Les Bell729b0352021-11-24 10:28:21 +00001882 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001883 self.ser,
1884 validator_fcns,
1885 error_name,
1886 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001887 input_shape=val.shape,
1888 output_shape=result_tens.shape,
1889 input_dtype=val.dtype,
1890 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001891 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001892 input_list=input_list,
1893 output_list=output_list,
1894 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001895 ):
1896 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001897
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001898 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001899 return result_tens
1900
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001901 def build_rescale(
1902 self,
1903 op,
1904 val,
1905 out_dtype,
1906 scale32,
1907 double_round,
1908 per_channel,
1909 validator_fcns,
1910 error_name,
1911 ):
1912 result_tens = OutputShaper.typeConversionOp(
1913 self.ser, self.rng, val, out_dtype, error_name
1914 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001915
1916 if per_channel:
1917 nc = val.shape[-1]
1918 else:
1919 nc = 1
1920
1921 in_type_width = self.typeWidth(val.dtype)
1922 out_type_width = self.typeWidth(out_dtype)
1923
Kevin Cheng3a478572021-01-22 17:21:02 -08001924 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001925 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001926 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001927 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001928 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001929 in_type_width += 1
1930 elif error_name in [
1931 ErrorIf.InputZeroPointNotZero,
1932 ErrorIf.U16InputZeroPointNotValid,
1933 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001934 input_zp = self.randInt(-128, 128)
1935 if input_zp == 0:
1936 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001937 in_type_width += 1
1938 elif val.dtype == DType.UINT16:
1939 # Must come after ErrorIf.U16InputZeroPointNotValid check
1940 input_zp = self.rng.choice([0, 32768])
1941 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001942 else:
1943 input_zp = 0
1944
Kevin Cheng3a478572021-01-22 17:21:02 -08001945 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001946 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001947 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001948 elif out_dtype == DType.UINT8:
1949 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001950 out_type_width += 1
1951 elif error_name in [
1952 ErrorIf.OutputZeroPointNotZero,
1953 ErrorIf.U16OutputZeroPointNotValid,
1954 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001955 output_zp = self.randInt(-128, 128)
1956 if output_zp == 0:
1957 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001958 out_type_width += 1
1959 elif out_dtype == DType.UINT16:
1960 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1961 output_zp = self.rng.choice([0, 32768])
1962 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001963 else:
1964 output_zp = 0
1965
1966 # Calculate scale based on:
1967 # scale = a *(2^output_width)/(2^input_width))
1968
1969 a = np.float32(self.rng.random(size=[nc]))
1970 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1971
1972 if scale32:
1973 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001974 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001975 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1976 else:
1977 # Cap the scaling at 2^15 - 1 for scale16
1978 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1979
Kevin Cheng550ccc52021-03-03 11:21:43 -08001980 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001981
1982 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1983 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001984 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1985 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001986
1987 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001988 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1989 scale_arr[i], scale32
1990 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001991 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1992 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001993
Kevin Cheng550ccc52021-03-03 11:21:43 -08001994 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001995 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001996 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001997 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001998 assert val.placeholderFilename
1999 values = np.load(
2000 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2001 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002002 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2003 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2004 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2005 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002006 if not np.all(np.array_equal(values, val_adj)):
2007 # Values changed so overwrite file with new values
2008 np.save(
2009 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2010 val_adj,
2011 False,
2012 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002013
Matthew Haddonc2025212021-10-08 21:21:05 +01002014 # Invalidate Input/Output list for error if checks.
2015 input_list = [val.name]
2016 output_list = [result_tens.name]
2017 pCount, cCount = op["operands"]
2018 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002019 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2020 self, error_name, input_list, output_list
2021 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002022
2023 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002024 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002025 self.ser,
2026 validator_fcns,
2027 error_name,
2028 op=op,
2029 input_dtype=val.dtype,
2030 output_dtype=out_dtype,
2031 input_shape=val.shape,
2032 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002033 scale32=scale32,
2034 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002035 input_list=input_list,
2036 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002037 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002038 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002039 ):
2040 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002041
Eric Kunzee5e26762020-10-13 16:11:07 -07002042 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002043 attr.RescaleAttribute(
2044 input_zp,
2045 output_zp,
2046 multiplier_arr,
2047 shift_arr,
2048 scale32,
2049 double_round,
2050 per_channel,
2051 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002052
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002053 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002054 return result_tens
2055
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002056 def _get_condition_tensor(self, op, cond, error_name):
2057 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002058 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002059 else:
2060 cond_type = DType.BOOL
2061 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2062 choice = self.rng.choice([1, 2])
2063 if choice == 1:
2064 cond_shape = [2]
2065 else:
2066 cond_shape = [1, 2]
2067 else:
2068 # Must be of size 1 (rank 0)
2069 cond_shape = []
2070 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2071 return cond_tens
2072
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002073 def build_cond_if_const(
2074 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2075 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002076 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002077 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002078 # and fill them with const nodes for the body.
2079
2080 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002081 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002082
2083 # Make then/else tensors
2084 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002085
2086 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002087 if error_name in [
2088 ErrorIf.CondIfOutputListThenGraphMismatch,
2089 ErrorIf.CondIfOutputListElseGraphMismatch,
2090 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002091 incorrect_shape = deepcopy(then_tens.shape)
2092 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002093 incorrect_shape[i] += (
2094 self.rng.choice([-3, -2, 2, 3])
2095 if incorrect_shape[i] > 3
2096 else self.rng.choice([1, 2, 4])
2097 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002098 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2099
Jeremy Johnson18e26662021-07-22 16:15:29 +01002100 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2101 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002102
2103 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002104 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002105
2106 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002107 then_block = "THEN_BLOCK"
2108 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002109 attr = ts.TosaSerializerAttribute()
2110 attr.CondIfAttribute(then_block, else_block)
2111
2112 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002113 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002114
Jerry Ge9e94af82022-10-27 09:57:00 -07002115 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002116 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002117 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2118 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2119 else:
2120 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002121 self.ser.addOutputTensor(then_tens)
2122
Jerry Ge9e94af82022-10-27 09:57:00 -07002123 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002124 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2125 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2126 else:
2127 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002128 self.ser.addOutputTensor(else_tens)
2129
Les Bell729b0352021-11-24 10:28:21 +00002130 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002131 self.ser,
2132 validator_fcns,
2133 error_name,
2134 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002135 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002136 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002137 ):
2138 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002139
Eric Kunzee5e26762020-10-13 16:11:07 -07002140 return result_tens
2141
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002142 def build_cond_if_binary(
2143 self, op, a, b, cond, validator_fcns=None, error_name=None
2144 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002145 # For cond_if with a binary op in the then/else blocks, take a and b and
2146 # alternately add or subtract them based on the condition
2147
2148 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002149 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002150
Kevin Cheng550ccc52021-03-03 11:21:43 -08002151 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002152
2153 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002154 then_block = "THEN_BLOCK"
2155 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002156 attr = ts.TosaSerializerAttribute()
2157 attr.CondIfAttribute(then_block, else_block)
2158
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002159 if error_name in [
2160 ErrorIf.CondIfInputListThenGraphMismatch,
2161 ErrorIf.CondIfInputListElseGraphMismatch,
2162 ErrorIf.CondIfOutputListElseGraphMismatch,
2163 ErrorIf.CondIfOutputListThenGraphMismatch,
2164 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002165 incorrect_shape = a.shape.copy()
2166 for i in range(len(incorrect_shape)):
2167 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2168 incorrect_block_input = deepcopy(a)
2169 incorrect_block_input.shape = incorrect_shape
2170
Eric Kunzee5e26762020-10-13 16:11:07 -07002171 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002172 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002173 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002174 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002175
James Ward24dbc422022-10-19 12:20:31 +01002176 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002177 then_op, else_op = Op.ADD, Op.SUB
2178 elif a.dtype in (DType.INT8, DType.INT16):
2179 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2180 else:
2181 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002182
Les Bell6040b4d2021-10-11 12:50:31 +01002183 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002184 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002185 if (
2186 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2187 and block == then_block
2188 ) or (
2189 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2190 and block == else_block
2191 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002192 self.ser.addInputTensor(incorrect_block_input)
2193 self.ser.addInputTensor(b)
2194 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002195 elif (
2196 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2197 and block == then_block
2198 ) or (
2199 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2200 and block == else_block
2201 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002202 self.ser.addInputTensor(a)
2203 self.ser.addInputTensor(b)
2204 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2205 else:
2206 self.ser.addInputTensor(a)
2207 self.ser.addInputTensor(b)
2208 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002209 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002210
Les Bell729b0352021-11-24 10:28:21 +00002211 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002212 self.ser,
2213 validator_fcns,
2214 error_name,
2215 op=op,
2216 a=a,
2217 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002218 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002219 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002220 ):
2221 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002222
Eric Kunzee5e26762020-10-13 16:11:07 -07002223 return result_tens
2224
Matthew Haddon630c17c2021-10-14 15:05:41 +01002225 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002226 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002227
Kevin Cheng550ccc52021-03-03 11:21:43 -08002228 cond_block = "COND_BLOCK"
2229 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002230
2231 attr = ts.TosaSerializerAttribute()
2232 attr.WhileLoopAttribute(cond_block, body_block)
2233
2234 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002235 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002236 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002237 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002238
2239 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002240 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2241 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002242 if error_name == ErrorIf.InputListOutputListMismatch:
2243 incorrect_acc = deepcopy(acc)
2244 for i in range(len(incorrect_acc.shape)):
2245 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2246 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2247 else:
2248 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002249
2250 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002251 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002252 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002253 [iter.name, a.name, acc.name],
2254 [iter_out.name, a_out.name, acc_out.name],
2255 attr,
2256 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002257 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002258
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002259 if error_name in [
2260 ErrorIf.InputListCondGraphMismatch,
2261 ErrorIf.InputListBodyGraphInputMismatch,
2262 ErrorIf.InputListBodyGraphOutputMismatch,
2263 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002264 incorrect_iter = deepcopy(iter)
2265 for i in range(len(incorrect_iter.shape)):
2266 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2267 if len(incorrect_iter.shape) == 0:
2268 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2269
2270 incorrect_acc = deepcopy(acc)
2271 for i in range(len(incorrect_acc.shape)):
2272 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2273
Eric Kunzee5e26762020-10-13 16:11:07 -07002274 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002275 self.ser.addBasicBlock(cond_block)
2276
Matthew Haddon630c17c2021-10-14 15:05:41 +01002277 if error_name == ErrorIf.InputListCondGraphMismatch:
2278 self.ser.addInputTensor(incorrect_iter)
2279 self.ser.addInputTensor(a)
2280 self.ser.addInputTensor(incorrect_acc)
2281 else:
2282 self.ser.addInputTensor(iter)
2283 self.ser.addInputTensor(a)
2284 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002285 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002286
2287 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002288 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002289 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002290 cond_type = DType.BOOL
2291 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2292 choice = self.rng.choice([1, 2])
2293 if choice == 1:
2294 cond_shape = [3]
2295 else:
2296 cond_shape = [1, 2]
2297 else:
2298 cond_shape = []
2299 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002300
Kevin Cheng550ccc52021-03-03 11:21:43 -08002301 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002302
2303 # BODY block (input: a, acc, iter, output: a, acc, iter)
2304 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002305 self.ser.addBasicBlock(body_block)
2306
Matthew Haddon630c17c2021-10-14 15:05:41 +01002307 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2308 self.ser.addInputTensor(incorrect_iter)
2309 self.ser.addInputTensor(a)
2310 self.ser.addInputTensor(incorrect_acc)
2311 else:
2312 self.ser.addInputTensor(iter)
2313 self.ser.addInputTensor(a)
2314 self.ser.addInputTensor(acc)
2315
Kevin Cheng550ccc52021-03-03 11:21:43 -08002316 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002317
2318 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002319 iter_body_out = self.ser.addIntermediate(
2320 incorrect_iter.shape, incorrect_iter.dtype
2321 )
2322 acc_body_out = self.ser.addIntermediate(
2323 incorrect_acc.shape, incorrect_acc.dtype
2324 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002325 else:
2326 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2327 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2328
Eric Kunzee5e26762020-10-13 16:11:07 -07002329 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2330 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2331 self.ser.addOutputTensor(iter_body_out)
2332 self.ser.addOutputTensor(a)
2333 self.ser.addOutputTensor(acc_body_out)
2334
Les Bell729b0352021-11-24 10:28:21 +00002335 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002336 self.ser,
2337 validator_fcns,
2338 error_name,
2339 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002340 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002341 ):
2342 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002343
Eric Kunzee5e26762020-10-13 16:11:07 -07002344 return acc_out
2345
Luke Hutton57287132023-02-06 14:54:18 +00002346 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002347 self,
2348 op,
2349 val1,
2350 val2,
2351 inverse,
2352 validator_fcns=None,
2353 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002354 ):
2355 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2356
2357 input_names = [val1.name, val2.name]
2358 pCount, cCount = op["operands"]
2359 num_operands = pCount + cCount
2360
2361 output_names = [res.name for res in results]
2362 output_shapes = [res.shape for res in results]
2363 output_dtypes = [res.dtype for res in results]
2364
2365 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2366 self, error_name, input_names, output_names
2367 )
2368
2369 if not TosaErrorValidator.evValidateErrorIfs(
2370 self.ser,
2371 validator_fcns,
2372 error_name,
2373 op=op,
2374 inverse=inverse,
2375 input1=val1,
2376 input2=val2,
2377 input_shape=val1.shape,
2378 input_dtype=val1.dtype,
2379 output_shape=output_shapes,
2380 output_dtype=output_dtypes,
2381 result_tensors=results,
2382 input_list=input_names,
2383 output_list=output_names,
2384 num_operands=num_operands,
2385 ):
2386 return None
2387
Tai Lyd3797f02023-11-15 23:06:19 +00002388 # TODO - Test local_bound, for now set local bound attribute to False
2389 local_bound = False
2390
Luke Hutton57287132023-02-06 14:54:18 +00002391 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002392 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002393
2394 self.ser.addOperator(op["op"], input_names, output_names, attr)
2395 return results
2396
Tai Lyd3797f02023-11-15 23:06:19 +00002397 def build_rfft2d(
2398 self,
2399 op,
2400 val,
2401 validator_fcns=None,
2402 error_name=None,
2403 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002404 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2405
2406 input_names = [val.name]
2407 pCount, cCount = op["operands"]
2408 num_operands = pCount + cCount
2409
2410 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002411 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002412 output_dtypes = [res.dtype for res in results]
2413
2414 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2415 self, error_name, input_names, output_names
2416 )
2417
2418 if not TosaErrorValidator.evValidateErrorIfs(
2419 self.ser,
2420 validator_fcns,
2421 error_name,
2422 op=op,
2423 input_shape=val.shape,
2424 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002425 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002426 output_dtype=output_dtypes,
2427 result_tensors=results,
2428 input_list=input_names,
2429 output_list=output_names,
2430 num_operands=num_operands,
2431 ):
2432 return None
2433
Tai Lyd3797f02023-11-15 23:06:19 +00002434 # TODO - Test local_bound, for now set local bound attribute to False
2435 local_bound = False
2436
2437 attr = ts.TosaSerializerAttribute()
2438 attr.RFFTAttribute(local_bound)
2439
2440 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002441 return results
2442
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002443 def create_filter_lists(
2444 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2445 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002446 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2447 default_test_rank_range = range(1, 5)
2448 if not shapeFilter:
2449 shapeFilter = [None]
2450
2451 # Calculate the filters based on what is requested and what the operator allows
2452 rmin, rmax = op["rank"]
2453 if rankFilter is not None:
2454 cleanRankFilter = []
2455 # Ensure rankFilter values are allowed by operator
2456 for rank in rankFilter:
2457 if rank >= rmin and rank <= rmax:
2458 cleanRankFilter.append(rank)
2459 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002460 # Ensure default behaviour is bounded by default range or by operator,
2461 # whichever is the smaller range of ranks.
2462 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002463 cleanRankFilter = (
2464 opRankRange
2465 if len(opRankRange) <= len(default_test_rank_range)
2466 else default_test_rank_range
2467 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002468 else:
2469 cleanRankFilter = range(rmin, rmax + 1)
2470
2471 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002472
Matthew Haddon1c00b712021-10-01 15:51:03 +01002473 if dtypeFilter is not None:
2474 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002475 # Create list of operator dtypes filtered by requested dtypes
2476 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002477 if dtype in dtypeFilter or (
2478 isinstance(dtype, list) and dtype[0] in dtypeFilter
2479 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002480 cleanDtypeFilter.append(dtype)
2481 else:
2482 cleanDtypeFilter = dtypes
2483
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002484 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002485 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002486 "shapeFilter": shapeFilter,
2487 "rankFilter": cleanRankFilter,
2488 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002489 }
2490 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002491 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002492 if validator is not None:
2493 validator_info = validator(check=False, op=op)
2494 else:
2495 return None
2496
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002497 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002498
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002499 # Set parameters as required
2500 if error_arguments["rank"] is not None:
2501 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002502 else:
2503 rankFilter = cleanRankFilter
2504
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002505 if error_arguments["dtype"] is not None:
2506 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002507 else:
2508 dtypeFilter = cleanDtypeFilter
2509
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002510 if error_arguments["shape"] is not None:
2511 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002512 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002513 shapeFilter = shapeFilter[
2514 :2
2515 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002516
2517 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002518 "shapeFilter": shapeFilter,
2519 "rankFilter": rankFilter,
2520 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002521 }
2522 return filterDict
2523
Kevin Cheng550ccc52021-03-03 11:21:43 -08002524 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002525 self,
2526 opName,
2527 shapeFilter=[None],
2528 rankFilter=None,
2529 dtypeFilter=None,
2530 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002531 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002532
2533 try:
2534 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002535 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002536 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002537
2538 # Initialize a new random number generator
2539 self.rng = np.random.default_rng(self.random_seed)
2540
Jeremy Johnson1271c442023-09-05 11:39:26 +01002541 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002542
Eric Kunzee5e26762020-10-13 16:11:07 -07002543 # Test list consists of a tuple of:
2544 # (opName, testNameStr, dtype, shapeList, argumentsList)
2545 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002546 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002547 error_if_validators = op["error_if_validators"]
2548 else:
2549 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002550
Matthew Haddon1c00b712021-10-01 15:51:03 +01002551 for validator in error_if_validators:
2552 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002553 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002554 else:
2555 error_name = None
2556
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002557 filterDict = self.create_filter_lists(
2558 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2559 )
2560 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002561 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002562 cleanRankFilter = filterDict["rankFilter"]
2563 cleanDtypeFilter = filterDict["dtypeFilter"]
2564 cleanShapeFilter = filterDict["shapeFilter"]
2565 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002566
2567 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002568 for t in cleanDtypeFilter:
2569 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002570 # Filter out by rank
2571 if shape is not None and len(shape) != r:
2572 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002573 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002574 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002575
Matthew Haddon74567092021-07-16 15:38:20 +01002576 shapeStr = self.shapeStr(shapeList[0])
2577 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002578
Matthew Haddon74567092021-07-16 15:38:20 +01002579 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2580 argList = []
2581 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002582 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002583 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002584 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002585
Matthew Haddon74567092021-07-16 15:38:20 +01002586 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002587 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002588 if argStr:
2589 testStr = "{}_{}_{}_{}".format(
2590 opName, shapeStr, typeStr, argStr
2591 )
2592 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002593 testStr = "{}_{}_{}".format(
2594 opName, shapeStr, typeStr
2595 )
2596 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002597 if argStr:
2598 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2599 opName, error_name, shapeStr, typeStr, argStr
2600 )
2601 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002602 testStr = "{}_ERRORIF_{}_{}_{}".format(
2603 opName, error_name, shapeStr, typeStr
2604 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002605
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002606 testList.append(
2607 (opName, testStr, t, error_name, shapeList, args)
2608 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002609
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002610 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002611 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2612 if "invalid_test_validators" in op:
2613 invalid_test_validators = op["invalid_test_validators"]
2614 clean_testList = []
2615 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002616 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002617 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002618 if validator_fcn(
2619 opName=test[0],
2620 input_dtype=test[2],
2621 shapeList=test[4],
2622 args=test[5],
2623 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002624 remove_test = True
2625 if not remove_test:
2626 clean_testList.append(test)
2627 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002628
2629 return testList
2630
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002631 def serializeTest(
2632 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2633 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002634 try:
2635 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002636 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002637 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002638
Jeremy Johnson0c716862023-04-13 17:18:19 +01002639 if self.args.verbose:
2640 print(f"Creating {testStr}")
2641
Eric Kunzee5e26762020-10-13 16:11:07 -07002642 # Create a serializer
2643 self.createSerializer(opName, testStr)
2644
Jeremy Johnson1271c442023-09-05 11:39:26 +01002645 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002646 if "error_if_validators" in op:
2647 error_if_validators = op["error_if_validators"]
2648 else:
2649 error_if_validators = None
2650
Kevin Cheng550ccc52021-03-03 11:21:43 -08002651 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002652 num_operands = pCount + cCount
2653
2654 if isinstance(dtype_or_dtypeList, list):
2655 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002656 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002657 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002658 else:
2659 dtypeList = [dtype_or_dtypeList] * (num_operands)
2660
Kevin Cheng93a16282021-08-31 16:14:03 -07002661 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002662 assert (
2663 len(shapeList) == num_operands
2664 ), "shapeList length {} must match number of operands {}".format(
2665 len(shapeList), num_operands
2666 )
2667 assert (
2668 len(dtypeList) == num_operands
2669 ), "dtypeList length {} must match number of operands {}".format(
2670 len(dtypeList), num_operands
2671 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002672
2673 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002674 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002675 except KeyError:
2676 qgen = None
2677
2678 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002679
Matthew Haddon1c00b712021-10-01 15:51:03 +01002680 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002681 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002682 else:
2683 qinfo = None
2684
Jeremy Johnson1271c442023-09-05 11:39:26 +01002685 # Extra meta data for the desc.json
2686 tensMeta = {}
2687
2688 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002689 if isinstance(testArgs, dict):
2690 # New interface with args info in dictionary
2691 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002692 assert "dg_type" in argsDict
2693 tvgInfo = tvgen_fcn(
2694 self, opName, dtypeList, shapeList, argsDict, error_name
2695 )
2696 if tvgInfo.dataGenDict:
2697 tensMeta["data_gen"] = tvgInfo.dataGenDict
2698 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002699
2700 result = build_fcn(
2701 self,
2702 op,
2703 tens,
2704 argsDict,
2705 validator_fcns=error_if_validators,
2706 error_name=error_name,
2707 qinfo=qinfo,
2708 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002709 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002710 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002711 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002712
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002713 try:
2714 if error_if_validators is None:
2715 if qinfo is not None:
2716 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2717 else:
2718 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002719 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002720 if qinfo is not None:
2721 result = build_fcn(
2722 self,
2723 op,
2724 *tens,
2725 *testArgs,
2726 validator_fcns=error_if_validators,
2727 error_name=error_name,
2728 qinfo=qinfo,
2729 )
2730 else:
2731 result = build_fcn(
2732 self,
2733 op,
2734 *tens,
2735 *testArgs,
2736 validator_fcns=error_if_validators,
2737 error_name=error_name,
2738 )
2739 except TypeError as e:
2740 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2741 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002742
Jeremy Johnson1271c442023-09-05 11:39:26 +01002743 if result:
Les Bell729b0352021-11-24 10:28:21 +00002744 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002745 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2746 # Add the compliance meta data
2747 # NOTE: This currently expects only one result output
2748 tensMeta["compliance"] = {
2749 "version": "0.1",
2750 "tensors": {result.resultTensor.name: result.complianceDict},
2751 }
2752 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002753 else:
2754 # The test is not valid
2755 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002756
Eric Kunzee5e26762020-10-13 16:11:07 -07002757 def createDynamicOpLists(self):
2758
Jeremy Johnson00423432022-09-12 17:27:37 +01002759 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2760 # Already created these lists (can occur when class is initialized more than once)
2761 return
2762
Eric Kunzee5e26762020-10-13 16:11:07 -07002763 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002764 if not self.args.level8k:
2765 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2766 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2767 else:
2768 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2769 KERNELS_2D = [[1, bigK], [bigK, 2]]
2770 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002771
Kevin Cheng1533b852021-09-01 12:51:58 -07002772 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002773 testName = "conv2d_{}x{}".format(k[0], k[1])
2774 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2775 self.TOSA_OP_LIST[testName]["filter"] = k
2776 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002777
Kevin Cheng550ccc52021-03-03 11:21:43 -08002778 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2779 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2780 "depthwise_conv2d_TEMPLATE"
2781 ].copy()
2782 self.TOSA_OP_LIST[testName]["filter"] = k
2783 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002784
Kevin Cheng550ccc52021-03-03 11:21:43 -08002785 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2786 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2787 "transpose_conv2d_TEMPLATE"
2788 ].copy()
2789 self.TOSA_OP_LIST[testName]["filter"] = k
2790 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002791
Kevin Cheng1533b852021-09-01 12:51:58 -07002792 for k in KERNELS_3D:
2793 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2794 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2795 self.TOSA_OP_LIST[testName]["filter"] = k
2796 self.TOSA_OP_LIST[testName]["template"] = False
2797
Eric Kunzee5e26762020-10-13 16:11:07 -07002798 # Delete any templates after having created any dynamic ops
2799 # This is a two-pass operation because it's bad practice to delete
2800 # keys from dictionaries while iterating
2801 keyList = []
2802 for k in self.TOSA_OP_LIST:
2803 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002804 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002805 keyList.append(k)
2806 continue
2807 except KeyError:
2808 pass
2809
2810 for k in keyList:
2811 del self.TOSA_OP_LIST[k]
2812
2813 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002814 """Fill in default fields for ops if they aren't already specified.
2815 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002816 for op in self.TOSA_OP_LIST:
2817
2818 # Required fields
2819 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002820 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002821 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002822 raise Exception(
2823 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2824 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002825
2826 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002827 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002828 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002829 raise Exception(
2830 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2831 op
2832 )
2833 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002834
2835 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002836 _ = self.TOSA_OP_LIST[op]["types"]
2837 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002838 raise Exception(
2839 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2840 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002841
2842 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002843 _ = self.TOSA_OP_LIST[op]["op"]
2844 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002845 raise Exception(
2846 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2847 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002848
2849 # Put in default rank range, if missing
2850 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002851 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002852 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002853 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002854
2855 # Tensor operator list
2856 # 'op': op name
2857 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002858 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2859 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002860 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2861 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002862 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002863
Kevin Cheng550ccc52021-03-03 11:21:43 -08002864 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002865 TYPE_INT_FP = [
2866 DType.INT8,
2867 DType.INT16,
2868 DType.INT32,
2869 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002870 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002871 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002872 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002873
Kevin Cheng550ccc52021-03-03 11:21:43 -08002874 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002875 TYPE_FI32 = [
2876 DType.FP32,
2877 DType.FP16,
2878 DType.BF16,
2879 DType.INT32,
2880 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002881 TYPE_FIB = [
2882 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002883 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002884 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002885 DType.INT8,
2886 DType.INT16,
2887 DType.INT32,
2888 DType.BOOL,
2889 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002890 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002891
James Ward24dbc422022-10-19 12:20:31 +01002892 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002893
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002894 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002895 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002896 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002897 [DType.INT8, DType.INT8, DType.INT32],
2898 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002899 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002900 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002901 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002902 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002903 ]
2904
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002905 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002906
2907 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002908 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002909 "argmax": {
2910 "op": Op.ARGMAX,
2911 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002912 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002913 "build_fcn": (
2914 build_argmax,
2915 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002916 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002917 TosaArgGen.agAxis,
2918 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002919 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002920 "error_if_validators": (
2921 TosaErrorValidator.evAxisSmallerZero,
2922 TosaErrorValidator.evAxisLargerRank,
2923 TosaErrorValidator.evArgmaxOutputRankMismatch,
2924 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2925 TosaErrorValidator.evWrongRank,
2926 TosaErrorValidator.evWrongInputType,
2927 TosaErrorValidator.evWrongOutputType,
2928 TosaErrorValidator.evWrongInputList,
2929 TosaErrorValidator.evWrongOutputList,
2930 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002931 "data_gen": {
2932 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
2933 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002934 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002935 "avg_pool2d": {
2936 "op": Op.AVG_POOL2D,
2937 "operands": (1, 0),
2938 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002939 "build_fcn": (
2940 build_pool2d,
2941 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002942 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002943 TosaArgGen.agPooling,
2944 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002945 "qgen": TosaQuantGen.qgUnary,
2946 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002947 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002948 "error_if_validators": (
2949 TosaErrorValidator.evKernelSmallerOne,
2950 TosaErrorValidator.evStrideSmallerOne,
2951 TosaErrorValidator.evPadSmallerZero,
2952 TosaErrorValidator.evWrongRank,
2953 TosaErrorValidator.evWrongInputType,
2954 TosaErrorValidator.evWrongOutputType,
2955 TosaErrorValidator.evWrongInputList,
2956 TosaErrorValidator.evWrongOutputList,
2957 TosaErrorValidator.evInputZeroPointNotZero,
2958 TosaErrorValidator.evOutputZeroPointNotZero,
2959 TosaErrorValidator.evPadLargerEqualKernel,
2960 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002961 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002962 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00002963 "data_gen": {
2964 "fp": (gtu.DataGenType.DOT_PRODUCT,),
2965 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002966 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002967 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002968 "conv2d_TEMPLATE": {
2969 "op": Op.CONV2D,
2970 "operands": (1, 2),
2971 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002972 "build_fcn": (
2973 build_conv2d,
2974 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002975 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002976 TosaArgGen.agConv,
2977 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002978 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002979 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002980 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2981 "error_if_validators": (
2982 TosaErrorValidator.evWrongInputType,
2983 TosaErrorValidator.evWrongOutputType,
2984 TosaErrorValidator.evWrongInputList,
2985 TosaErrorValidator.evWrongOutputList,
2986 TosaErrorValidator.evInputZeroPointNotZero,
2987 TosaErrorValidator.evWeightZeroPointNotZero,
2988 TosaErrorValidator.evPadSmallerZero,
2989 TosaErrorValidator.evStrideSmallerOne,
2990 TosaErrorValidator.evDilationSmallerOne,
2991 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002992 TosaErrorValidator.evConvOutputShapeMismatch,
2993 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002994 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002995 "data_gen": {
2996 "fp": (gtu.DataGenType.DOT_PRODUCT,),
2997 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002998 "template": True,
2999 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003000 # Templated operator. Filled in by createDynamicOpLists
3001 "conv3d_TEMPLATE": {
3002 "op": Op.CONV3D,
3003 "operands": (1, 2),
3004 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003005 "build_fcn": (
3006 build_conv3d,
3007 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003008 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003009 TosaArgGen.agConv,
3010 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003011 "qgen": TosaQuantGen.qgConv,
3012 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003013 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3014 "error_if_validators": (
3015 TosaErrorValidator.evWrongInputType,
3016 TosaErrorValidator.evWrongOutputType,
3017 TosaErrorValidator.evWrongInputList,
3018 TosaErrorValidator.evWrongOutputList,
3019 TosaErrorValidator.evInputZeroPointNotZero,
3020 TosaErrorValidator.evWeightZeroPointNotZero,
3021 TosaErrorValidator.evPadSmallerZero,
3022 TosaErrorValidator.evStrideSmallerOne,
3023 TosaErrorValidator.evDilationSmallerOne,
3024 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003025 TosaErrorValidator.evConvOutputShapeMismatch,
3026 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003027 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003028 "template": True,
3029 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003030 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003031 "depthwise_conv2d_TEMPLATE": {
3032 "op": Op.DEPTHWISE_CONV2D,
3033 "operands": (1, 2),
3034 "filter": [1, 1],
3035 "rank": (4, 4),
3036 "build_fcn": (
3037 build_depthwise_conv2d,
3038 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003039 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003040 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003041 ),
3042 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003043 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003044 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3045 "error_if_validators": (
3046 TosaErrorValidator.evWrongInputType,
3047 TosaErrorValidator.evWrongOutputType,
3048 TosaErrorValidator.evWrongInputList,
3049 TosaErrorValidator.evWrongOutputList,
3050 TosaErrorValidator.evInputZeroPointNotZero,
3051 TosaErrorValidator.evWeightZeroPointNotZero,
3052 TosaErrorValidator.evPadSmallerZero,
3053 TosaErrorValidator.evStrideSmallerOne,
3054 TosaErrorValidator.evDilationSmallerOne,
3055 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003056 TosaErrorValidator.evConvOutputShapeMismatch,
3057 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003058 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003059 "template": True,
3060 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003061 "fully_connected": {
3062 "op": Op.FULLY_CONNECTED,
3063 "operands": (1, 2),
3064 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003065 "build_fcn": (
3066 build_fully_connected,
3067 TosaTensorGen.tgFullyConnected,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003068 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003069 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003070 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003071 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003072 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003073 "error_if_validators": (
3074 TosaErrorValidator.evInputZeroPointNotZero,
3075 TosaErrorValidator.evWeightZeroPointNotZero,
3076 TosaErrorValidator.evWrongRank,
3077 TosaErrorValidator.evWrongInputType,
3078 TosaErrorValidator.evWrongOutputType,
3079 TosaErrorValidator.evWrongInputList,
3080 TosaErrorValidator.evWrongOutputList,
3081 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003082 "data_gen": {
3083 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3084 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003085 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003086 "matmul": {
3087 "op": Op.MATMUL,
3088 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003089 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003090 "build_fcn": (
3091 build_matmul,
3092 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003093 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003094 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003095 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003096 "qgen": TosaQuantGen.qgMatmul,
3097 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003098 "error_if_validators": (
3099 TosaErrorValidator.evInputZeroPointNotZero,
3100 TosaErrorValidator.evWrongRank,
3101 TosaErrorValidator.evWrongInputType,
3102 TosaErrorValidator.evWrongOutputType,
3103 TosaErrorValidator.evWrongInputList,
3104 TosaErrorValidator.evWrongOutputList,
3105 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003106 "data_gen": {
3107 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003108 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003109 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003110 "max_pool2d": {
3111 "op": Op.MAX_POOL2D,
3112 "operands": (1, 0),
3113 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003114 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003115 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003116 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003117 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003118 TosaArgGen.agPooling,
3119 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003120 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003121 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003122 "error_if_validators": (
3123 TosaErrorValidator.evKernelSmallerOne,
3124 TosaErrorValidator.evStrideSmallerOne,
3125 TosaErrorValidator.evPadSmallerZero,
3126 TosaErrorValidator.evWrongRank,
3127 TosaErrorValidator.evWrongInputType,
3128 TosaErrorValidator.evWrongOutputType,
3129 TosaErrorValidator.evWrongInputList,
3130 TosaErrorValidator.evWrongOutputList,
3131 TosaErrorValidator.evPadLargerEqualKernel,
3132 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003133 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003134 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003135 "data_gen": {
3136 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3137 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003138 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003139 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003140 "transpose_conv2d_TEMPLATE": {
3141 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003142 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003143 "rank": (4, 4),
3144 "build_fcn": (
3145 build_transpose_conv2d,
3146 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003147 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003148 TosaArgGen.agTransposeConv2D,
3149 ),
3150 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003151 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003152 "invalid_test_validators": (
3153 TosaInvalidValidator.ivHeightWidthInvalid,
3154 TosaInvalidValidator.ivNonPositiveOutputShape,
3155 ),
3156 "error_if_validators": (
3157 TosaErrorValidator.evWrongInputType,
3158 TosaErrorValidator.evWrongOutputType,
3159 TosaErrorValidator.evWrongInputList,
3160 TosaErrorValidator.evWrongOutputList,
3161 TosaErrorValidator.evInputZeroPointNotZero,
3162 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003163 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003164 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003165 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003166 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003167 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003168 "template": True,
3169 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003170 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003171 "clamp": {
3172 "op": Op.CLAMP,
3173 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003174 "build_fcn": (
3175 build_clamp,
3176 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003177 TosaTensorValuesGen.tvgLazyGenDefault,
3178 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003179 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003180 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003181 "error_if_validators": (
3182 TosaErrorValidator.evMaxSmallerMin,
3183 TosaErrorValidator.evWrongInputType,
3184 TosaErrorValidator.evWrongOutputType,
3185 TosaErrorValidator.evWrongInputList,
3186 TosaErrorValidator.evWrongOutputList,
3187 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003188 "data_gen": {
3189 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3190 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003191 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003192 "sigmoid": {
3193 "op": Op.SIGMOID,
3194 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003195 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003196 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003197 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003198 TosaTensorValuesGen.tvgLazyGenDefault,
3199 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003200 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003201 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003202 "error_if_validators": (
3203 TosaErrorValidator.evWrongInputType,
3204 TosaErrorValidator.evWrongOutputType,
3205 TosaErrorValidator.evWrongInputList,
3206 TosaErrorValidator.evWrongOutputList,
3207 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003208 "data_gen": {
3209 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3210 },
3211 "compliance": {"ulp": 5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08003212 },
3213 "tanh": {
3214 "op": Op.TANH,
3215 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003216 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003217 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003218 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003219 TosaTensorValuesGen.tvgLazyGenDefault,
3220 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003221 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003222 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003223 "error_if_validators": (
3224 TosaErrorValidator.evWrongInputType,
3225 TosaErrorValidator.evWrongOutputType,
3226 TosaErrorValidator.evWrongInputList,
3227 TosaErrorValidator.evWrongOutputList,
3228 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003229 "data_gen": {
3230 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3231 },
3232 "compliance": {"ulp": 5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08003233 },
Won Jeon78155c62023-06-10 00:20:04 +00003234 "erf": {
3235 "op": Op.ERF,
3236 "operands": (1, 0),
3237 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003238 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003239 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003240 TosaTensorValuesGen.tvgLazyGenDefault,
3241 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003242 ),
3243 "types": TYPE_FP,
3244 "error_if_validators": (
3245 TosaErrorValidator.evWrongInputType,
3246 TosaErrorValidator.evWrongOutputType,
3247 TosaErrorValidator.evWrongInputList,
3248 TosaErrorValidator.evWrongOutputList,
3249 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003250 "data_gen": {
3251 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3252 },
3253 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003254 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003255 # Elementwise Binary Operators
3256 "add": {
3257 "op": Op.ADD,
3258 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003259 "build_fcn": (
3260 build_binary_broadcast,
3261 TosaTensorGen.tgBroadcastFuzz,
3262 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003263 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003264 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003265 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003266 "error_if_validators": (
3267 TosaErrorValidator.evRankMismatch,
3268 TosaErrorValidator.evWrongInputType,
3269 TosaErrorValidator.evWrongOutputType,
3270 TosaErrorValidator.evWrongInputList,
3271 TosaErrorValidator.evWrongOutputList,
3272 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003273 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003274 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003275 "data_gen": {
3276 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3277 },
3278 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003279 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003280 "arithmetic_right_shift": {
3281 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3282 "operands": (2, 0),
3283 "build_fcn": (
3284 build_arithmetic_right_shift,
3285 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003286 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003287 TosaArgGen.agArithmeticRightShift,
3288 ),
3289 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003290 "error_if_validators": (
3291 TosaErrorValidator.evRankMismatch,
3292 TosaErrorValidator.evWrongInputType,
3293 TosaErrorValidator.evWrongOutputType,
3294 TosaErrorValidator.evWrongInputList,
3295 TosaErrorValidator.evWrongOutputList,
3296 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003297 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003298 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003299 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003300 "bitwise_and": {
3301 "op": Op.BITWISE_AND,
3302 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003303 "build_fcn": (
3304 build_binary_broadcast,
3305 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003306 TosaTensorValuesGen.tvgLazyGenDefault,
3307 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003308 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003309 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003310 "error_if_validators": (
3311 TosaErrorValidator.evRankMismatch,
3312 TosaErrorValidator.evWrongInputType,
3313 TosaErrorValidator.evWrongOutputType,
3314 TosaErrorValidator.evWrongInputList,
3315 TosaErrorValidator.evWrongOutputList,
3316 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003317 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003318 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003319 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003320 "bitwise_or": {
3321 "op": Op.BITWISE_OR,
3322 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003323 "build_fcn": (
3324 build_binary_broadcast,
3325 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003326 TosaTensorValuesGen.tvgLazyGenDefault,
3327 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003328 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003329 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003330 "error_if_validators": (
3331 TosaErrorValidator.evRankMismatch,
3332 TosaErrorValidator.evWrongInputType,
3333 TosaErrorValidator.evWrongOutputType,
3334 TosaErrorValidator.evWrongInputList,
3335 TosaErrorValidator.evWrongOutputList,
3336 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003337 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003338 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003339 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003340 "bitwise_xor": {
3341 "op": Op.BITWISE_XOR,
3342 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003343 "build_fcn": (
3344 build_binary_broadcast,
3345 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003346 TosaTensorValuesGen.tvgLazyGenDefault,
3347 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003348 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003349 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003350 "error_if_validators": (
3351 TosaErrorValidator.evRankMismatch,
3352 TosaErrorValidator.evWrongInputType,
3353 TosaErrorValidator.evWrongOutputType,
3354 TosaErrorValidator.evWrongInputList,
3355 TosaErrorValidator.evWrongOutputList,
3356 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003357 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003358 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003359 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003360 "intdiv": {
3361 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003362 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003363 "build_fcn": (
3364 build_binary_broadcast,
3365 TosaTensorGen.tgBroadcastFuzz,
3366 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003367 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003368 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003369 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003370 "error_if_validators": (
3371 TosaErrorValidator.evRankMismatch,
3372 TosaErrorValidator.evWrongInputType,
3373 TosaErrorValidator.evWrongOutputType,
3374 TosaErrorValidator.evWrongInputList,
3375 TosaErrorValidator.evWrongOutputList,
3376 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003377 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003378 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003379 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003380 "logical_and": {
3381 "op": Op.LOGICAL_AND,
3382 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003383 "build_fcn": (
3384 build_binary_broadcast,
3385 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003386 TosaTensorValuesGen.tvgLazyGenDefault,
3387 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003388 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003389 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003390 "error_if_validators": (
3391 TosaErrorValidator.evRankMismatch,
3392 TosaErrorValidator.evWrongInputType,
3393 TosaErrorValidator.evWrongOutputType,
3394 TosaErrorValidator.evWrongInputList,
3395 TosaErrorValidator.evWrongOutputList,
3396 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003397 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003398 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003399 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003400 "logical_left_shift": {
3401 "op": Op.LOGICAL_LEFT_SHIFT,
3402 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003403 "build_fcn": (
3404 build_binary_broadcast,
3405 TosaTensorGen.tgBroadcastFuzz,
3406 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003407 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003408 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003409 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003410 "error_if_validators": (
3411 TosaErrorValidator.evRankMismatch,
3412 TosaErrorValidator.evWrongInputType,
3413 TosaErrorValidator.evWrongOutputType,
3414 TosaErrorValidator.evWrongInputList,
3415 TosaErrorValidator.evWrongOutputList,
3416 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003417 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003418 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003419 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003420 "logical_right_shift": {
3421 "op": Op.LOGICAL_RIGHT_SHIFT,
3422 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003423 "build_fcn": (
3424 build_binary_broadcast,
3425 TosaTensorGen.tgBroadcastFuzz,
3426 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003427 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003428 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003429 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003430 "error_if_validators": (
3431 TosaErrorValidator.evRankMismatch,
3432 TosaErrorValidator.evWrongInputType,
3433 TosaErrorValidator.evWrongOutputType,
3434 TosaErrorValidator.evWrongInputList,
3435 TosaErrorValidator.evWrongOutputList,
3436 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003437 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003438 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003439 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003440 "logical_or": {
3441 "op": Op.LOGICAL_OR,
3442 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003443 "build_fcn": (
3444 build_binary_broadcast,
3445 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003446 TosaTensorValuesGen.tvgLazyGenDefault,
3447 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003448 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003449 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003450 "error_if_validators": (
3451 TosaErrorValidator.evRankMismatch,
3452 TosaErrorValidator.evWrongInputType,
3453 TosaErrorValidator.evWrongOutputType,
3454 TosaErrorValidator.evWrongInputList,
3455 TosaErrorValidator.evWrongOutputList,
3456 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003457 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003458 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003459 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003460 "logical_xor": {
3461 "op": Op.LOGICAL_XOR,
3462 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003463 "build_fcn": (
3464 build_binary_broadcast,
3465 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003466 TosaTensorValuesGen.tvgLazyGenDefault,
3467 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003468 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003469 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003470 "error_if_validators": (
3471 TosaErrorValidator.evRankMismatch,
3472 TosaErrorValidator.evWrongInputType,
3473 TosaErrorValidator.evWrongOutputType,
3474 TosaErrorValidator.evWrongInputList,
3475 TosaErrorValidator.evWrongOutputList,
3476 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003477 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003478 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003479 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003480 "maximum": {
3481 "op": Op.MAXIMUM,
3482 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003483 "build_fcn": (
3484 build_binary_broadcast,
3485 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003486 TosaTensorValuesGen.tvgLazyGenDefault,
3487 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003488 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003489 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003490 "error_if_validators": (
3491 TosaErrorValidator.evRankMismatch,
3492 TosaErrorValidator.evWrongInputType,
3493 TosaErrorValidator.evWrongOutputType,
3494 TosaErrorValidator.evWrongInputList,
3495 TosaErrorValidator.evWrongOutputList,
3496 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003497 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003498 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003499 "data_gen": {
3500 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3501 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003502 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003503 "minimum": {
3504 "op": Op.MINIMUM,
3505 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003506 "build_fcn": (
3507 build_binary_broadcast,
3508 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003509 TosaTensorValuesGen.tvgLazyGenDefault,
3510 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003511 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003512 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003513 "error_if_validators": (
3514 TosaErrorValidator.evRankMismatch,
3515 TosaErrorValidator.evWrongInputType,
3516 TosaErrorValidator.evWrongOutputType,
3517 TosaErrorValidator.evWrongInputList,
3518 TosaErrorValidator.evWrongOutputList,
3519 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003520 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003521 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003522 "data_gen": {
3523 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3524 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003525 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003526 "mul": {
3527 "op": Op.MUL,
3528 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003529 "build_fcn": (
3530 build_mul,
3531 TosaTensorGen.tgBroadcastFuzz,
3532 TosaTensorValuesGen.tvgMul,
3533 TosaArgGen.agMul,
3534 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003536 "error_if_validators": (
3537 TosaErrorValidator.evWrongInputType,
3538 TosaErrorValidator.evWrongOutputType,
3539 TosaErrorValidator.evWrongInputList,
3540 TosaErrorValidator.evWrongOutputList,
3541 TosaErrorValidator.evRankMismatch,
3542 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003543 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003544 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003545 "data_gen": {
3546 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3547 },
3548 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003549 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003550 "pow": {
3551 "op": Op.POW,
3552 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003553 "build_fcn": (
3554 build_binary_broadcast,
3555 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003556 TosaTensorValuesGen.tvgLazyGenDefault,
3557 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003558 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003559 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003560 "error_if_validators": (
3561 TosaErrorValidator.evRankMismatch,
3562 TosaErrorValidator.evWrongInputType,
3563 TosaErrorValidator.evWrongOutputType,
3564 TosaErrorValidator.evWrongInputList,
3565 TosaErrorValidator.evWrongOutputList,
3566 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003567 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003568 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003569 "data_gen": {
3570 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3571 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003572 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003573 "sub": {
3574 "op": Op.SUB,
3575 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003576 "build_fcn": (
3577 build_binary_broadcast,
3578 TosaTensorGen.tgBroadcastFuzz,
3579 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003580 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003581 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003582 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003583 "error_if_validators": (
3584 TosaErrorValidator.evRankMismatch,
3585 TosaErrorValidator.evWrongInputType,
3586 TosaErrorValidator.evWrongOutputType,
3587 TosaErrorValidator.evWrongInputList,
3588 TosaErrorValidator.evWrongOutputList,
3589 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003590 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003591 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003592 "data_gen": {
3593 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3594 },
3595 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003596 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003597 "table": {
3598 "op": Op.TABLE,
3599 # Use the automatic generation functions to create the input array
3600 # but create the table tensor in the build function, as it may be
3601 # a different type from the input
3602 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003603 "build_fcn": (
3604 build_table,
3605 TosaTensorGen.tgBasic,
3606 TosaTensorValuesGen.tvgDefault,
3607 TosaArgGen.agTable,
3608 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003609 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003610 "error_if_validators": (
3611 TosaErrorValidator.evWrongInputType,
3612 TosaErrorValidator.evWrongOutputType,
3613 TosaErrorValidator.evWrongInputList,
3614 TosaErrorValidator.evWrongOutputList,
3615 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003616 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003617 # Elementwise Unary operators
3618 "abs": {
3619 "op": Op.ABS,
3620 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003621 "build_fcn": (
3622 build_unary,
3623 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003624 TosaTensorValuesGen.tvgLazyGenDefault,
3625 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003626 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003627 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003628 "error_if_validators": (
3629 TosaErrorValidator.evWrongInputType,
3630 TosaErrorValidator.evWrongOutputType,
3631 TosaErrorValidator.evWrongInputList,
3632 TosaErrorValidator.evWrongOutputList,
3633 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003634 "data_gen": {
3635 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3636 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003637 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003638 "bitwise_not": {
3639 "op": Op.BITWISE_NOT,
3640 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003641 "build_fcn": (
3642 build_unary,
3643 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003644 TosaTensorValuesGen.tvgLazyGenDefault,
3645 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003646 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003647 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003648 "error_if_validators": (
3649 TosaErrorValidator.evWrongInputType,
3650 TosaErrorValidator.evWrongOutputType,
3651 TosaErrorValidator.evWrongInputList,
3652 TosaErrorValidator.evWrongOutputList,
3653 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003654 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003655 "ceil": {
3656 "op": Op.CEIL,
3657 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003658 "build_fcn": (
3659 build_unary,
3660 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003661 TosaTensorValuesGen.tvgLazyGenDefault,
3662 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003663 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003664 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003665 "error_if_validators": (
3666 TosaErrorValidator.evWrongInputType,
3667 TosaErrorValidator.evWrongOutputType,
3668 TosaErrorValidator.evWrongInputList,
3669 TosaErrorValidator.evWrongOutputList,
3670 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003671 "data_gen": {
3672 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3673 },
3674 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003675 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003676 "clz": {
3677 "op": Op.CLZ,
3678 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003679 "build_fcn": (
3680 build_unary,
3681 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003682 TosaTensorValuesGen.tvgLazyGenDefault,
3683 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003684 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003685 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003686 "error_if_validators": (
3687 TosaErrorValidator.evWrongInputType,
3688 TosaErrorValidator.evWrongOutputType,
3689 TosaErrorValidator.evWrongInputList,
3690 TosaErrorValidator.evWrongOutputList,
3691 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003692 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003693 "exp": {
3694 "op": Op.EXP,
3695 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003696 "build_fcn": (
3697 build_unary,
3698 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003699 TosaTensorValuesGen.tvgLazyGenDefault,
3700 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003701 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003702 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003703 "error_if_validators": (
3704 TosaErrorValidator.evWrongInputType,
3705 TosaErrorValidator.evWrongOutputType,
3706 TosaErrorValidator.evWrongInputList,
3707 TosaErrorValidator.evWrongOutputList,
3708 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003709 "data_gen": {
3710 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3711 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003712 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003713 "floor": {
3714 "op": Op.FLOOR,
3715 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003716 "build_fcn": (
3717 build_unary,
3718 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003719 TosaTensorValuesGen.tvgLazyGenDefault,
3720 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003721 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003722 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003723 "error_if_validators": (
3724 TosaErrorValidator.evWrongInputType,
3725 TosaErrorValidator.evWrongOutputType,
3726 TosaErrorValidator.evWrongInputList,
3727 TosaErrorValidator.evWrongOutputList,
3728 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003729 "data_gen": {
3730 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3731 },
3732 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003733 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003734 "log": {
3735 "op": Op.LOG,
3736 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003737 "build_fcn": (
3738 build_unary,
3739 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003740 TosaTensorValuesGen.tvgLazyGenDefault,
3741 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003742 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003743 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003744 "error_if_validators": (
3745 TosaErrorValidator.evWrongInputType,
3746 TosaErrorValidator.evWrongOutputType,
3747 TosaErrorValidator.evWrongInputList,
3748 TosaErrorValidator.evWrongOutputList,
3749 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003750 "data_gen": {
3751 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3752 },
3753 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003754 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003755 "logical_not": {
3756 "op": Op.LOGICAL_NOT,
3757 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003758 "build_fcn": (
3759 build_unary,
3760 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003761 TosaTensorValuesGen.tvgLazyGenDefault,
3762 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003763 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003764 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003765 "error_if_validators": (
3766 TosaErrorValidator.evWrongInputType,
3767 TosaErrorValidator.evWrongOutputType,
3768 TosaErrorValidator.evWrongInputList,
3769 TosaErrorValidator.evWrongOutputList,
3770 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003771 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003772 "negate": {
3773 "op": Op.NEGATE,
3774 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003775 "build_fcn": (
3776 build_unary,
3777 TosaTensorGen.tgBasic,
3778 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003779 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003780 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 "qgen": TosaQuantGen.qgUnary,
3782 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003783 "error_if_validators": (
3784 TosaErrorValidator.evInputZeroPointNotZero,
3785 TosaErrorValidator.evOutputZeroPointNotZero,
3786 TosaErrorValidator.evWrongInputType,
3787 TosaErrorValidator.evWrongOutputType,
3788 TosaErrorValidator.evWrongInputList,
3789 TosaErrorValidator.evWrongOutputList,
3790 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003791 "data_gen": {
3792 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3793 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003794 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003795 "reciprocal": {
3796 "op": Op.RECIPROCAL,
3797 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003798 "build_fcn": (
3799 build_unary,
3800 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003801 TosaTensorValuesGen.tvgLazyGenDefault,
3802 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003803 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003804 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003805 "error_if_validators": (
3806 TosaErrorValidator.evWrongInputType,
3807 TosaErrorValidator.evWrongOutputType,
3808 TosaErrorValidator.evWrongInputList,
3809 TosaErrorValidator.evWrongOutputList,
3810 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003811 "data_gen": {
3812 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3813 },
3814 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003815 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003816 "rsqrt": {
3817 "op": Op.RSQRT,
3818 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003819 "build_fcn": (
3820 build_unary,
3821 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003822 TosaTensorValuesGen.tvgLazyGenDefault,
3823 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003824 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003825 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003826 "error_if_validators": (
3827 TosaErrorValidator.evWrongInputType,
3828 TosaErrorValidator.evWrongOutputType,
3829 TosaErrorValidator.evWrongInputList,
3830 TosaErrorValidator.evWrongOutputList,
3831 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003832 "data_gen": {
3833 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3834 },
3835 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003836 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003837 # Elementwise Ternary operators
3838 "select": {
3839 "op": Op.SELECT,
3840 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003841 "build_fcn": (
3842 build_select,
3843 TosaTensorGen.tgBroadcastFuzz,
3844 TosaTensorValuesGen.tvgSelect,
3845 None,
3846 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003847 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003848 "error_if_validators": (
3849 TosaErrorValidator.evRankMismatch,
3850 TosaErrorValidator.evWrongInputType,
3851 TosaErrorValidator.evWrongOutputType,
3852 TosaErrorValidator.evWrongInputList,
3853 TosaErrorValidator.evWrongOutputList,
3854 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003855 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003856 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003857 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003858 # Comparison operators
3859 "equal": {
3860 "op": Op.EQUAL,
3861 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003862 "build_fcn": (
3863 build_comparison,
3864 TosaTensorGen.tgBroadcastFuzz,
3865 TosaTensorValuesGen.tvgEqual,
3866 None,
3867 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003868 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003869 "error_if_validators": (
3870 TosaErrorValidator.evRankMismatch,
3871 TosaErrorValidator.evWrongInputType,
3872 TosaErrorValidator.evWrongOutputType,
3873 TosaErrorValidator.evWrongInputList,
3874 TosaErrorValidator.evWrongOutputList,
3875 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003876 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003877 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003878 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003879 "greater_equal": {
3880 "op": Op.GREATER_EQUAL,
3881 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003882 "build_fcn": (
3883 build_comparison,
3884 TosaTensorGen.tgBroadcastFuzz,
3885 TosaTensorValuesGen.tvgDefault,
3886 None,
3887 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003888 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003889 "error_if_validators": (
3890 TosaErrorValidator.evRankMismatch,
3891 TosaErrorValidator.evWrongInputType,
3892 TosaErrorValidator.evWrongOutputType,
3893 TosaErrorValidator.evWrongInputList,
3894 TosaErrorValidator.evWrongOutputList,
3895 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003896 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003897 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003898 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003899 "greater": {
3900 "op": Op.GREATER,
3901 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003902 "build_fcn": (
3903 build_comparison,
3904 TosaTensorGen.tgBroadcastFuzz,
3905 TosaTensorValuesGen.tvgDefault,
3906 None,
3907 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003908 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003909 "error_if_validators": (
3910 TosaErrorValidator.evRankMismatch,
3911 TosaErrorValidator.evWrongInputType,
3912 TosaErrorValidator.evWrongOutputType,
3913 TosaErrorValidator.evWrongInputList,
3914 TosaErrorValidator.evWrongOutputList,
3915 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003916 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003917 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003918 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003919 # Reduction operators
3920 "reduce_all": {
3921 "op": Op.REDUCE_ALL,
3922 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003923 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003924 "build_fcn": (
3925 build_reduce,
3926 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003927 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003928 TosaArgGen.agAxis,
3929 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003930 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003931 "error_if_validators": (
3932 TosaErrorValidator.evAxisLargerRank,
3933 TosaErrorValidator.evAxisSmallerZero,
3934 TosaErrorValidator.evShapeOfAxisNotOne,
3935 TosaErrorValidator.evWrongInputType,
3936 TosaErrorValidator.evWrongOutputType,
3937 TosaErrorValidator.evWrongRank,
3938 TosaErrorValidator.evWrongInputList,
3939 TosaErrorValidator.evWrongOutputList,
3940 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003941 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003942 "reduce_any": {
3943 "op": Op.REDUCE_ANY,
3944 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003945 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003946 "build_fcn": (
3947 build_reduce,
3948 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003949 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003950 TosaArgGen.agAxis,
3951 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003952 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003953 "error_if_validators": (
3954 TosaErrorValidator.evAxisLargerRank,
3955 TosaErrorValidator.evAxisSmallerZero,
3956 TosaErrorValidator.evShapeOfAxisNotOne,
3957 TosaErrorValidator.evWrongInputType,
3958 TosaErrorValidator.evWrongOutputType,
3959 TosaErrorValidator.evWrongRank,
3960 TosaErrorValidator.evWrongInputList,
3961 TosaErrorValidator.evWrongOutputList,
3962 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003963 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003964 "reduce_max": {
3965 "op": Op.REDUCE_MAX,
3966 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003967 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003968 "build_fcn": (
3969 build_reduce,
3970 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003971 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003972 TosaArgGen.agAxis,
3973 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003974 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003975 "error_if_validators": (
3976 TosaErrorValidator.evAxisLargerRank,
3977 TosaErrorValidator.evAxisSmallerZero,
3978 TosaErrorValidator.evShapeOfAxisNotOne,
3979 TosaErrorValidator.evWrongInputType,
3980 TosaErrorValidator.evWrongOutputType,
3981 TosaErrorValidator.evWrongRank,
3982 TosaErrorValidator.evWrongInputList,
3983 TosaErrorValidator.evWrongOutputList,
3984 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003985 "data_gen": {
3986 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3987 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003988 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003989 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003990 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003991 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003992 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003993 "build_fcn": (
3994 build_reduce,
3995 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003996 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003997 TosaArgGen.agAxis,
3998 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003999 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004000 "error_if_validators": (
4001 TosaErrorValidator.evAxisLargerRank,
4002 TosaErrorValidator.evAxisSmallerZero,
4003 TosaErrorValidator.evShapeOfAxisNotOne,
4004 TosaErrorValidator.evWrongInputType,
4005 TosaErrorValidator.evWrongOutputType,
4006 TosaErrorValidator.evWrongRank,
4007 TosaErrorValidator.evWrongInputList,
4008 TosaErrorValidator.evWrongOutputList,
4009 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004010 "data_gen": {
4011 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4012 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004013 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004014 "reduce_product": {
4015 "op": Op.REDUCE_PRODUCT,
4016 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004017 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004018 "build_fcn": (
4019 build_reduce,
4020 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004021 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004022 TosaArgGen.agAxis,
4023 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004024 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004025 "error_if_validators": (
4026 TosaErrorValidator.evAxisLargerRank,
4027 TosaErrorValidator.evAxisSmallerZero,
4028 TosaErrorValidator.evShapeOfAxisNotOne,
4029 TosaErrorValidator.evWrongInputType,
4030 TosaErrorValidator.evWrongOutputType,
4031 TosaErrorValidator.evWrongRank,
4032 TosaErrorValidator.evWrongInputList,
4033 TosaErrorValidator.evWrongOutputList,
4034 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004035 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004036 "reduce_sum": {
4037 "op": Op.REDUCE_SUM,
4038 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004039 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004040 "build_fcn": (
4041 build_reduce,
4042 TosaTensorGen.tgBasic,
4043 TosaTensorValuesGen.tvgReduceSum,
4044 TosaArgGen.agAxis,
4045 ),
James Ward24dbc422022-10-19 12:20:31 +01004046 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004047 "error_if_validators": (
4048 TosaErrorValidator.evAxisLargerRank,
4049 TosaErrorValidator.evAxisSmallerZero,
4050 TosaErrorValidator.evShapeOfAxisNotOne,
4051 TosaErrorValidator.evWrongInputType,
4052 TosaErrorValidator.evWrongOutputType,
4053 TosaErrorValidator.evWrongRank,
4054 TosaErrorValidator.evWrongInputList,
4055 TosaErrorValidator.evWrongOutputList,
4056 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004057 "data_gen": {
4058 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4059 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004060 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004061 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004062 "concat": {
4063 "op": Op.CONCAT,
4064 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004065 "build_fcn": (
4066 build_concat,
4067 TosaTensorGen.tgConcat,
4068 TosaTensorValuesGen.tvgConcat,
4069 TosaArgGen.agAxis,
4070 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004071 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004072 "error_if_validators": (
4073 TosaErrorValidator.evAxisLargerRank,
4074 TosaErrorValidator.evAxisSmallerZero,
4075 TosaErrorValidator.evConcatInputRankMismatch,
4076 TosaErrorValidator.evConcatShapeSumMismatch,
4077 TosaErrorValidator.evConcatInputDimMismatch,
4078 TosaErrorValidator.evWrongInputType,
4079 TosaErrorValidator.evWrongOutputType,
4080 TosaErrorValidator.evWrongOutputList,
4081 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004082 },
4083 "pad": {
4084 "op": Op.PAD,
4085 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004086 "build_fcn": (
4087 build_pad,
4088 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004089 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004090 TosaArgGen.agPad,
4091 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004092 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004093 "error_if_validators": (
4094 TosaErrorValidator.evWrongInputType,
4095 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004096 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004097 TosaErrorValidator.evWrongOutputType,
4098 TosaErrorValidator.evWrongInputList,
4099 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004100 TosaErrorValidator.evRankMismatch,
4101 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004102 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004103 "data_gen": {
4104 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4105 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004106 },
Won Jeona21b2e82023-08-10 10:33:01 +00004107 "dim": {
4108 "op": Op.DIM,
4109 "operands": (1, 0),
4110 "build_fcn": (
4111 build_dim,
4112 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004113 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004114 TosaArgGen.agAxis,
4115 ),
4116 "types": TYPE_FIB,
4117 "error_if_validators": (
4118 TosaErrorValidator.evAxisLargerRank,
4119 TosaErrorValidator.evAxisSmallerZero,
4120 TosaErrorValidator.evWrongInputType,
4121 TosaErrorValidator.evWrongInputList,
4122 TosaErrorValidator.evWrongOutputList,
4123 TosaErrorValidator.evWrongRank,
4124 ),
4125 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004126 "reshape": {
4127 "op": Op.RESHAPE,
4128 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004129 "build_fcn": (
4130 build_reshape,
4131 TosaTensorGen.tgBasic,
4132 TosaTensorValuesGen.tvgDefault,
4133 TosaArgGen.agReshape,
4134 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004135 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004136 "error_if_validators": (
4137 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4138 TosaErrorValidator.evWrongInputType,
4139 TosaErrorValidator.evWrongOutputType,
4140 TosaErrorValidator.evWrongInputList,
4141 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00004142 TosaErrorValidator.evReshapeOutputSizeMultiInference,
4143 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004144 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004145 },
4146 "reverse": {
4147 "op": Op.REVERSE,
4148 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004149 "build_fcn": (
4150 build_reverse,
4151 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004152 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004153 TosaArgGen.agAxis,
4154 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004155 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004156 "error_if_validators": (
4157 TosaErrorValidator.evAxisSmallerZero,
4158 TosaErrorValidator.evAxisLargerRank,
4159 TosaErrorValidator.evWrongInputType,
4160 TosaErrorValidator.evWrongOutputType,
4161 TosaErrorValidator.evWrongInputList,
4162 TosaErrorValidator.evWrongOutputList,
4163 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004164 },
4165 "slice": {
4166 "op": Op.SLICE,
4167 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004168 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004169 "build_fcn": (
4170 build_slice,
4171 TosaTensorGen.tgBasic,
4172 TosaTensorValuesGen.tvgDefault,
4173 TosaArgGen.agSlice,
4174 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004175 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004176 "error_if_validators": (
4177 TosaErrorValidator.evStartSmallerZero,
4178 TosaErrorValidator.evSizeSmallerEqualZero,
4179 TosaErrorValidator.evStartSizeOutsideBounds,
4180 TosaErrorValidator.evSizeOutputShapeMismatch,
4181 TosaErrorValidator.evInputSizeStartLengthMismatch,
4182 TosaErrorValidator.evWrongRank,
4183 TosaErrorValidator.evWrongInputType,
4184 TosaErrorValidator.evWrongOutputType,
4185 TosaErrorValidator.evWrongInputList,
4186 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004187 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004188 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004189 },
4190 "tile": {
4191 "op": Op.TILE,
4192 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004193 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004194 "build_fcn": (
4195 build_tile,
4196 TosaTensorGen.tgBasic,
4197 TosaTensorValuesGen.tvgDefault,
4198 TosaArgGen.agTile,
4199 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004200 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004201 "error_if_validators": (
4202 TosaErrorValidator.evWrongInputType,
4203 TosaErrorValidator.evWrongOutputType,
4204 TosaErrorValidator.evWrongInputList,
4205 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004206 TosaErrorValidator.evRankMismatch,
4207 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004208 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004209 },
4210 "transpose": {
4211 "op": Op.TRANSPOSE,
4212 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004213 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004214 "build_fcn": (
4215 build_transpose,
4216 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004217 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004218 TosaArgGen.agTranspose,
4219 ),
4220 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004221 "error_if_validators": (
4222 TosaErrorValidator.evIndexOutsideBounds,
4223 TosaErrorValidator.evIndexUsedTwice,
4224 TosaErrorValidator.evWrongInputType,
4225 TosaErrorValidator.evWrongOutputType,
4226 TosaErrorValidator.evWrongInputList,
4227 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004228 TosaErrorValidator.evWrongRank,
4229 TosaErrorValidator.evRankMismatch,
4230 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004231 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004232 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004233 # Data nodes
4234 "const": {
4235 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004236 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004237 "build_fcn": (
4238 build_const,
4239 TosaTensorGen.tgBasic,
4240 TosaTensorValuesGen.tvgDefault,
4241 None,
4242 ),
Luke Hutton65872422023-02-20 10:33:04 +00004243 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004244 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004245 "identity": {
4246 "op": Op.IDENTITY,
4247 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004248 "build_fcn": (
4249 build_unary,
4250 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004251 TosaTensorValuesGen.tvgLazyGenDefault,
4252 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004253 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004254 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004255 "data_gen": {
4256 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4257 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004258 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004259 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004260 "gather": {
4261 "op": Op.GATHER,
4262 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4263 "operands": (1, 0),
4264 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004265 "build_fcn": (
4266 build_gather,
4267 TosaTensorGen.tgBasic,
4268 TosaTensorValuesGen.tvgDefault,
4269 None,
4270 ),
James Ward24dbc422022-10-19 12:20:31 +01004271 "types": (
4272 DType.INT8,
4273 DType.INT16,
4274 DType.INT32,
4275 DType.FP16,
4276 DType.BF16,
4277 DType.FP32,
4278 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004279 "error_if_validators": (
4280 TosaErrorValidator.evWrongInputType,
4281 TosaErrorValidator.evWrongOutputType,
4282 TosaErrorValidator.evWrongInputList,
4283 TosaErrorValidator.evWrongOutputList,
4284 TosaErrorValidator.evWrongRank,
4285 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004286 },
4287 "scatter": {
4288 "op": Op.SCATTER,
4289 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004290 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004291 "operands": (2, 0),
4292 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004293 "build_fcn": (
4294 build_scatter,
4295 TosaTensorGen.tgScatter,
4296 TosaTensorValuesGen.tvgDefault,
4297 None,
4298 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004299 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004300 "error_if_validators": (
4301 TosaErrorValidator.evWrongInputType,
4302 TosaErrorValidator.evWrongOutputType,
4303 TosaErrorValidator.evWrongInputList,
4304 TosaErrorValidator.evWrongOutputList,
4305 TosaErrorValidator.evWrongRank,
4306 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004307 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004308 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004309 "resize": {
4310 "op": Op.RESIZE,
4311 "operands": (1, 0),
4312 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004313 "build_fcn": (
4314 build_resize,
4315 TosaTensorGen.tgNHWC,
4316 TosaTensorValuesGen.tvgDefault,
4317 TosaArgGen.agResize,
4318 ),
James Ward24dbc422022-10-19 12:20:31 +01004319 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004320 "invalid_test_validators": (
4321 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004322 ),
4323 "error_if_validators": (
4324 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004325 TosaErrorValidator.evScaleSmallerEqualZero,
4326 TosaErrorValidator.evScaleNLargerMax,
4327 TosaErrorValidator.evScaleDLargerMax,
4328 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004329 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004330 TosaErrorValidator.evBorderSmallerMin,
4331 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004332 TosaErrorValidator.evWrongInputType,
4333 TosaErrorValidator.evWrongOutputType,
4334 TosaErrorValidator.evWrongRank,
4335 TosaErrorValidator.evWrongInputList,
4336 TosaErrorValidator.evWrongOutputList,
4337 TosaErrorValidator.evBatchMismatch,
4338 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004339 TosaErrorValidator.evResizeOutputShapeMismatch,
4340 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004341 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004342 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004343 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004344 "cast": {
4345 "op": Op.CAST,
4346 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004347 "build_fcn": (
4348 build_cast,
4349 TosaTensorGen.tgBasic,
4350 TosaTensorValuesGen.tvgDefault,
4351 TosaArgGen.agCast,
4352 ),
James Ward8b390432022-08-12 20:48:56 +01004353 "types": (
4354 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004355 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004356 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004357 DType.INT8,
4358 DType.INT16,
4359 DType.INT32,
4360 DType.BOOL,
4361 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004362 "error_if_validators": (
4363 TosaErrorValidator.evWrongInputType,
4364 TosaErrorValidator.evWrongOutputType,
4365 TosaErrorValidator.evWrongInputList,
4366 TosaErrorValidator.evWrongOutputList,
4367 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004368 },
4369 "rescale": {
4370 "op": Op.RESCALE,
4371 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004372 "build_fcn": (
4373 build_rescale,
4374 TosaTensorGen.tgBasic,
4375 TosaTensorValuesGen.tvgDefault,
4376 TosaArgGen.agRescale,
4377 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004378 "types": [
4379 DType.UINT8,
4380 DType.INT8,
4381 DType.INT16,
4382 DType.INT32,
4383 DType.INT48,
4384 DType.UINT16,
4385 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004386 "error_if_validators": (
4387 TosaErrorValidator.evInputZeroPointNotZero,
4388 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004389 TosaErrorValidator.evU16InputZeroPointNotValid,
4390 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004391 TosaErrorValidator.evScaleTrue,
4392 TosaErrorValidator.evScaleNotTrue,
4393 TosaErrorValidator.evWrongInputType,
4394 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004395 TosaErrorValidator.evWrongInputList,
4396 TosaErrorValidator.evWrongOutputList,
4397 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004398 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004399 # Custom
4400 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004401 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004402 # Two varients of cond_if, one that generates one of two constant tensors (no
4403 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4404 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004405 "cond_if_const": {
4406 "op": Op.COND_IF,
4407 "operands": (0, 2),
4408 "build_fcn": (
4409 build_cond_if_const,
4410 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004411 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004412 TosaArgGen.agCondIf,
4413 ),
4414 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004415 "error_if_validators": (
4416 TosaErrorValidator.evOutputListThenGraphMismatch,
4417 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004418 TosaErrorValidator.evCondIfCondNotMatchingBool,
4419 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004420 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004421 },
4422 "cond_if_binary": {
4423 "op": Op.COND_IF,
4424 "operands": (2, 0),
4425 "build_fcn": (
4426 build_cond_if_binary,
4427 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004428 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004429 TosaArgGen.agCondIf,
4430 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004431 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004432 "error_if_validators": (
4433 TosaErrorValidator.evInputListThenGraphMismatch,
4434 TosaErrorValidator.evInputListElseGraphMismatch,
4435 TosaErrorValidator.evOutputListThenGraphMismatch,
4436 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004437 TosaErrorValidator.evCondIfCondNotMatchingBool,
4438 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004439 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004440 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004441 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004442 "while_loop": {
4443 "op": Op.WHILE_LOOP,
4444 "operands": (0, 1),
4445 "build_fcn": (
4446 build_while_loop,
4447 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004448 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004449 TosaArgGen.agWhileLoop,
4450 ),
4451 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004452 "error_if_validators": (
4453 TosaErrorValidator.evInputListOutputListMismatch,
4454 TosaErrorValidator.evInputListCondGraphMismatch,
4455 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4456 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4457 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004458 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004459 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004460 },
Luke Hutton57287132023-02-06 14:54:18 +00004461 "fft2d": {
4462 "op": Op.FFT2D,
4463 "operands": (2, 0),
4464 "rank": (3, 3),
4465 "build_fcn": (
4466 build_fft2d,
4467 TosaTensorGen.tgFFT2d,
4468 TosaTensorValuesGen.tvgDefault,
4469 TosaArgGen.agFFT2d,
4470 ),
4471 "types": [DType.FP32],
4472 "error_if_validators": (
4473 TosaErrorValidator.evWrongInputType,
4474 TosaErrorValidator.evWrongOutputType,
4475 TosaErrorValidator.evWrongInputList,
4476 TosaErrorValidator.evWrongOutputList,
4477 TosaErrorValidator.evWrongRank,
4478 TosaErrorValidator.evBatchMismatch,
4479 TosaErrorValidator.evKernelNotPowerOfTwo,
4480 TosaErrorValidator.evFFTInputShapeMismatch,
4481 TosaErrorValidator.evFFTOutputShapeMismatch,
4482 ),
4483 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004484 "rfft2d": {
4485 "op": Op.RFFT2D,
4486 "operands": (1, 0),
4487 "rank": (3, 3),
4488 "build_fcn": (
4489 build_rfft2d,
4490 TosaTensorGen.tgRFFT2d,
4491 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004492 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004493 ),
4494 "types": [DType.FP32],
4495 "error_if_validators": (
4496 TosaErrorValidator.evWrongInputType,
4497 TosaErrorValidator.evWrongOutputType,
4498 TosaErrorValidator.evWrongInputList,
4499 TosaErrorValidator.evWrongOutputList,
4500 TosaErrorValidator.evWrongRank,
4501 TosaErrorValidator.evBatchMismatch,
4502 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004503 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004504 ),
4505 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004506 }
4507
Kevin Cheng550ccc52021-03-03 11:21:43 -08004508
Eric Kunzee5e26762020-10-13 16:11:07 -07004509class OutputShaper:
4510 # Methods in this class compute the expected output shape and datatype
4511 # for common classes of operations
4512 def __init__(self):
4513 pass
4514
4515 # These methods return arguments that can be used for
4516 # creating a new output tensor
4517 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004518 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4519 if error_name != ErrorIf.RankMismatch:
4520 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004521 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004522
4523 shape = []
4524 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004525 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004526 shape.append(b.shape[i])
4527 else:
4528 shape.append(a.shape[i])
4529
Jerry Ge135c9552023-05-23 20:59:32 +00004530 fuzz_idx = rng.integers(0, len(a.shape))
4531 if error_name == ErrorIf.DimensionMismatch:
4532 shape[fuzz_idx] += 1
4533
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004534 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004535 all_dtypes = [
4536 DType.INT8,
4537 DType.INT16,
4538 DType.INT32,
4539 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004540 DType.FP16,
4541 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004542 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004543 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004544 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4545 outputDType = rng.choice(wrong_dtypes)
4546 else:
4547 outputDType = a.dtype
4548
4549 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004550
4551 @staticmethod
4552 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004553 assert len(a.shape) == len(b.shape)
4554 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004555
4556 shape = []
4557 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004558 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004559 shape.append(a.shape[i])
4560
Kevin Cheng550ccc52021-03-03 11:21:43 -08004561 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004562
4563 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004564 def unaryOp(ser, rng, a, error_name=None):
4565 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004566 all_dtypes = [
4567 DType.INT8,
4568 DType.INT16,
4569 DType.INT32,
4570 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004571 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004572 DType.FP16,
4573 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004574 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004575 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4576 outputDType = rng.choice(wrong_dtypes)
4577 else:
4578 outputDType = a.dtype
4579
4580 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004581
4582 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004583 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004584 if error_name != ErrorIf.RankMismatch:
4585 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004586 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004587
4588 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004589 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004590 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004591 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4592 else:
4593 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004594
Jerry Ge135c9552023-05-23 20:59:32 +00004595 fuzz_idx = rng.integers(0, len(a.shape))
4596 if error_name == ErrorIf.DimensionMismatch:
4597 shape[fuzz_idx] += 1
4598
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004599 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004600 all_dtypes = [
4601 DType.INT8,
4602 DType.INT16,
4603 DType.INT32,
4604 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004605 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004606 DType.FP16,
4607 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004608 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004609 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4610 outputDType = rng.choice(wrong_dtypes)
4611 else:
4612 outputDType = a.dtype
4613
4614 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004615
4616 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004617 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004618 if error_name != ErrorIf.RankMismatch:
4619 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004620 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004621
4622 # Do broadcast
4623 shape = []
4624 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004625 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004626 shape.append(b.shape[i])
4627 else:
4628 shape.append(a.shape[i])
4629
Jerry Ge135c9552023-05-23 20:59:32 +00004630 fuzz_idx = rng.integers(0, len(a.shape))
4631 if error_name == ErrorIf.DimensionMismatch:
4632 shape[fuzz_idx] += 1
4633
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004634 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004635 wrong_dtypes = [
4636 DType.INT8,
4637 DType.INT16,
4638 DType.INT32,
4639 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004640 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004641 DType.FP16,
4642 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004643 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004644 outputDType = rng.choice(wrong_dtypes)
4645 else:
4646 outputDType = DType.BOOL
4647
4648 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004649
4650 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004651 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004652 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004653 if error_name not in [
4654 ErrorIf.AxisSmallerZero,
4655 ErrorIf.AxisLargerRank,
4656 ErrorIf.ShapeOfAxisNotOne,
4657 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004658 shape[axis] = 1
4659 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4660 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004661
Matthew Haddond6ce7252021-09-29 15:35:44 +01004662 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004663 all_dtypes = [
4664 DType.INT8,
4665 DType.INT16,
4666 DType.INT32,
4667 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004668 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004669 DType.FP16,
4670 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004671 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004672 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4673 outputDType = rng.choice(wrong_dtypes)
4674 else:
4675 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004676
Matthew Haddond6ce7252021-09-29 15:35:44 +01004677 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004678
4679 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004680 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004681 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004682
4683 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4684 del shape[axis]
4685
4686 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4687 remove = rng.choice([True, False])
4688 if remove and len(shape) > 1:
4689 del shape[0]
4690 else:
4691 shape.append(1)
4692 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4693 for i in range(len(shape)):
4694 shape[i] = shape[i] + rng.integers(1, 10)
4695
4696 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004697 all_dtypes = [
4698 DType.INT8,
4699 DType.INT16,
4700 DType.INT32,
4701 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004702 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004703 DType.FP16,
4704 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004705 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004706 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4707 outputDType = rng.choice(wrong_dtypes)
4708 else:
4709 outputDType = DType.INT32
4710
4711 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004712
4713 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004714 def conv2dOp(
4715 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4716 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004717
4718 # IFM: NHWC
4719 # Filter: OHWI
4720 # OFM: NHWC
4721
Kevin Cheng550ccc52021-03-03 11:21:43 -08004722 h = (
4723 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004724 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004725 + padding[0]
4726 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004727 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004728 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004729
Kevin Cheng550ccc52021-03-03 11:21:43 -08004730 w = (
4731 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004732 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004733 + padding[2]
4734 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004735 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004736 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004737
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004738 if error_name == ErrorIf.ConvOutputShapeMismatch:
4739 choices = [1, 2, 3]
4740 change = rng.choice(choices)
4741 # increment in multiples of stride to not hit non-integer error case
4742 if change in [1, 3]:
4743 h = h + (rng.choice(choices) * strides[0])
4744 if change in [2, 3]:
4745 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004746
Eric Kunzee5e26762020-10-13 16:11:07 -07004747 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4748
James Ward8b390432022-08-12 20:48:56 +01004749 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004750 # Pick some potentially correct output dtype if input type is incorrect
4751 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004752 else:
James Ward8b390432022-08-12 20:48:56 +01004753 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004754
4755 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004756 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004757 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004758 else:
4759 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004760 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004761 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004762
Kevin Cheng550ccc52021-03-03 11:21:43 -08004763 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004764
4765 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004766 def conv3dOp(
4767 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4768 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004769
4770 # IFM: NDHWC
4771 # Filter: ODHWI
4772 # OFM: NDHWC
4773
4774 d = (
4775 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004776 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004777 + padding[0]
4778 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004779 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004780 ) // strides[0] + 1
4781
4782 h = (
4783 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004784 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004785 + padding[2]
4786 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004787 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004788 ) // strides[1] + 1
4789
4790 w = (
4791 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004792 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004793 + padding[4]
4794 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004795 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004796 ) // strides[2] + 1
4797
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004798 if error_name == ErrorIf.ConvOutputShapeMismatch:
4799 choices = [1, 2, 3, 4]
4800 change = rng.choice(choices)
4801 # increment in multiples of stride to not hit non-integer error case
4802 if change in [1, 4]:
4803 d = d + (rng.choice(choices) * strides[0])
4804 if change in [2, 4]:
4805 h = h + (rng.choice(choices) * strides[1])
4806 if change in [3, 4]:
4807 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004808
Kevin Cheng1533b852021-09-01 12:51:58 -07004809 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4810
James Ward8b390432022-08-12 20:48:56 +01004811 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004812 # Pick some potentially correct output dtype if input type is incorrect
4813 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004814 else:
James Ward8b390432022-08-12 20:48:56 +01004815 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004816
4817 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004818 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004819 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004820 else:
4821 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004822 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004823 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004824
4825 return ser.addOutput(ofm_shape, out_dtype)
4826
4827 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004828 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004829 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004830 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004831 # IFM: NHWC
4832 # Filter: HWCM
4833 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004834
Kevin Cheng550ccc52021-03-03 11:21:43 -08004835 h = (
4836 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004837 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004838 + padding[0]
4839 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004840 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004841 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004842
Kevin Cheng550ccc52021-03-03 11:21:43 -08004843 w = (
4844 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004845 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004846 + padding[2]
4847 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004848 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004849 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004850
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004851 if error_name == ErrorIf.ConvOutputShapeMismatch:
4852 choices = [1, 2, 3]
4853 change = rng.choice(choices)
4854 # increment in multiples of stride to not hit non-integer error case
4855 if change in [1, 3]:
4856 h = h + (rng.choice(choices) * strides[0])
4857 if change in [2, 3]:
4858 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004859
Eric Kunzee5e26762020-10-13 16:11:07 -07004860 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4861
James Ward8b390432022-08-12 20:48:56 +01004862 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004863 # Pick some potentially correct output dtype if input type is incorrect
4864 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004865 else:
James Ward8b390432022-08-12 20:48:56 +01004866 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004867
4868 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004869 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004870 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004871 else:
4872 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004873 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004874 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004875
Kevin Cheng550ccc52021-03-03 11:21:43 -08004876 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004877
4878 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004879 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004880 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004881 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004882 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004883 h = 1
4884 w = 1
4885 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004886 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4887 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004888
4889 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004890 choices = [1, 2, 3]
4891 change = rng.choice(choices)
4892 # increment in multiples of stride to not hit non-integer error case
4893 if change in [1, 3]:
4894 h = h + (rng.choice(choices) * stride[0])
4895 if change in [2, 3]:
4896 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004897 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004898
4899 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004900 all_dtypes = [
4901 DType.INT8,
4902 DType.INT16,
4903 DType.INT32,
4904 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004905 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004906 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004907 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004908 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004909 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4910 outputDType = rng.choice(wrong_dtypes)
4911 else:
4912 outputDType = ifm.dtype
4913
4914 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004915
4916 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004917 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004918 # input: N, IC
4919 # filter: OC, IC
4920 # output: N, OC
4921
4922 output_shape = [input.shape[0], filter.shape[0]]
4923
James Ward8b390432022-08-12 20:48:56 +01004924 # Validated in arg_gen (also invalidated for ErrorIf)
4925 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004926
Kevin Cheng550ccc52021-03-03 11:21:43 -08004927 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004928
4929 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004930 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004931 # a: N, H, C
4932 # b: N, C, W
4933 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004934
Kevin Cheng2d60f002021-06-09 14:18:32 -07004935 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004936
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004937 if error_name == ErrorIf.WrongOutputType:
4938 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004939 incorrect_types = (
4940 DType.INT4,
4941 DType.INT8,
4942 DType.INT16,
4943 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004944 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004945 DType.FP16,
4946 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004947 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004948 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004949 incorrect_types = (
4950 DType.INT4,
4951 DType.INT8,
4952 DType.INT16,
4953 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004954 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004955 DType.FP16,
4956 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004957 )
James Ward24dbc422022-10-19 12:20:31 +01004958 elif (
4959 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4960 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004961 incorrect_types = (
4962 DType.INT4,
4963 DType.INT8,
4964 DType.INT16,
4965 DType.INT32,
4966 DType.INT48,
4967 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004968 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004969 elif error_name == ErrorIf.WrongInputType:
4970 # Pick some potentially correct output dtype if input type is incorrect
4971 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004972 else:
James Ward8b390432022-08-12 20:48:56 +01004973 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004974
Kevin Cheng550ccc52021-03-03 11:21:43 -08004975 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004976
4977 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004978 def concatOp(ser, rng, axis, inputs, error_name=None):
4979 input1 = inputs[0]
4980 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004981
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004982 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004983 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004984 if not (
4985 # unable to concat tensors of different ranks
4986 error_name == ErrorIf.ConcatInputRankMismatch
4987 # unable to concat tensors along an invalid axis
4988 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004989 ):
4990 for tensor in remaining_inputs:
4991 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004992
Matthew Haddon01c359d2021-10-15 16:30:48 +01004993 if error_name == ErrorIf.ConcatShapeSumMismatch:
4994 output_shape[axis] += rng.integers(5, 10)
4995
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004996 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004997 all_dtypes = {
4998 DType.INT8,
4999 DType.INT16,
5000 DType.INT32,
5001 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005002 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005003 DType.FP16,
5004 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005005 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005006 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5007 outputDType = rng.choice(wrong_dtypes)
5008 else:
5009 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005010
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005011 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005012
5013 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005014 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005015
5016 output_shape = a.shape.copy()
5017
5018 for i in range(len(output_shape)):
5019 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5020
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005021 if error_name == ErrorIf.PadOutputShapeMismatch:
5022 bad_dim = rng.choice(range(len(output_shape)))
5023 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005024 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005025 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005026
Matthew Haddone807aae2021-10-11 18:12:58 +01005027 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005028 all_dtypes = [
5029 DType.INT8,
5030 DType.INT16,
5031 DType.INT32,
5032 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005033 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005034 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005035 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005036 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005037 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5038 outputDType = rng.choice(wrong_dtypes)
5039 else:
5040 outputDType = a.dtype
5041
5042 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005043
5044 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005045 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00005046 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00005047
5048 if error_name == ErrorIf.WrongOutputType:
5049 all_dtypes = [
5050 DType.INT8,
5051 DType.INT16,
5052 DType.INT32,
5053 DType.INT48,
5054 DType.FP32,
5055 DType.FP16,
5056 DType.BF16,
5057 ]
5058 wrong_dtypes = list(set(all_dtypes))
5059 outputDType = rng.choice(wrong_dtypes)
5060 else:
5061 outputDType = DType.SHAPE
5062
5063 return ser.addOutput(output_shape, outputDType)
5064
5065 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005066 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005067 output_shape = shape.copy()
5068
Matthew Haddone807aae2021-10-11 18:12:58 +01005069 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5070 for i in range(len(output_shape)):
5071 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5072
5073 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005074 all_dtypes = [
5075 DType.INT8,
5076 DType.INT16,
5077 DType.INT32,
5078 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005079 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005080 DType.FP16,
5081 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005082 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005083 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5084 outputDType = rng.choice(wrong_dtypes)
5085 else:
5086 outputDType = a.dtype
5087
5088 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005089
5090 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005091 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005092
Matthew Haddone807aae2021-10-11 18:12:58 +01005093 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005094 all_dtypes = [
5095 DType.INT8,
5096 DType.INT16,
5097 DType.INT32,
5098 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005099 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005100 DType.FP16,
5101 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005102 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005103 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005104 outputDType = rng.choice(wrong_dtypes)
5105 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005106 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005107
Luke Huttona4e48ca2023-02-22 11:53:48 +00005108 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005109 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005110 for index in range(len(output_shape)):
5111 if output_shape[index] <= 2:
5112 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5113 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005114 output_shape[index] = output_shape[index] + rng.choice(
5115 [-2, -1, 1, 2]
5116 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005117 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5118 output_shape = input.shape.copy()
5119 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005120 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005121
5122 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005123
5124 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005125 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005126
5127 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005128 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005129
5130 for i in range(len(output_shape)):
5131 output_shape[i] = a.shape[i] * multiples[i]
5132
Luke Huttona4e48ca2023-02-22 11:53:48 +00005133 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005134 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005135
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005136 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005137 all_dtypes = [
5138 DType.INT8,
5139 DType.INT16,
5140 DType.INT32,
5141 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005142 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005143 DType.FP16,
5144 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005145 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005146 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5147 outputDType = rng.choice(wrong_dtypes)
5148 else:
5149 outputDType = a.dtype
5150
5151 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005152
5153 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005154 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005155 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005156
Kevin Cheng550ccc52021-03-03 11:21:43 -08005157 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005158
Luke Huttona4e48ca2023-02-22 11:53:48 +00005159 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005160 for i in range(len(output_shape)):
5161 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005162
Luke Huttona4e48ca2023-02-22 11:53:48 +00005163 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5164 for i in range(len(output_shape)):
5165 output_shape[i] += rng.integers(1, 10)
5166 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005167 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005168
Matthew Haddone807aae2021-10-11 18:12:58 +01005169 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005170 all_dtypes = [
5171 DType.INT8,
5172 DType.INT16,
5173 DType.INT32,
5174 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005175 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005176 DType.FP16,
5177 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005178 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005179 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5180 outputDType = rng.choice(wrong_dtypes)
5181 else:
5182 outputDType = a.dtype
5183
5184 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005185
5186 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005187 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005188 if error_name != ErrorIf.WrongRank:
5189 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005190 assert len(indices.shape) == 2
5191 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005192
Kevin Cheng77d0f762020-11-24 10:26:32 -08005193 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5194
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005195 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005196 all_dtypes = [
5197 DType.INT8,
5198 DType.INT16,
5199 DType.INT32,
5200 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005201 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005202 DType.FP16,
5203 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005204 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005205 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5206 outputDType = rng.choice(wrong_dtypes)
5207 else:
5208 outputDType = values.dtype
5209
5210 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005211
5212 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005213 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005214 if error_name != ErrorIf.WrongRank:
5215 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005216 assert len(indices.shape) == 2
5217 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005218 assert values_in.shape[0] == indices.shape[0] # N
5219 assert input.shape[1] == indices.shape[1] # W
5220 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005221
5222 output_shape = values_in.shape
5223
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005224 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005225 all_dtypes = [
5226 DType.INT8,
5227 DType.INT16,
5228 DType.INT32,
5229 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005230 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005231 DType.FP16,
5232 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005233 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005234 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5235 outputDType = rng.choice(wrong_dtypes)
5236 else:
5237 outputDType = values_in.dtype
5238
5239 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005240
5241 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005242 def tableOp(ser, rng, input, error_name=None):
5243 # Same shape as the input, dtype dependent on input dtype
5244 if error_name != ErrorIf.WrongInputType:
5245 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005246 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005247 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005248 wrong_dtypes = [
5249 DType.INT8,
5250 DType.INT16,
5251 DType.INT32,
5252 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005253 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005254 DType.FP16,
5255 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005256 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005257 wrong_dtypes.remove(output_dtype)
5258 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005259 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005260
5261 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005262 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005263 serializer,
5264 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005265 input,
5266 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005267 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005268 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005269 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005270 input_dtype,
5271 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005272 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005273 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005274 # Calculate OH, OW
5275 scale_y_n = scale[0]
5276 scale_y_d = scale[1]
5277 scale_x_n = scale[2]
5278 scale_x_d = scale[3]
5279 if error_name == ErrorIf.ScaleSmallerEqualZero:
5280 scale_y_n = max(scale_y_n, 1)
5281 scale_y_d = max(scale_y_d, 1)
5282 scale_x_n = max(scale_x_n, 1)
5283 scale_x_d = max(scale_x_d, 1)
5284
5285 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5286 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5287
5288 if error_name is not None:
5289 # Make sure the output tensor is valid, which can occur when
5290 # scale, offset or border have been changed for ERROR_IFs
5291 oh = max(oh, 1)
5292 ow = max(ow, 1)
5293 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005294 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5295 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005296
5297 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5298 choices = [1, 2, 3]
5299 change = rng.choice(choices)
5300 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5301 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005302 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005303 oh -= scale_y_d
5304 assert oh > 0 # Should have been caught in agResize
5305 else:
5306 oh += scale_y_d
5307 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005308 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005309 ow -= scale_x_d
5310 assert ow > 0 # Should have been caught in agResize
5311 else:
5312 ow += scale_x_d
5313
Matthew Haddon848efb42021-09-09 12:30:53 +01005314 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005315 output_dims = [
5316 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005317 oh,
5318 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005319 input.shape[0],
5320 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005321 elif error_name == ErrorIf.BatchMismatch:
5322 output_dims = [
5323 input.shape[0] + rng.integers(1, 10),
5324 oh,
5325 ow,
5326 input.shape[3],
5327 ]
5328 elif error_name == ErrorIf.ChannelMismatch:
5329 output_dims = [
5330 input.shape[0],
5331 oh,
5332 ow,
5333 input.shape[3] + rng.integers(1, 10),
5334 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005335 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005336 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005337
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005338 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005339
5340 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005341 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005342 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005343
5344 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005345 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005346 if error_name == ErrorIf.ConvOutputShapeMismatch:
5347 choices = [1, 2, 3]
5348 change = rng.choice(choices)
5349 if change in [1, 3]:
5350 output_shape[1] = output_shape[1] + rng.choice(choices)
5351 if change in [2, 3]:
5352 output_shape[2] = output_shape[2] + rng.choice(choices)
5353
James Ward8b390432022-08-12 20:48:56 +01005354 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005355 # Pick some potentially correct output dtype if input type is incorrect
5356 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005357 else:
James Ward8b390432022-08-12 20:48:56 +01005358 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005359
5360 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005361 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005362 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005363 else:
5364 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005365 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005366 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005367
Kevin Cheng550ccc52021-03-03 11:21:43 -08005368 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005369
5370 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005371 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5372 outputs = []
5373
5374 assert ifm1.dtype == ifm2.dtype
5375 input_dtype = ifm1.dtype
5376
5377 if error_name != ErrorIf.FFTInputShapeMismatch:
5378 assert ifm1.shape == ifm2.shape
5379
5380 input_shape = ifm1.shape
5381 if error_name != ErrorIf.WrongRank:
5382 assert len(input_shape) == 3
5383
5384 output_shape = input_shape.copy()
5385 output_dtype = input_dtype
5386
5387 if error_name == ErrorIf.WrongOutputType:
5388 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005389 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005390 output_dtype = rng.choice(wrong_dtypes)
5391 elif error_name == ErrorIf.BatchMismatch:
5392 output_shape[0] += rng.integers(1, 10)
5393 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5394 modify_dim = rng.choice([1, 2])
5395 output_shape[modify_dim] += rng.integers(1, 10)
5396
5397 outputs.append(serializer.addOutput(output_shape, output_dtype))
5398 outputs.append(serializer.addOutput(output_shape, output_dtype))
5399 return outputs
5400
5401 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005402 def rfft2dOp(serializer, rng, value, error_name=None):
5403 outputs = []
5404
5405 input_shape = value.shape
5406 if error_name != ErrorIf.WrongRank:
5407 assert len(input_shape) == 3
5408
5409 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5410
5411 output_dtype = value.dtype
5412 if error_name == ErrorIf.WrongOutputType:
5413 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005414 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005415 output_dtype = rng.choice(wrong_dtypes)
5416 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005417 output_shape[0] += rng.integers(1, 10)
5418 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5419 modify_dim = rng.choice([1, 2])
5420 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005421
5422 outputs.append(serializer.addOutput(output_shape, output_dtype))
5423 outputs.append(serializer.addOutput(output_shape, output_dtype))
5424 return outputs