blob: 9f65fd4bedc6b8b5e3adfeb7669e7ece7f31f21f [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
40 TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6)
41 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
69 vals.append(v)
70 return tuple(sorted(vals))
71
72 self.random_float_range = {}
73 for dtype in (DType.FP32, DType.FP16, DType.BF16):
74 self.random_float_range[dtype] = convertFPRange(
75 args.tensor_fp_value_range,
76 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
77 )
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 def createSerializer(self, opName, testPath):
80 self.testPath = os.path.join(opName, testPath)
81
82 fullPath = os.path.join(self.basePath, self.testPath)
83 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010084 # Embed const data in the flatbuffer
85 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010086 if self.args.lazy_data_gen:
87 # Lazy data generation - so make constants files
88 constMode = ts.ConstMode.INPUTS
89 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 constMode = ts.ConstMode.EMBED_DUMP
91 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070092
93 def getSerializer(self):
94 return self.ser
95
Jeremy Johnson1271c442023-09-05 11:39:26 +010096 def serialize(self, testName, metaData=None):
97 path = Path(self.basePath) / self.testPath
98
99 # Write out TOSA flatbuffer binary
100 path_fb = path / f"{testName}.tosa"
101 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 fd.write(self.ser.serialize())
103
Jeremy Johnson1271c442023-09-05 11:39:26 +0100104 # Get JSON descriptor from serializer
105 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
106
107 if metaData:
108 # Add extra meta data to desc.json
109 desc["meta"] = metaData
110
111 # Validate desc.json before we output it
112 self.descSchemaValidator.validate_config(desc)
113
114 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100115 if "data_gen" in metaData:
116 if self.args.lazy_data_gen:
117 # Output datagen meta data as CPP data
118 path_md = path / f"{testName}_meta_data_gen.cpp"
119 with path_md.open("w") as fd:
120 fd.write(TOSA_AUTOGENERATED_HEADER)
121 fd.write("// Test meta data for data generation setup\n\n")
122 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
123 json.dump(metaData["data_gen"], fd)
124 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100125 if "compliance" in metaData:
126 # Output datagen meta data as CPP data
127 path_md = path / f"{testName}_meta_compliance.cpp"
128 with path_md.open("w") as fd:
129 fd.write(TOSA_AUTOGENERATED_HEADER)
130 fd.write("// Test meta data for compliance validation\n\n")
131 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
132 json.dump(metaData["compliance"], fd)
133 fd.write(')";\n\n')
134
135 # Write desc.json
136 path_desc = path / "desc.json"
137 with path_desc.open("w") as fd:
138 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700139
Matthew Haddon74567092021-07-16 15:38:20 +0100140 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000141 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100142 seed = self.random_seed + 1
143 self.rng = np.random.default_rng(seed)
144
Jeremy Johnson1271c442023-09-05 11:39:26 +0100145 def getDTypeRange(self, dtype, high_inclusive=False):
146 # Returns dtype value range boundaries (low, high)
147 # The high boundary is excluded in the range
148 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100149 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100150 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 elif dtype == DType.BOOL:
152 rng = (0, 2)
153 elif dtype == DType.UINT8:
154 rng = (0, 256)
155 elif dtype == DType.UINT16:
156 rng = (0, 65536)
157 elif dtype == DType.INT4:
158 # TOSA specific INT4 weight range from -7 to 7
159 rng = (-7, 8)
160 elif dtype == DType.INT8:
161 rng = (-128, 128)
162 elif dtype == DType.INT16:
163 rng = (-32768, 32768)
164 elif dtype in (DType.INT32, DType.SHAPE):
165 # restricting too large value for SHAPE
166 rng = (-(1 << 31), (1 << 31))
167 elif dtype == DType.INT48:
168 rng = (-(1 << 47), (1 << 47))
169 else:
170 raise Exception("Unknown dtype: {}".format(dtype))
171
172 if not high_inclusive:
173 # Exclusive high: low <= range < high
174 return rng
175 else:
176 # Inclusive range: low <= range <= high
177 return (rng[0], rng[1] - 1)
178
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000179 def getRandTensor(self, shape, dtype, data_range=None):
180 if data_range is None:
181 low, high = self.getDTypeRange(dtype)
182 else:
183 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100184
Eric Kunzee5e26762020-10-13 16:11:07 -0700185 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100188 return np.int64(self.rng.integers(low=low, high=high, size=shape))
189 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
190 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
191
192 if dtype == DType.FP16:
193 return np.float16(f_tensor)
194 else:
195 f32_tensor = np.float32(f_tensor)
196 if dtype == DType.BF16:
197 # Floor the last 16 bits of each f32 value
198 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
199 else:
200 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100202 # All other integer types
203 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700204
Kevin Cheng989cb052021-04-28 16:29:44 -0700205 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 placeholders = []
207
Kevin Cheng989cb052021-04-28 16:29:44 -0700208 assert len(shape_list) == len(dtype_list)
209
Jeremy Johnson1271c442023-09-05 11:39:26 +0100210 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100212 if not self.args.lazy_data_gen:
213 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700214 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
216 return placeholders
217
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700219 consts = []
220
Kevin Cheng989cb052021-04-28 16:29:44 -0700221 assert len(shape_list) == len(dtype_list)
222
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100225 if not self.args.lazy_data_gen:
226 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700227 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700228
229 return consts
230
231 def makeShape(self, rank):
232 if self.targetted_shape:
233 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800234 return np.int32(
235 self.rng.integers(
236 low=self.args.tensor_shape_range[0],
237 high=self.args.tensor_shape_range[1],
238 size=rank,
239 )
240 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700241
242 def setTargetShape(self, shape):
243 self.targetted_shape = shape
244
245 def randInt(self, low=0, high=256):
246 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
247
248 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100249 low, high = self.getDTypeRange(dtype)
250
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100251 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100252 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100253 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100254 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100255 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100256 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
257 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 elif dtype == DType.BOOL:
259 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700260 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700261 # Special size
262 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700263
264 return np.int32(self.rng.integers(low, high, size=1))[0]
265
266 def shapeStr(self, shape):
267
268 sStr = []
269 # Convert to strings
270 for i in shape:
271 sStr.append(str(i))
272
Kevin Cheng550ccc52021-03-03 11:21:43 -0800273 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100275 def typeStr(self, dtype):
276 if isinstance(dtype, list) or isinstance(dtype, tuple):
277 assert len(dtype) >= 2
278 strs = [self.typeStr(t) for t in dtype]
279 # Limit types to the first 2 as the 3rd is the accumulator
280 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700281 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100282 if dtype in gtu.DTYPE_ATTRIBUTES:
283 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700284 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100285 raise Exception(
286 "Unknown dtype, cannot convert to string: {}".format(dtype)
287 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700288
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100289 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100290 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100291 if dtype in gtu.DTYPE_ATTRIBUTES:
292 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700293 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100294 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700295
Luke Hutton57287132023-02-06 14:54:18 +0000296 def constrictBatchSize(self, shape):
297 # Limit the batch size unless an explicit target shape set
298 if self.args.max_batch_size and not self.args.target_shapes:
299 shape[0] = min(shape[0], self.args.max_batch_size)
300 return shape
301
James Ward30124a82023-02-02 14:56:33 +0000302 def makeDimension(self):
303 return self.randInt(
304 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
305 )
306
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100307 def tensorComplianceMetaData(
308 self, op, inputType, argsDict, outputTensor, errorName
309 ):
310 if (
311 errorName
312 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
313 or not gtu.dtypeIsSupportedByCompliance(inputType)
314 ):
315 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100316 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100317
Jeremy Johnson1271c442023-09-05 11:39:26 +0100318 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100319 compliance_tens = {
320 "mode": None,
321 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
322 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
323 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100324 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
325 mode = gtu.ComplianceMode.DOT_PRODUCT
326 compliance_tens["dot_product_info"] = {
327 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100328 "ks": int(argsDict["ksb"])
329 if "ksb" in argsDict
330 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100331 }
332 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
333 mode = gtu.ComplianceMode.FP_SPECIAL
334 elif "compliance" in op and "ulp" in op["compliance"]:
335 mode = gtu.ComplianceMode.ULP
336 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
337 elif op["op"] == Op.REDUCE_PRODUCT:
338 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnson9a758382023-11-07 16:27:35 +0000339 elif op["op"] in (Op.EXP, Op.POW):
340 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnson1271c442023-09-05 11:39:26 +0100341 else:
342 mode = gtu.ComplianceMode.EXACT
343 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
344
345 return compliance_tens
346
347 # Build Op functions
348 # Create the output tensor (calling OutputShaper as needed)
349 # Do final tweaks to attributes (if necessary for errorIf)
350 # Add Op into graph
351 # Return resulting tensor information or BuildInfo
352
353 class BuildInfo:
354 """Enhanced build information containing result tensor and associated compliance dict."""
355
356 def __init__(self, resultTensor, complianceDict):
357 self.resultTensor = resultTensor
358 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700359
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000360 def build_unary(
361 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
362 ):
363 assert len(inputs) == 1
364 a = inputs[0]
365 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100366
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000367 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100368
369 # Ensure new output type has correct qinfo
370 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000371 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000372 qinfo = [
373 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000374 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000375 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100376
377 # Invalidate Input/Output list for error if checks.
378 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000379 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100380 pCount, cCount = op["operands"]
381 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000382 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
383 self, error_name, input_list, output_list
384 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100385
Les Bell729b0352021-11-24 10:28:21 +0000386 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100387 self.ser,
388 validator_fcns,
389 error_name,
390 op=op,
391 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000392 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000394 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100395 input_list=input_list,
396 output_list=output_list,
397 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000398 ):
399 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100400
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000401 attr = None
402 if op["op"] == Op.NEGATE:
403 attr = ts.TosaSerializerAttribute()
404 attr.NegateAttribute(qinfo[0], qinfo[1])
405
406 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000407
Jeremy Johnson9a758382023-11-07 16:27:35 +0000408 if op["op"] in (Op.LOG,):
409 # TODO - add compliance support LOG
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000410 compliance = None
411 else:
412 compliance = self.tensorComplianceMetaData(
413 op, a.dtype, args_dict, result_tensor, error_name
414 )
415 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700416
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000417 def build_binary_broadcast(
418 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
419 ):
420 assert len(inputs) == 2
421 a, b = inputs
422 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000423 self.ser, self.rng, a, b, error_name
424 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100425
426 # Invalidate Input/Output list for error if checks.
427 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000428 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100429 pCount, cCount = op["operands"]
430 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000431 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
432 self, error_name, input_list, output_list
433 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100434
Les Bell729b0352021-11-24 10:28:21 +0000435 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100436 self.ser,
437 validator_fcns,
438 error_name,
439 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000440 input1=a,
441 input2=b,
442 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000443 output_dtype=result_tensor.dtype,
444 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100445 input_list=input_list,
446 output_list=output_list,
447 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000448 ):
449 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100450
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000451 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000452
Jeremy Johnson9a758382023-11-07 16:27:35 +0000453 compliance = self.tensorComplianceMetaData(
454 op, a.dtype, args_dict, result_tensor, error_name
455 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000456
457 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700458
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100459 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700460 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000461 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700462 return result_tens
463
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000464 def build_arithmetic_right_shift(
465 self, op, a, b, round, validator_fcns=None, error_name=None
466 ):
467 result_tens = OutputShaper.binaryBroadcastOp(
468 self.ser, self.rng, a, b, error_name
469 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100470
471 # Invalidate Input/Output list for error if checks.
472 input_list = [a.name, b.name]
473 output_list = [result_tens.name]
474 pCount, cCount = op["operands"]
475 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000476 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
477 self, error_name, input_list, output_list
478 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100479
Les Bell729b0352021-11-24 10:28:21 +0000480 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100481 self.ser,
482 validator_fcns,
483 error_name,
484 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000485 input1=a,
486 input2=b,
487 input_dtype=a.dtype,
488 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000489 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100490 input_list=input_list,
491 output_list=output_list,
492 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000493 ):
494 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800495
496 attr = ts.TosaSerializerAttribute()
497 attr.ArithmeticRightShiftAttribute(round)
498
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000499 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800500 return result_tens
501
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100502 def build_mul(
503 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
504 ):
505 assert len(inputs) == 2
506 a, b = inputs
507 shift = args_dict["shift"]
508
509 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000510 self.ser, self.rng, a, b, error_name
511 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700512
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100513 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100514 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100515 result_tensor.setDtype(DType.INT32)
516
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100517 if error_name == ErrorIf.WrongOutputType:
518 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
519 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100520 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100521
522 # Invalidate Input/Output list for error if checks.
523 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100524 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100525 pCount, cCount = op["operands"]
526 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000527 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
528 self, error_name, input_list, output_list
529 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100530
Les Bell729b0352021-11-24 10:28:21 +0000531 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100532 self.ser,
533 validator_fcns,
534 error_name,
535 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000536 input1=a,
537 input2=b,
538 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100539 output_dtype=result_tensor.dtype,
540 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100541 input_list=input_list,
542 output_list=output_list,
543 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000544 ):
545 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700546
Kevin Chengaee1fac2020-11-11 13:54:06 -0800547 attr = ts.TosaSerializerAttribute()
548 attr.MulAttribute(shift)
549
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000550 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100551
552 compliance = self.tensorComplianceMetaData(
553 op, a.dtype, args_dict, result_tensor, error_name
554 )
555
556 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700557
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100558 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
559 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700560
Kevin Chengfe392ce2021-10-18 21:51:55 +0000561 attr = ts.TosaSerializerAttribute()
562 attr.TableAttribute(table)
563
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100564 # Invalidate Input/Output list for error if checks.
565 input_list = [a.name]
566 output_list = [result_tens.name]
567 pCount, cCount = op["operands"]
568 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000569 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
570 self, error_name, input_list, output_list
571 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100572
Les Bell729b0352021-11-24 10:28:21 +0000573 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100574 self.ser,
575 validator_fcns,
576 error_name,
577 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000578 input_shape=a.shape,
579 input_dtype=a.dtype,
580 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000581 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100582 input_list=input_list,
583 output_list=output_list,
584 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000585 ):
586 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100587
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000588 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700589
590 return result_tens
591
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100592 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
593 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
594
595 # Invalidate Input/Output list for error if checks.
596 input_list = [cond.name, a.name, b.name]
597 output_list = [result_tens.name]
598 pCount, cCount = op["operands"]
599 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000600 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
601 self, error_name, input_list, output_list
602 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100603
Les Bell729b0352021-11-24 10:28:21 +0000604 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100605 self.ser,
606 validator_fcns,
607 error_name,
608 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000609 input1=cond,
610 input2=a,
611 input3=b,
612 input_shape=a.shape,
613 input_dtype=a.dtype,
614 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000615 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100616 input_list=input_list,
617 output_list=output_list,
618 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000619 ):
620 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100621
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000622 self.ser.addOperator(
623 op["op"],
624 input_list,
625 output_list,
626 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700627 return result_tens
628
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100629 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000630 result_tens = OutputShaper.binaryComparisonOp(
631 self.ser, self.rng, a, b, error_name
632 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100633
634 # Invalidate Input/Output list for error if checks.
635 input_list = [a.name, b.name]
636 output_list = [result_tens.name]
637 pCount, cCount = op["operands"]
638 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000639 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
640 self, error_name, input_list, output_list
641 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100642
Les Bell729b0352021-11-24 10:28:21 +0000643 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100644 self.ser,
645 validator_fcns,
646 error_name,
647 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000648 input1=a,
649 input2=b,
650 input_shape=a.shape,
651 input_dtype=a.dtype,
652 output_shape=result_tens.shape,
653 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000654 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100655 input_list=input_list,
656 output_list=output_list,
657 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000658 ):
659 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100660
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000661 self.ser.addOperator(
662 op["op"],
663 input_list,
664 output_list,
665 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700666 return result_tens
667
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000668 def build_argmax(
669 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
670 ):
671 assert len(inputs) == 1
672 a = inputs[0]
673 axis = args_dict["axis"]
674 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100675
676 # Invalidate Input/Output list for error if checks.
677 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000678 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100679 pCount, cCount = op["operands"]
680 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000681 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
682 self, error_name, input_list, output_list
683 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100684
Les Bell729b0352021-11-24 10:28:21 +0000685 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100686 self.ser,
687 validator_fcns,
688 error_name,
689 op=op,
690 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000691 input_shape=a.shape,
692 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000693 output_shape=result_tensor.shape,
694 output_dtype=result_tensor.dtype,
695 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100696 input_list=input_list,
697 output_list=output_list,
698 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000699 ):
700 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700701
702 attr = ts.TosaSerializerAttribute()
703 attr.AxisAttribute(axis)
704
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000705 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000706
707 compliance = self.tensorComplianceMetaData(
708 op, inputs[0].dtype, args_dict, result_tensor, error_name
709 )
710 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700711
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000712 def build_pool2d(
713 self,
714 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100715 inputs,
716 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000717 validator_fcns=None,
718 error_name=None,
719 qinfo=None,
720 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100721 assert len(inputs) == 1
722 input = inputs[0]
723 # max_pool has no accum_dtype
724 accum_dtype = (
725 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
726 )
727 stride = args_dict["stride"]
728 pad = args_dict["pad"]
729 kernel = args_dict["kernel"]
730
Jeremy Johnson0601f802023-11-08 16:28:09 +0000731 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000732 self.ser, self.rng, input, kernel, stride, pad, error_name
733 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100734
735 # Ensure new output type has correct qinfo
736 if error_name == ErrorIf.WrongInputType:
737 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000738 qinfo = [
739 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000740 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000741 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100742
743 # Invalidate Input/Output list for error if checks.
744 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000745 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100746 pCount, cCount = op["operands"]
747 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000748 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
749 self, error_name, input_list, output_list
750 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100751
Les Bell729b0352021-11-24 10:28:21 +0000752 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100753 self.ser,
754 validator_fcns,
755 error_name,
756 op=op,
757 input_shape=input.shape,
758 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000759 output_shape=result_tensor.shape,
760 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100761 kernel=kernel,
762 stride=stride,
763 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000764 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000765 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100766 input_list=input_list,
767 output_list=output_list,
768 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000769 ):
770 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700771
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000772 if qinfo is None:
773 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700774
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000775 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100776 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000777
778 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700779
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100780 compliance = self.tensorComplianceMetaData(
781 op, inputs[0].dtype, args_dict, result_tensor, error_name
782 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100783
784 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100785
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000786 def build_conv2d(
787 self,
788 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100789 inputs,
790 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000791 validator_fcns=None,
792 error_name=None,
793 qinfo=None,
794 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100795 assert len(inputs) == 3
796 ifm, filter, bias = inputs
797 accum_dtype = args_dict["acc_type"]
798 strides = args_dict["stride"]
799 padding = args_dict["pad"]
800 dilations = args_dict["dilation"]
801
Kevin Cheng550ccc52021-03-03 11:21:43 -0800802 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100803 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100804 self.ser,
805 self.rng,
806 ifm,
807 filter,
808 accum_dtype,
809 strides,
810 padding,
811 dilations,
812 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000813 )
814
815 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000816 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
817 DType.INT8,
818 DType.UINT8,
819 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000820 qinfo = [
821 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100822 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000823 ]
Les Bell0e027d42021-11-09 14:42:14 +0000824
825 # Invalidate Input/Output list for error_if checks.
826 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100827 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000828 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000829 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
830 self, error_name, input_list, output_list
831 )
Les Bell0e027d42021-11-09 14:42:14 +0000832
Les Bell729b0352021-11-24 10:28:21 +0000833 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000834 self.ser,
835 validator_fcns,
836 error_name,
837 op=op,
838 input_dtype=ifm.dtype,
839 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100840 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000841 qinfo=qinfo,
842 input_list=input_list,
843 num_operands=num_operands,
844 output_list=output_list,
845 pad=padding,
846 stride=strides,
847 dilation=dilations,
848 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100849 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100850 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000851 ):
852 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700853
Tai Lyd3797f02023-11-15 23:06:19 +0000854 # TODO - Test local_bound, for now set local bound attribute to False
855 local_bound = False
856
Eric Kunzee5e26762020-10-13 16:11:07 -0700857 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000858 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700859
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000860 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100861
862 compliance = self.tensorComplianceMetaData(
863 op, ifm.dtype, args_dict, result_tensor, error_name
864 )
865
866 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700867
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000868 def build_conv3d(
869 self,
870 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100871 inputs,
872 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000873 validator_fcns=None,
874 error_name=None,
875 qinfo=None,
876 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100877 assert len(inputs) == 3
878 ifm, filter, bias = inputs
879 accum_dtype = args_dict["acc_type"]
880 strides = args_dict["stride"]
881 padding = args_dict["pad"]
882 dilations = args_dict["dilation"]
883
Kevin Cheng1533b852021-09-01 12:51:58 -0700884 assert len(padding) == 6
885 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100886 self.ser,
887 self.rng,
888 ifm,
889 filter,
890 accum_dtype,
891 strides,
892 padding,
893 dilations,
894 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000895 )
896
897 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000898 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
899 DType.INT8,
900 DType.UINT8,
901 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000902 qinfo = [
903 TosaQuantGen.getZeroPoint(self, ifm.dtype),
904 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
905 ]
Les Bell0e027d42021-11-09 14:42:14 +0000906
907 # Invalidate Input/Output list for error_if checks.
908 input_list = [ifm.name, filter.name, bias.name]
909 output_list = [result_tens.name]
910 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000911 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
912 self, error_name, input_list, output_list
913 )
Les Bell0e027d42021-11-09 14:42:14 +0000914
Les Bell729b0352021-11-24 10:28:21 +0000915 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000916 self.ser,
917 validator_fcns,
918 error_name,
919 op=op,
920 input_dtype=ifm.dtype,
921 weight_dtype=filter.dtype,
922 output_dtype=result_tens.dtype,
923 qinfo=qinfo,
924 input_list=input_list,
925 num_operands=num_operands,
926 output_list=output_list,
927 pad=padding,
928 stride=strides,
929 dilation=dilations,
930 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100931 weight_shape=filter.shape,
932 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000933 ):
934 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700935
Tai Lyd3797f02023-11-15 23:06:19 +0000936 # TODO - Test local_bound, for now set local bound attribute to False
937 local_bound = False
938
Kevin Cheng1533b852021-09-01 12:51:58 -0700939 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000940 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700941
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000942 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700943 return result_tens
944
Kevin Cheng550ccc52021-03-03 11:21:43 -0800945 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000946 self,
947 op,
948 ifm,
949 filter,
950 bias,
James Ward8b390432022-08-12 20:48:56 +0100951 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000952 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700953 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000954 output_shape,
955 validator_fcns=None,
956 error_name=None,
957 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800958 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700959 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000960 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100961 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000962 )
Les Bell0e027d42021-11-09 14:42:14 +0000963
964 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000965 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
966 DType.INT8,
967 DType.UINT8,
968 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000969 qinfo = [
970 TosaQuantGen.getZeroPoint(self, ifm.dtype),
971 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
972 ]
Les Bell0e027d42021-11-09 14:42:14 +0000973
974 # Invalidate Input/Output list for error_if checks.
975 input_list = [ifm.name, filter.name, bias.name]
976 output_list = [result_tens.name]
977 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000978 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
979 self, error_name, input_list, output_list
980 )
Les Bell0e027d42021-11-09 14:42:14 +0000981
Les Bell729b0352021-11-24 10:28:21 +0000982 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000983 self.ser,
984 validator_fcns,
985 error_name,
986 op=op,
987 input_dtype=ifm.dtype,
988 weight_dtype=filter.dtype,
989 output_dtype=result_tens.dtype,
990 qinfo=qinfo,
991 input_list=input_list,
992 num_operands=num_operands,
993 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700994 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000995 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000996 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100997 weight_shape=filter.shape,
998 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000999 ):
1000 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001001
Tai Lyd3797f02023-11-15 23:06:19 +00001002 # TODO - Test local_bound, for now set local bound attribute to False
1003 local_bound = False
1004
Eric Kunzee5e26762020-10-13 16:11:07 -07001005 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001006 attr.TransposeConvAttribute(
1007 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1008 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001009
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001010 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001011 return result_tens
1012
Kevin Cheng550ccc52021-03-03 11:21:43 -08001013 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001014 self,
1015 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001016 inputs,
1017 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001018 validator_fcns=None,
1019 error_name=None,
1020 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001021 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001022 assert len(inputs) == 3
1023 ifm, filter, bias = inputs
1024 accum_dtype = args_dict["acc_type"]
1025 strides = args_dict["stride"]
1026 padding = args_dict["pad"]
1027 dilations = args_dict["dilation"]
1028
Kevin Cheng550ccc52021-03-03 11:21:43 -08001029 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001030 self.ser,
1031 self.rng,
1032 ifm,
1033 filter,
1034 accum_dtype,
1035 strides,
1036 padding,
1037 dilations,
1038 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001039 )
1040
1041 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001042 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1043 DType.INT8,
1044 DType.UINT8,
1045 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001046 qinfo = [
1047 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1048 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1049 ]
Les Bell0e027d42021-11-09 14:42:14 +00001050
1051 # Invalidate Input/Output list for error_if checks.
1052 input_list = [ifm.name, filter.name, bias.name]
1053 output_list = [result_tens.name]
1054 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001055 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1056 self, error_name, input_list, output_list
1057 )
Les Bell0e027d42021-11-09 14:42:14 +00001058
Les Bell729b0352021-11-24 10:28:21 +00001059 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001060 self.ser,
1061 validator_fcns,
1062 error_name,
1063 op=op,
1064 input_dtype=ifm.dtype,
1065 weight_dtype=filter.dtype,
1066 output_dtype=result_tens.dtype,
1067 qinfo=qinfo,
1068 input_list=input_list,
1069 num_operands=num_operands,
1070 output_list=output_list,
1071 pad=padding,
1072 stride=strides,
1073 dilation=dilations,
1074 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001075 weight_shape=filter.shape,
1076 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001077 ):
1078 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001079
Tai Lyd3797f02023-11-15 23:06:19 +00001080 # TODO - Test local_bound, for now set local bound attribute to False
1081 local_bound = False
1082
Eric Kunzee5e26762020-10-13 16:11:07 -07001083 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001084 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001085
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001086 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001087 return result_tens
1088
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001089 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001090 self,
1091 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001092 inputs,
1093 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001094 validator_fcns=None,
1095 error_name=None,
1096 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001097 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001098 assert len(inputs) == 3
1099 ifm, filter, bias = inputs
1100 accum_dtype = args_dict["acc_type"]
1101
1102 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001103 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001104 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001105
1106 # Invalidate Input/Output list for error if checks.
1107 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001108 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001109 pCount, cCount = op["operands"]
1110 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001111 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1112 self, error_name, input_list, output_list
1113 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001114
Les Bell729b0352021-11-24 10:28:21 +00001115 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001116 self.ser,
1117 validator_fcns,
1118 error_name,
1119 op=op,
1120 input_shape=ifm.shape,
1121 input_dtype=ifm.dtype,
1122 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001123 output_shape=result_tensor.shape,
1124 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001125 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001126 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001127 input_list=input_list,
1128 output_list=output_list,
1129 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001130 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001131 ):
1132 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001133
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001134 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001135 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001136
1137 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001138
1139 compliance = self.tensorComplianceMetaData(
1140 op, ifm.dtype, args_dict, result_tensor, error_name
1141 )
1142
1143 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001144
James Ward8b390432022-08-12 20:48:56 +01001145 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001146 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001147 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001148 assert len(inputs) == 2
1149 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001150 accum_dtype = args_dict["acc_type"]
1151 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001152 self.ser, self.rng, a, b, accum_dtype, error_name
1153 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001154
1155 # Invalidate Input/Output list for error if checks.
1156 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001157 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001158 pCount, cCount = op["operands"]
1159 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001160 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1161 self, error_name, input_list, output_list
1162 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001163
Les Bell729b0352021-11-24 10:28:21 +00001164 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001165 self.ser,
1166 validator_fcns,
1167 error_name,
1168 op=op,
1169 input_shape=a.shape,
1170 input_dtype=a.dtype,
1171 input2_shape=b.shape,
1172 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001173 output_shape=result_tensor.shape,
1174 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001175 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001176 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001177 input_list=input_list,
1178 output_list=output_list,
1179 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001180 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001181 ):
1182 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001183
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001184 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001185 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001186
1187 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001188
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001189 compliance = self.tensorComplianceMetaData(
1190 op, a.dtype, args_dict, result_tensor, error_name
1191 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001192
1193 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001194
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001195 def build_reduce(
1196 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1197 ):
1198 assert len(inputs) == 1
1199 a = inputs[0]
1200 axis = args_dict["axis"]
1201 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001202
1203 # Invalidate Input/Output list for error if checks.
1204 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001205 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001206 pCount, cCount = op["operands"]
1207 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001208 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1209 self, error_name, input_list, output_list
1210 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001211
Les Bell729b0352021-11-24 10:28:21 +00001212 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001213 self.ser,
1214 validator_fcns,
1215 error_name,
1216 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001217 axis=axis,
1218 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001219 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001220 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001221 output_dtype=result_tensor.dtype,
1222 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001223 input_list=input_list,
1224 output_list=output_list,
1225 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001226 ):
1227 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001228
1229 attr = ts.TosaSerializerAttribute()
1230 attr.AxisAttribute(axis)
1231
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001232 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001233
1234 if op["op"] == Op.REDUCE_PRODUCT:
1235 # TODO: Add compliance support!
1236 compliance = None
1237 else:
1238 compliance = self.tensorComplianceMetaData(
1239 op, a.dtype, args_dict, result_tensor, error_name
1240 )
1241
1242 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001243
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001244 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1245 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001246
Jeremy Johnson18e26662021-07-22 16:15:29 +01001247 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001248
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001249 if error_name == ErrorIf.MaxSmallerMin:
1250 # Make sure the numbers are different to invoke this error
1251 while v[0] == v[1]:
1252 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1253 max_val = min(v)
1254 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001255 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001256 max_val = max(v)
1257 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001258
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001259 # Invalidate Input/Output list for error if checks.
1260 input_list = [a.name]
1261 output_list = [result_tens.name]
1262 pCount, cCount = op["operands"]
1263 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001264 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1265 self, error_name, input_list, output_list
1266 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001267
Les Bell729b0352021-11-24 10:28:21 +00001268 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001269 self.ser,
1270 validator_fcns,
1271 error_name,
1272 op=op,
1273 max_val=max_val,
1274 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001275 input_shape=a.shape,
1276 output_shape=result_tens.shape,
1277 input_dtype=a.dtype,
1278 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001279 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001280 input_list=input_list,
1281 output_list=output_list,
1282 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001283 ):
1284 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001285
1286 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001287 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1288 if a.dtype == DType.FP16:
1289 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1290 min_val = min_val.astype(np.float32)
1291 max_val = max_val.astype(np.float32)
1292
1293 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001294 else:
James Ward34071252022-12-07 15:48:47 +00001295 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001296
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001297 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001298 return result_tens
1299
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001300 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1301 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001302 attr = ts.TosaSerializerAttribute()
1303
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001304 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001305
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001306 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001307 return result_tens
1308
1309 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001310 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1311 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001312
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001313 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001314 return result_tens
1315
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001316 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1317 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1318
1319 # Invalidate Input/Output list for error if checks.
1320 input_list = [a.name]
1321 output_list = [result_tens.name]
1322 pCount, cCount = op["operands"]
1323 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001324 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1325 self, error_name, input_list, output_list
1326 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001327
Les Bell729b0352021-11-24 10:28:21 +00001328 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001329 self.ser,
1330 validator_fcns,
1331 error_name,
1332 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001333 input_shape=a.shape,
1334 output_shape=result_tens.shape,
1335 input_dtype=a.dtype,
1336 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001337 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001338 input_list=input_list,
1339 output_list=output_list,
1340 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001341 ):
1342 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001343
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001344 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001345 return result_tens
1346
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001347 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1348 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1349
1350 # Invalidate Input/Output list for error if checks.
1351 input_list = [a.name]
1352 output_list = [result_tens.name]
1353 pCount, cCount = op["operands"]
1354 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001355 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1356 self, error_name, input_list, output_list
1357 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001358
Les Bell729b0352021-11-24 10:28:21 +00001359 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001360 self.ser,
1361 validator_fcns,
1362 error_name,
1363 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001364 input_shape=a.shape,
1365 output_shape=result_tens.shape,
1366 input_dtype=a.dtype,
1367 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001368 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001369 input_list=input_list,
1370 output_list=output_list,
1371 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001372 ):
1373 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001375 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001376 return result_tens
1377
Won Jeon78155c62023-06-10 00:20:04 +00001378 def build_erf(self, op, a, validator_fcns=None, error_name=None):
1379 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1380
1381 # Invalidate Input/Output list for error if checks.
1382 input_list = [a.name]
1383 output_list = [result_tens.name]
1384 pCount, cCount = op["operands"]
1385 num_operands = pCount + cCount
1386 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1387 self, error_name, input_list, output_list
1388 )
1389
1390 if not TosaErrorValidator.evValidateErrorIfs(
1391 self.ser,
1392 validator_fcns,
1393 error_name,
1394 op=op,
1395 input_shape=a.shape,
1396 output_shape=result_tens.shape,
1397 input_dtype=a.dtype,
1398 output_dtype=result_tens.dtype,
1399 result_tensors=[result_tens],
1400 input_list=input_list,
1401 output_list=output_list,
1402 num_operands=num_operands,
1403 ):
1404 return None
1405
1406 self.ser.addOperator(op["op"], input_list, output_list)
1407 return result_tens
1408
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001409 def build_concat(
1410 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1411 ):
1412 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001413 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001414 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001415
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001416 result_tensor = OutputShaper.concatOp(
1417 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001418 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001419
Matthew Haddon818ab902021-07-27 09:12:49 +01001420 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001421 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001422 input_tensor_names.append(tensor.name)
1423
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424 # Invalidate Input/Output list for error if checks.
1425 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001426 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001427 pCount, cCount = op["operands"]
1428 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001429 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1430 self, error_name, input_list, output_list
1431 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432
Les Bell729b0352021-11-24 10:28:21 +00001433 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001434 self.ser,
1435 validator_fcns,
1436 error_name,
1437 op=op,
1438 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001439 input_shape=inputs[0].shape,
1440 output_shape=result_tensor.shape,
1441 input_dtype=inputs[0].dtype,
1442 output_dtype=result_tensor.dtype,
1443 inputs=inputs,
1444 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001445 input_list=input_list,
1446 output_list=output_list,
1447 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001448 ):
1449 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001450
1451 attr = ts.TosaSerializerAttribute()
1452 attr.AxisAttribute(axis)
1453
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001454 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001455 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001456
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001457 def build_pad(
1458 self,
1459 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001460 inputs,
1461 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001462 validator_fcns=None,
1463 error_name=None,
1464 qinfo=None,
1465 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001466 assert len(inputs) == 1
1467 a = inputs[0]
1468 padding = args_dict["pad"]
1469 pad_const_int = args_dict["pad_const_int"]
1470 pad_const_float = args_dict["pad_const_fp"]
1471
1472 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001473
Kevin Chengfe392ce2021-10-18 21:51:55 +00001474 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001475 attr.PadAttribute(
1476 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1477 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001478
Matthew Haddone807aae2021-10-11 18:12:58 +01001479 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001480 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001481 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001482 pCount, cCount = op["operands"]
1483 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001484 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1485 self, error_name, input_list, output_list
1486 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001487
Les Bell729b0352021-11-24 10:28:21 +00001488 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001489 self.ser,
1490 validator_fcns,
1491 error_name,
1492 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001493 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001494 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001495 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001496 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001497 pad=padding,
1498 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001499 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001500 input_list=input_list,
1501 output_list=output_list,
1502 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001503 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001504 ):
1505 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001506
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001507 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001508
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001509 compliance = self.tensorComplianceMetaData(
1510 op, a.dtype, args_dict, result_tensor, error_name
1511 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001512
1513 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001514
Won Jeona21b2e82023-08-10 10:33:01 +00001515 def build_dim(
1516 self,
1517 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001518 inputs,
1519 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001520 validator_fcns=None,
1521 error_name=None,
1522 qinfo=None,
1523 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001524 assert len(inputs) == 1
1525 a = inputs[0]
1526 axis = args_dict["axis"]
1527 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001528
1529 # Invalidate Input/Output list for error if checks.
1530 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001531 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001532 pCount, cCount = op["operands"]
1533 num_operands = pCount + cCount
1534 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1535 self, error_name, input_list, output_list
1536 )
1537
1538 if not TosaErrorValidator.evValidateErrorIfs(
1539 self.ser,
1540 validator_fcns,
1541 error_name,
1542 op=op,
1543 axis=axis,
1544 input_shape=a.shape,
1545 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001546 output_shape=result_tensor.shape,
1547 output_dtype=result_tensor.dtype,
1548 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001549 input_list=input_list,
1550 output_list=output_list,
1551 num_operands=num_operands,
1552 ):
1553 return None
1554
1555 attr = ts.TosaSerializerAttribute()
1556 attr.AxisAttribute(axis)
1557
1558 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001559 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001560
Matthew Haddone807aae2021-10-11 18:12:58 +01001561 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001562 result_tens = OutputShaper.reshapeOp(
1563 self.ser, self.rng, a, newShape, error_name
1564 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001565
1566 # Invalidate Input/Output list for error if checks.
1567 input_list = [a.name]
1568 output_list = [result_tens.name]
1569 pCount, cCount = op["operands"]
1570 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001571 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1572 self, error_name, input_list, output_list
1573 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001574
Les Bell729b0352021-11-24 10:28:21 +00001575 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001576 self.ser,
1577 validator_fcns,
1578 error_name,
1579 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001580 input_shape=a.shape,
1581 output_shape=result_tens.shape,
1582 input_dtype=a.dtype,
1583 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001584 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001585 input_list=input_list,
1586 output_list=output_list,
1587 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001588 ):
1589 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001590
1591 attr = ts.TosaSerializerAttribute()
1592 attr.ReshapeAttribute(newShape)
1593
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001594 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001595 return result_tens
1596
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001597 def build_reverse(
1598 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1599 ):
1600 assert len(inputs) == 1
1601 a = inputs[0]
1602 axis = args_dict["axis"]
1603 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001604
1605 # Invalidate Input/Output list for error if checks.
1606 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001607 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001608 pCount, cCount = op["operands"]
1609 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001610 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1611 self, error_name, input_list, output_list
1612 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001613
Les Bell729b0352021-11-24 10:28:21 +00001614 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001615 self.ser,
1616 validator_fcns,
1617 error_name,
1618 op=op,
1619 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001620 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001621 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001622 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001623 output_dtype=result_tensor.dtype,
1624 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001625 input_list=input_list,
1626 output_list=output_list,
1627 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001628 ):
1629 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001630
1631 attr = ts.TosaSerializerAttribute()
1632 attr.AxisAttribute(axis)
1633
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001634 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001635 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001636
Matthew Haddone807aae2021-10-11 18:12:58 +01001637 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1638 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001639
Kevin Chengfe392ce2021-10-18 21:51:55 +00001640 attr = ts.TosaSerializerAttribute()
1641 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001642
Matthew Haddone807aae2021-10-11 18:12:58 +01001643 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001644 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001645 output_list = [result_tens.name]
1646 pCount, cCount = op["operands"]
1647 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001648 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1649 self, error_name, input_list, output_list
1650 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001651
Les Bell729b0352021-11-24 10:28:21 +00001652 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001653 self.ser,
1654 validator_fcns,
1655 error_name,
1656 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001657 input_shape=a.shape,
1658 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001659 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001660 input_dtype=a.dtype,
1661 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001662 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001663 input_list=input_list,
1664 output_list=output_list,
1665 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001666 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001667 ):
1668 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001669
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001670 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001671 return result_tens
1672
Matthew Haddone807aae2021-10-11 18:12:58 +01001673 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001674 result_tens = OutputShaper.sliceOp(
1675 self.ser, self.rng, a, start, size, error_name
1676 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001677
1678 # Invalidate Input/Output list for error if checks.
1679 input_list = [a.name]
1680 output_list = [result_tens.name]
1681 pCount, cCount = op["operands"]
1682 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001683 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1684 self, error_name, input_list, output_list
1685 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001686
Les Bell729b0352021-11-24 10:28:21 +00001687 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001688 self.ser,
1689 validator_fcns,
1690 error_name,
1691 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001692 input_shape=a.shape,
1693 output_shape=result_tens.shape,
1694 input_dtype=a.dtype,
1695 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001696 start=start,
1697 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001698 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001699 input_list=input_list,
1700 output_list=output_list,
1701 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001702 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001703 ):
1704 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001705
1706 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001707 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001708
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001709 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001710 return result_tens
1711
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001712 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1713 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1714
1715 # Invalidate Input/Output list for error if checks.
1716 input_list = [a.name]
1717 output_list = [result_tens.name]
1718 pCount, cCount = op["operands"]
1719 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001720 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1721 self, error_name, input_list, output_list
1722 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001723
Les Bell729b0352021-11-24 10:28:21 +00001724 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001725 self.ser,
1726 validator_fcns,
1727 error_name,
1728 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001729 input_shape=a.shape,
1730 output_shape=result_tens.shape,
1731 input_dtype=a.dtype,
1732 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001733 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001734 input_list=input_list,
1735 output_list=output_list,
1736 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001737 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001738 ):
1739 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001740
1741 attr = ts.TosaSerializerAttribute()
1742 attr.TileAttribute(multiples)
1743
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001744 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001745 return result_tens
1746
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001747 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001748
1749 # Create a new indicies tensor
1750 # here with data that doesn't exceed the dimensions of the values tensor
1751
Kevin Cheng550ccc52021-03-03 11:21:43 -08001752 K = values.shape[1] # K
1753 W = self.randInt(
1754 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1755 ) # W
1756 indicies_arr = np.int32(
1757 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1758 ) # (N, W)
1759 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001760
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001761 result_tens = OutputShaper.gatherOp(
1762 self.ser, self.rng, values, indicies, error_name
1763 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001764
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001765 # Invalidate Input/Output list for error if checks.
1766 input_list = [values.name, indicies.name]
1767 output_list = [result_tens.name]
1768 pCount, cCount = op["operands"]
1769 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001770 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1771 self, error_name, input_list, output_list
1772 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001773
Les Bell729b0352021-11-24 10:28:21 +00001774 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001775 self.ser,
1776 validator_fcns,
1777 error_name,
1778 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001779 input_shape=values.shape,
1780 output_shape=result_tens.shape,
1781 input_dtype=values.dtype,
1782 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001783 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001784 input_list=input_list,
1785 output_list=output_list,
1786 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001787 ):
1788 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001789
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001790 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001791
1792 return result_tens
1793
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001794 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001795
1796 # Create a new indicies tensor
1797 # here with data that doesn't exceed the dimensions of the values_in tensor
1798
Kevin Cheng550ccc52021-03-03 11:21:43 -08001799 K = values_in.shape[1] # K
1800 W = input.shape[1] # W
1801 indicies_arr = np.int32(
1802 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1803 ) # (N, W)
1804 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001805
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001806 result_tens = OutputShaper.scatterOp(
1807 self.ser, self.rng, values_in, indicies, input, error_name
1808 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001809
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001810 # Invalidate Input/Output list for error if checks.
1811 input_list = [values_in.name, indicies.name, input.name]
1812 output_list = [result_tens.name]
1813 pCount, cCount = op["operands"]
1814 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001815 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1816 self, error_name, input_list, output_list
1817 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001818
Les Bell729b0352021-11-24 10:28:21 +00001819 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001820 self.ser,
1821 validator_fcns,
1822 error_name,
1823 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001824 input_shape=values_in.shape,
1825 output_shape=result_tens.shape,
1826 input_dtype=values_in.dtype,
1827 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001828 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001829 input_list=input_list,
1830 output_list=output_list,
1831 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001832 ):
1833 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001834
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001835 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001836
Kevin Cheng77d0f762020-11-24 10:26:32 -08001837 return result_tens
1838
Kevin Cheng550ccc52021-03-03 11:21:43 -08001839 def build_resize(
1840 self,
1841 op,
1842 input,
1843 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001844 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001845 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001846 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001847 input_dtype,
1848 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001849 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001850 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001851 ):
1852 result_tens = OutputShaper.resizeOp(
1853 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001854 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001855 input,
1856 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001857 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001858 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001859 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001860 input_dtype,
1861 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001862 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001863 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001864
Matthew Haddon848efb42021-09-09 12:30:53 +01001865 # Invalidate Input/Output list for error if checks.
1866 input_list = [input.name]
1867 output_list = [result_tens.name]
1868 pCount, cCount = op["operands"]
1869 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001870 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1871 self, error_name, input_list, output_list
1872 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001873
Les Bell729b0352021-11-24 10:28:21 +00001874 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001875 self.ser,
1876 validator_fcns,
1877 error_name,
1878 op=op,
1879 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001880 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001881 input_dtype=input_dtype,
1882 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001883 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001884 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001885 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001886 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001887 input_list=input_list,
1888 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001889 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001890 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001891 ):
1892 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001893
Eric Kunzee5e26762020-10-13 16:11:07 -07001894 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001895
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001896 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001897
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001898 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001899 return result_tens
1900
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001901 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1902 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1903 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001904 self.ser.addOperator(
1905 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1906 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001907 return result_tens
1908
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001909 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001910 self.ser.addOutputTensor(val)
1911 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001912
1913 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001914 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001915 result_tens = OutputShaper.typeConversionOp(
1916 self.ser, self.rng, val, out_dtype, error_name
1917 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001918
1919 # Invalidate Input/Output list for error if checks.
1920 input_list = [val.name]
1921 output_list = [result_tens.name]
1922 pCount, cCount = op["operands"]
1923 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001924 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1925 self, error_name, input_list, output_list
1926 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001927
Les Bell729b0352021-11-24 10:28:21 +00001928 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001929 self.ser,
1930 validator_fcns,
1931 error_name,
1932 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001933 input_shape=val.shape,
1934 output_shape=result_tens.shape,
1935 input_dtype=val.dtype,
1936 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001937 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001938 input_list=input_list,
1939 output_list=output_list,
1940 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001941 ):
1942 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001943
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001944 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001945 return result_tens
1946
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001947 def build_rescale(
1948 self,
1949 op,
1950 val,
1951 out_dtype,
1952 scale32,
1953 double_round,
1954 per_channel,
1955 validator_fcns,
1956 error_name,
1957 ):
1958 result_tens = OutputShaper.typeConversionOp(
1959 self.ser, self.rng, val, out_dtype, error_name
1960 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001961
1962 if per_channel:
1963 nc = val.shape[-1]
1964 else:
1965 nc = 1
1966
1967 in_type_width = self.typeWidth(val.dtype)
1968 out_type_width = self.typeWidth(out_dtype)
1969
Kevin Cheng3a478572021-01-22 17:21:02 -08001970 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001971 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001972 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001973 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001974 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001975 in_type_width += 1
1976 elif error_name in [
1977 ErrorIf.InputZeroPointNotZero,
1978 ErrorIf.U16InputZeroPointNotValid,
1979 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001980 input_zp = self.randInt(-128, 128)
1981 if input_zp == 0:
1982 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001983 in_type_width += 1
1984 elif val.dtype == DType.UINT16:
1985 # Must come after ErrorIf.U16InputZeroPointNotValid check
1986 input_zp = self.rng.choice([0, 32768])
1987 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001988 else:
1989 input_zp = 0
1990
Kevin Cheng3a478572021-01-22 17:21:02 -08001991 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001992 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001993 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001994 elif out_dtype == DType.UINT8:
1995 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001996 out_type_width += 1
1997 elif error_name in [
1998 ErrorIf.OutputZeroPointNotZero,
1999 ErrorIf.U16OutputZeroPointNotValid,
2000 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002001 output_zp = self.randInt(-128, 128)
2002 if output_zp == 0:
2003 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002004 out_type_width += 1
2005 elif out_dtype == DType.UINT16:
2006 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2007 output_zp = self.rng.choice([0, 32768])
2008 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002009 else:
2010 output_zp = 0
2011
2012 # Calculate scale based on:
2013 # scale = a *(2^output_width)/(2^input_width))
2014
2015 a = np.float32(self.rng.random(size=[nc]))
2016 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2017
2018 if scale32:
2019 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002020 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002021 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2022 else:
2023 # Cap the scaling at 2^15 - 1 for scale16
2024 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2025
Kevin Cheng550ccc52021-03-03 11:21:43 -08002026 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002027
2028 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2029 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002030 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2031 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002032
2033 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002034 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2035 scale_arr[i], scale32
2036 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002037 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2038 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002039
Kevin Cheng550ccc52021-03-03 11:21:43 -08002040 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002041 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002042 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002043 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002044 assert val.placeholderFilename
2045 values = np.load(
2046 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2047 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002048 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2049 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2050 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2051 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002052 if not np.all(np.array_equal(values, val_adj)):
2053 # Values changed so overwrite file with new values
2054 np.save(
2055 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2056 val_adj,
2057 False,
2058 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002059
Matthew Haddonc2025212021-10-08 21:21:05 +01002060 # Invalidate Input/Output list for error if checks.
2061 input_list = [val.name]
2062 output_list = [result_tens.name]
2063 pCount, cCount = op["operands"]
2064 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002065 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2066 self, error_name, input_list, output_list
2067 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002068
2069 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002070 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002071 self.ser,
2072 validator_fcns,
2073 error_name,
2074 op=op,
2075 input_dtype=val.dtype,
2076 output_dtype=out_dtype,
2077 input_shape=val.shape,
2078 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002079 scale32=scale32,
2080 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002081 input_list=input_list,
2082 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002083 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002084 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002085 ):
2086 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002087
Eric Kunzee5e26762020-10-13 16:11:07 -07002088 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002089 attr.RescaleAttribute(
2090 input_zp,
2091 output_zp,
2092 multiplier_arr,
2093 shift_arr,
2094 scale32,
2095 double_round,
2096 per_channel,
2097 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002098
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002099 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002100 return result_tens
2101
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002102 def _get_condition_tensor(self, op, cond, error_name):
2103 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002104 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002105 else:
2106 cond_type = DType.BOOL
2107 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2108 choice = self.rng.choice([1, 2])
2109 if choice == 1:
2110 cond_shape = [2]
2111 else:
2112 cond_shape = [1, 2]
2113 else:
2114 # Must be of size 1 (rank 0)
2115 cond_shape = []
2116 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2117 return cond_tens
2118
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002119 def build_cond_if_const(
2120 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2121 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002122 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002123 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002124 # and fill them with const nodes for the body.
2125
2126 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002127 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002128
2129 # Make then/else tensors
2130 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002131
2132 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002133 if error_name in [
2134 ErrorIf.CondIfOutputListThenGraphMismatch,
2135 ErrorIf.CondIfOutputListElseGraphMismatch,
2136 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002137 incorrect_shape = deepcopy(then_tens.shape)
2138 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002139 incorrect_shape[i] += (
2140 self.rng.choice([-3, -2, 2, 3])
2141 if incorrect_shape[i] > 3
2142 else self.rng.choice([1, 2, 4])
2143 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002144 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2145
Jeremy Johnson18e26662021-07-22 16:15:29 +01002146 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2147 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002148
2149 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002150 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002151
2152 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002153 then_block = "THEN_BLOCK"
2154 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002155 attr = ts.TosaSerializerAttribute()
2156 attr.CondIfAttribute(then_block, else_block)
2157
2158 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002159 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002160
Jerry Ge9e94af82022-10-27 09:57:00 -07002161 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002162 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002163 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2164 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2165 else:
2166 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002167 self.ser.addOutputTensor(then_tens)
2168
Jerry Ge9e94af82022-10-27 09:57:00 -07002169 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002170 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2171 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2172 else:
2173 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002174 self.ser.addOutputTensor(else_tens)
2175
Les Bell729b0352021-11-24 10:28:21 +00002176 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002177 self.ser,
2178 validator_fcns,
2179 error_name,
2180 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002181 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002182 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002183 ):
2184 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002185
Eric Kunzee5e26762020-10-13 16:11:07 -07002186 return result_tens
2187
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002188 def build_cond_if_binary(
2189 self, op, a, b, cond, validator_fcns=None, error_name=None
2190 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002191 # For cond_if with a binary op in the then/else blocks, take a and b and
2192 # alternately add or subtract them based on the condition
2193
2194 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002195 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002196
Kevin Cheng550ccc52021-03-03 11:21:43 -08002197 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002198
2199 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002200 then_block = "THEN_BLOCK"
2201 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002202 attr = ts.TosaSerializerAttribute()
2203 attr.CondIfAttribute(then_block, else_block)
2204
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002205 if error_name in [
2206 ErrorIf.CondIfInputListThenGraphMismatch,
2207 ErrorIf.CondIfInputListElseGraphMismatch,
2208 ErrorIf.CondIfOutputListElseGraphMismatch,
2209 ErrorIf.CondIfOutputListThenGraphMismatch,
2210 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002211 incorrect_shape = a.shape.copy()
2212 for i in range(len(incorrect_shape)):
2213 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2214 incorrect_block_input = deepcopy(a)
2215 incorrect_block_input.shape = incorrect_shape
2216
Eric Kunzee5e26762020-10-13 16:11:07 -07002217 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002218 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002219 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002220 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002221
James Ward24dbc422022-10-19 12:20:31 +01002222 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002223 then_op, else_op = Op.ADD, Op.SUB
2224 elif a.dtype in (DType.INT8, DType.INT16):
2225 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2226 else:
2227 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002228
Les Bell6040b4d2021-10-11 12:50:31 +01002229 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002230 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002231 if (
2232 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2233 and block == then_block
2234 ) or (
2235 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2236 and block == else_block
2237 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002238 self.ser.addInputTensor(incorrect_block_input)
2239 self.ser.addInputTensor(b)
2240 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002241 elif (
2242 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2243 and block == then_block
2244 ) or (
2245 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2246 and block == else_block
2247 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002248 self.ser.addInputTensor(a)
2249 self.ser.addInputTensor(b)
2250 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2251 else:
2252 self.ser.addInputTensor(a)
2253 self.ser.addInputTensor(b)
2254 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002255 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002256
Les Bell729b0352021-11-24 10:28:21 +00002257 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002258 self.ser,
2259 validator_fcns,
2260 error_name,
2261 op=op,
2262 a=a,
2263 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002264 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002265 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002266 ):
2267 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002268
Eric Kunzee5e26762020-10-13 16:11:07 -07002269 return result_tens
2270
Matthew Haddon630c17c2021-10-14 15:05:41 +01002271 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002272 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002273
Kevin Cheng550ccc52021-03-03 11:21:43 -08002274 cond_block = "COND_BLOCK"
2275 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002276
2277 attr = ts.TosaSerializerAttribute()
2278 attr.WhileLoopAttribute(cond_block, body_block)
2279
2280 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002281 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002282 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002283 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002284
2285 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002286 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2287 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002288 if error_name == ErrorIf.InputListOutputListMismatch:
2289 incorrect_acc = deepcopy(acc)
2290 for i in range(len(incorrect_acc.shape)):
2291 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2292 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2293 else:
2294 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002295
2296 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002297 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002298 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002299 [iter.name, a.name, acc.name],
2300 [iter_out.name, a_out.name, acc_out.name],
2301 attr,
2302 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002303 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002304
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002305 if error_name in [
2306 ErrorIf.InputListCondGraphMismatch,
2307 ErrorIf.InputListBodyGraphInputMismatch,
2308 ErrorIf.InputListBodyGraphOutputMismatch,
2309 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002310 incorrect_iter = deepcopy(iter)
2311 for i in range(len(incorrect_iter.shape)):
2312 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2313 if len(incorrect_iter.shape) == 0:
2314 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2315
2316 incorrect_acc = deepcopy(acc)
2317 for i in range(len(incorrect_acc.shape)):
2318 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2319
Eric Kunzee5e26762020-10-13 16:11:07 -07002320 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002321 self.ser.addBasicBlock(cond_block)
2322
Matthew Haddon630c17c2021-10-14 15:05:41 +01002323 if error_name == ErrorIf.InputListCondGraphMismatch:
2324 self.ser.addInputTensor(incorrect_iter)
2325 self.ser.addInputTensor(a)
2326 self.ser.addInputTensor(incorrect_acc)
2327 else:
2328 self.ser.addInputTensor(iter)
2329 self.ser.addInputTensor(a)
2330 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002331 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002332
2333 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002334 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002335 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002336 cond_type = DType.BOOL
2337 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2338 choice = self.rng.choice([1, 2])
2339 if choice == 1:
2340 cond_shape = [3]
2341 else:
2342 cond_shape = [1, 2]
2343 else:
2344 cond_shape = []
2345 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002346
Kevin Cheng550ccc52021-03-03 11:21:43 -08002347 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002348
2349 # BODY block (input: a, acc, iter, output: a, acc, iter)
2350 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002351 self.ser.addBasicBlock(body_block)
2352
Matthew Haddon630c17c2021-10-14 15:05:41 +01002353 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2354 self.ser.addInputTensor(incorrect_iter)
2355 self.ser.addInputTensor(a)
2356 self.ser.addInputTensor(incorrect_acc)
2357 else:
2358 self.ser.addInputTensor(iter)
2359 self.ser.addInputTensor(a)
2360 self.ser.addInputTensor(acc)
2361
Kevin Cheng550ccc52021-03-03 11:21:43 -08002362 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002363
2364 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002365 iter_body_out = self.ser.addIntermediate(
2366 incorrect_iter.shape, incorrect_iter.dtype
2367 )
2368 acc_body_out = self.ser.addIntermediate(
2369 incorrect_acc.shape, incorrect_acc.dtype
2370 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002371 else:
2372 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2373 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2374
Eric Kunzee5e26762020-10-13 16:11:07 -07002375 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2376 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2377 self.ser.addOutputTensor(iter_body_out)
2378 self.ser.addOutputTensor(a)
2379 self.ser.addOutputTensor(acc_body_out)
2380
Les Bell729b0352021-11-24 10:28:21 +00002381 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002382 self.ser,
2383 validator_fcns,
2384 error_name,
2385 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002386 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002387 ):
2388 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002389
Eric Kunzee5e26762020-10-13 16:11:07 -07002390 return acc_out
2391
Luke Hutton57287132023-02-06 14:54:18 +00002392 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002393 self,
2394 op,
2395 val1,
2396 val2,
2397 inverse,
2398 validator_fcns=None,
2399 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002400 ):
2401 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2402
2403 input_names = [val1.name, val2.name]
2404 pCount, cCount = op["operands"]
2405 num_operands = pCount + cCount
2406
2407 output_names = [res.name for res in results]
2408 output_shapes = [res.shape for res in results]
2409 output_dtypes = [res.dtype for res in results]
2410
2411 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2412 self, error_name, input_names, output_names
2413 )
2414
2415 if not TosaErrorValidator.evValidateErrorIfs(
2416 self.ser,
2417 validator_fcns,
2418 error_name,
2419 op=op,
2420 inverse=inverse,
2421 input1=val1,
2422 input2=val2,
2423 input_shape=val1.shape,
2424 input_dtype=val1.dtype,
2425 output_shape=output_shapes,
2426 output_dtype=output_dtypes,
2427 result_tensors=results,
2428 input_list=input_names,
2429 output_list=output_names,
2430 num_operands=num_operands,
2431 ):
2432 return None
2433
Tai Lyd3797f02023-11-15 23:06:19 +00002434 # TODO - Test local_bound, for now set local bound attribute to False
2435 local_bound = False
2436
Luke Hutton57287132023-02-06 14:54:18 +00002437 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002438 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002439
2440 self.ser.addOperator(op["op"], input_names, output_names, attr)
2441 return results
2442
Tai Lyd3797f02023-11-15 23:06:19 +00002443 def build_rfft2d(
2444 self,
2445 op,
2446 val,
2447 validator_fcns=None,
2448 error_name=None,
2449 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002450 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2451
2452 input_names = [val.name]
2453 pCount, cCount = op["operands"]
2454 num_operands = pCount + cCount
2455
2456 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002457 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002458 output_dtypes = [res.dtype for res in results]
2459
2460 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2461 self, error_name, input_names, output_names
2462 )
2463
2464 if not TosaErrorValidator.evValidateErrorIfs(
2465 self.ser,
2466 validator_fcns,
2467 error_name,
2468 op=op,
2469 input_shape=val.shape,
2470 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002471 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002472 output_dtype=output_dtypes,
2473 result_tensors=results,
2474 input_list=input_names,
2475 output_list=output_names,
2476 num_operands=num_operands,
2477 ):
2478 return None
2479
Tai Lyd3797f02023-11-15 23:06:19 +00002480 # TODO - Test local_bound, for now set local bound attribute to False
2481 local_bound = False
2482
2483 attr = ts.TosaSerializerAttribute()
2484 attr.RFFTAttribute(local_bound)
2485
2486 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002487 return results
2488
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002489 def create_filter_lists(
2490 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2491 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002492 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2493 default_test_rank_range = range(1, 5)
2494 if not shapeFilter:
2495 shapeFilter = [None]
2496
2497 # Calculate the filters based on what is requested and what the operator allows
2498 rmin, rmax = op["rank"]
2499 if rankFilter is not None:
2500 cleanRankFilter = []
2501 # Ensure rankFilter values are allowed by operator
2502 for rank in rankFilter:
2503 if rank >= rmin and rank <= rmax:
2504 cleanRankFilter.append(rank)
2505 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002506 # Ensure default behaviour is bounded by default range or by operator,
2507 # whichever is the smaller range of ranks.
2508 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002509 cleanRankFilter = (
2510 opRankRange
2511 if len(opRankRange) <= len(default_test_rank_range)
2512 else default_test_rank_range
2513 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002514 else:
2515 cleanRankFilter = range(rmin, rmax + 1)
2516
2517 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002518
Matthew Haddon1c00b712021-10-01 15:51:03 +01002519 if dtypeFilter is not None:
2520 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002521 # Create list of operator dtypes filtered by requested dtypes
2522 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002523 if dtype in dtypeFilter or (
2524 isinstance(dtype, list) and dtype[0] in dtypeFilter
2525 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002526 cleanDtypeFilter.append(dtype)
2527 else:
2528 cleanDtypeFilter = dtypes
2529
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002530 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002531 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002532 "shapeFilter": shapeFilter,
2533 "rankFilter": cleanRankFilter,
2534 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002535 }
2536 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002537 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002538 if validator is not None:
2539 validator_info = validator(check=False, op=op)
2540 else:
2541 return None
2542
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002543 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002544
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002545 # Set parameters as required
2546 if error_arguments["rank"] is not None:
2547 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002548 else:
2549 rankFilter = cleanRankFilter
2550
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002551 if error_arguments["dtype"] is not None:
2552 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002553 else:
2554 dtypeFilter = cleanDtypeFilter
2555
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002556 if error_arguments["shape"] is not None:
2557 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002558 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002559 shapeFilter = shapeFilter[
2560 :2
2561 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002562
2563 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002564 "shapeFilter": shapeFilter,
2565 "rankFilter": rankFilter,
2566 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002567 }
2568 return filterDict
2569
Kevin Cheng550ccc52021-03-03 11:21:43 -08002570 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002571 self,
2572 opName,
2573 shapeFilter=[None],
2574 rankFilter=None,
2575 dtypeFilter=None,
2576 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002577 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002578
2579 try:
2580 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002581 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002582 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002583
2584 # Initialize a new random number generator
2585 self.rng = np.random.default_rng(self.random_seed)
2586
Jeremy Johnson1271c442023-09-05 11:39:26 +01002587 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002588
Eric Kunzee5e26762020-10-13 16:11:07 -07002589 # Test list consists of a tuple of:
2590 # (opName, testNameStr, dtype, shapeList, argumentsList)
2591 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002592 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002593 error_if_validators = op["error_if_validators"]
2594 else:
2595 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002596
Matthew Haddon1c00b712021-10-01 15:51:03 +01002597 for validator in error_if_validators:
2598 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002599 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002600 else:
2601 error_name = None
2602
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002603 filterDict = self.create_filter_lists(
2604 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2605 )
2606 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002607 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002608 cleanRankFilter = filterDict["rankFilter"]
2609 cleanDtypeFilter = filterDict["dtypeFilter"]
2610 cleanShapeFilter = filterDict["shapeFilter"]
2611 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002612
2613 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002614 for t in cleanDtypeFilter:
2615 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002616 # Filter out by rank
2617 if shape is not None and len(shape) != r:
2618 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002619 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002620 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002621
Matthew Haddon74567092021-07-16 15:38:20 +01002622 shapeStr = self.shapeStr(shapeList[0])
2623 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002624
Matthew Haddon74567092021-07-16 15:38:20 +01002625 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2626 argList = []
2627 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002628 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002629 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002630 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002631
Matthew Haddon74567092021-07-16 15:38:20 +01002632 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002633 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002634 if argStr:
2635 testStr = "{}_{}_{}_{}".format(
2636 opName, shapeStr, typeStr, argStr
2637 )
2638 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002639 testStr = "{}_{}_{}".format(
2640 opName, shapeStr, typeStr
2641 )
2642 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002643 if argStr:
2644 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2645 opName, error_name, shapeStr, typeStr, argStr
2646 )
2647 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002648 testStr = "{}_ERRORIF_{}_{}_{}".format(
2649 opName, error_name, shapeStr, typeStr
2650 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002651
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002652 testList.append(
2653 (opName, testStr, t, error_name, shapeList, args)
2654 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002655
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002656 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002657 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2658 if "invalid_test_validators" in op:
2659 invalid_test_validators = op["invalid_test_validators"]
2660 clean_testList = []
2661 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002662 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002663 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002664 if validator_fcn(
2665 opName=test[0],
2666 input_dtype=test[2],
2667 shapeList=test[4],
2668 args=test[5],
2669 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002670 remove_test = True
2671 if not remove_test:
2672 clean_testList.append(test)
2673 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002674
2675 return testList
2676
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002677 def serializeTest(
2678 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2679 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002680 try:
2681 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002682 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002683 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002684
Jeremy Johnson0c716862023-04-13 17:18:19 +01002685 if self.args.verbose:
2686 print(f"Creating {testStr}")
2687
Eric Kunzee5e26762020-10-13 16:11:07 -07002688 # Create a serializer
2689 self.createSerializer(opName, testStr)
2690
Jeremy Johnson1271c442023-09-05 11:39:26 +01002691 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002692 if "error_if_validators" in op:
2693 error_if_validators = op["error_if_validators"]
2694 else:
2695 error_if_validators = None
2696
Kevin Cheng550ccc52021-03-03 11:21:43 -08002697 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002698 num_operands = pCount + cCount
2699
2700 if isinstance(dtype_or_dtypeList, list):
2701 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002702 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002703 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002704 else:
2705 dtypeList = [dtype_or_dtypeList] * (num_operands)
2706
Kevin Cheng93a16282021-08-31 16:14:03 -07002707 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002708 assert (
2709 len(shapeList) == num_operands
2710 ), "shapeList length {} must match number of operands {}".format(
2711 len(shapeList), num_operands
2712 )
2713 assert (
2714 len(dtypeList) == num_operands
2715 ), "dtypeList length {} must match number of operands {}".format(
2716 len(dtypeList), num_operands
2717 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002718
2719 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002720 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002721 except KeyError:
2722 qgen = None
2723
2724 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002725
Matthew Haddon1c00b712021-10-01 15:51:03 +01002726 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002727 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002728 else:
2729 qinfo = None
2730
Jeremy Johnson1271c442023-09-05 11:39:26 +01002731 # Extra meta data for the desc.json
2732 tensMeta = {}
2733
2734 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002735 if isinstance(testArgs, dict):
2736 # New interface with args info in dictionary
2737 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002738 assert "dg_type" in argsDict
2739 tvgInfo = tvgen_fcn(
2740 self, opName, dtypeList, shapeList, argsDict, error_name
2741 )
2742 if tvgInfo.dataGenDict:
2743 tensMeta["data_gen"] = tvgInfo.dataGenDict
2744 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002745
2746 result = build_fcn(
2747 self,
2748 op,
2749 tens,
2750 argsDict,
2751 validator_fcns=error_if_validators,
2752 error_name=error_name,
2753 qinfo=qinfo,
2754 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002755 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002756 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002757 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002758
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002759 try:
2760 if error_if_validators is None:
2761 if qinfo is not None:
2762 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2763 else:
2764 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002765 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002766 if qinfo is not None:
2767 result = build_fcn(
2768 self,
2769 op,
2770 *tens,
2771 *testArgs,
2772 validator_fcns=error_if_validators,
2773 error_name=error_name,
2774 qinfo=qinfo,
2775 )
2776 else:
2777 result = build_fcn(
2778 self,
2779 op,
2780 *tens,
2781 *testArgs,
2782 validator_fcns=error_if_validators,
2783 error_name=error_name,
2784 )
2785 except TypeError as e:
2786 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2787 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002788
Jeremy Johnson1271c442023-09-05 11:39:26 +01002789 if result:
Les Bell729b0352021-11-24 10:28:21 +00002790 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002791 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2792 # Add the compliance meta data
2793 # NOTE: This currently expects only one result output
2794 tensMeta["compliance"] = {
2795 "version": "0.1",
2796 "tensors": {result.resultTensor.name: result.complianceDict},
2797 }
2798 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002799 else:
2800 # The test is not valid
2801 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002802
Eric Kunzee5e26762020-10-13 16:11:07 -07002803 def createDynamicOpLists(self):
2804
Jeremy Johnson00423432022-09-12 17:27:37 +01002805 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2806 # Already created these lists (can occur when class is initialized more than once)
2807 return
2808
Eric Kunzee5e26762020-10-13 16:11:07 -07002809 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002810 if not self.args.level8k:
2811 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2812 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2813 else:
2814 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2815 KERNELS_2D = [[1, bigK], [bigK, 2]]
2816 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002817
Kevin Cheng1533b852021-09-01 12:51:58 -07002818 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002819 testName = "conv2d_{}x{}".format(k[0], k[1])
2820 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2821 self.TOSA_OP_LIST[testName]["filter"] = k
2822 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002823
Kevin Cheng550ccc52021-03-03 11:21:43 -08002824 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2825 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2826 "depthwise_conv2d_TEMPLATE"
2827 ].copy()
2828 self.TOSA_OP_LIST[testName]["filter"] = k
2829 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002830
Kevin Cheng550ccc52021-03-03 11:21:43 -08002831 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2832 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2833 "transpose_conv2d_TEMPLATE"
2834 ].copy()
2835 self.TOSA_OP_LIST[testName]["filter"] = k
2836 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002837
Kevin Cheng1533b852021-09-01 12:51:58 -07002838 for k in KERNELS_3D:
2839 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2840 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2841 self.TOSA_OP_LIST[testName]["filter"] = k
2842 self.TOSA_OP_LIST[testName]["template"] = False
2843
Eric Kunzee5e26762020-10-13 16:11:07 -07002844 # Delete any templates after having created any dynamic ops
2845 # This is a two-pass operation because it's bad practice to delete
2846 # keys from dictionaries while iterating
2847 keyList = []
2848 for k in self.TOSA_OP_LIST:
2849 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002850 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002851 keyList.append(k)
2852 continue
2853 except KeyError:
2854 pass
2855
2856 for k in keyList:
2857 del self.TOSA_OP_LIST[k]
2858
2859 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002860 """Fill in default fields for ops if they aren't already specified.
2861 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002862 for op in self.TOSA_OP_LIST:
2863
2864 # Required fields
2865 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002866 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002867 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002868 raise Exception(
2869 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2870 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002871
2872 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002873 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002874 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002875 raise Exception(
2876 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2877 op
2878 )
2879 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002880
2881 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002882 _ = self.TOSA_OP_LIST[op]["types"]
2883 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002884 raise Exception(
2885 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2886 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002887
2888 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002889 _ = self.TOSA_OP_LIST[op]["op"]
2890 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002891 raise Exception(
2892 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2893 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002894
2895 # Put in default rank range, if missing
2896 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002897 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002898 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002899 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002900
2901 # Tensor operator list
2902 # 'op': op name
2903 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002904 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2905 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002906 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2907 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002908 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002909
Kevin Cheng550ccc52021-03-03 11:21:43 -08002910 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002911 TYPE_INT_FP = [
2912 DType.INT8,
2913 DType.INT16,
2914 DType.INT32,
2915 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002916 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002917 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002918 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002919
Kevin Cheng550ccc52021-03-03 11:21:43 -08002920 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002921 TYPE_FI32 = [
2922 DType.FP32,
2923 DType.FP16,
2924 DType.BF16,
2925 DType.INT32,
2926 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002927 TYPE_FIB = [
2928 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002929 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002930 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002931 DType.INT8,
2932 DType.INT16,
2933 DType.INT32,
2934 DType.BOOL,
2935 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002936 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002937
James Ward24dbc422022-10-19 12:20:31 +01002938 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002939
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002940 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002941 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002942 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002943 [DType.INT8, DType.INT8, DType.INT32],
2944 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002945 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002946 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002947 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002948 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002949 ]
2950
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002951 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002952
2953 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002954 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002955 "argmax": {
2956 "op": Op.ARGMAX,
2957 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002958 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002959 "build_fcn": (
2960 build_argmax,
2961 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002962 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002963 TosaArgGen.agAxis,
2964 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002965 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002966 "error_if_validators": (
2967 TosaErrorValidator.evAxisSmallerZero,
2968 TosaErrorValidator.evAxisLargerRank,
2969 TosaErrorValidator.evArgmaxOutputRankMismatch,
2970 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2971 TosaErrorValidator.evWrongRank,
2972 TosaErrorValidator.evWrongInputType,
2973 TosaErrorValidator.evWrongOutputType,
2974 TosaErrorValidator.evWrongInputList,
2975 TosaErrorValidator.evWrongOutputList,
2976 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002977 "data_gen": {
2978 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
2979 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002980 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002981 "avg_pool2d": {
2982 "op": Op.AVG_POOL2D,
2983 "operands": (1, 0),
2984 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002985 "build_fcn": (
2986 build_pool2d,
2987 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002988 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002989 TosaArgGen.agPooling,
2990 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002991 "qgen": TosaQuantGen.qgUnary,
2992 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002993 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002994 "error_if_validators": (
2995 TosaErrorValidator.evKernelSmallerOne,
2996 TosaErrorValidator.evStrideSmallerOne,
2997 TosaErrorValidator.evPadSmallerZero,
2998 TosaErrorValidator.evWrongRank,
2999 TosaErrorValidator.evWrongInputType,
3000 TosaErrorValidator.evWrongOutputType,
3001 TosaErrorValidator.evWrongInputList,
3002 TosaErrorValidator.evWrongOutputList,
3003 TosaErrorValidator.evInputZeroPointNotZero,
3004 TosaErrorValidator.evOutputZeroPointNotZero,
3005 TosaErrorValidator.evPadLargerEqualKernel,
3006 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003007 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003008 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003009 "data_gen": {
3010 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3011 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003012 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003013 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003014 "conv2d_TEMPLATE": {
3015 "op": Op.CONV2D,
3016 "operands": (1, 2),
3017 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003018 "build_fcn": (
3019 build_conv2d,
3020 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003021 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003022 TosaArgGen.agConv,
3023 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003024 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003025 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003026 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3027 "error_if_validators": (
3028 TosaErrorValidator.evWrongInputType,
3029 TosaErrorValidator.evWrongOutputType,
3030 TosaErrorValidator.evWrongInputList,
3031 TosaErrorValidator.evWrongOutputList,
3032 TosaErrorValidator.evInputZeroPointNotZero,
3033 TosaErrorValidator.evWeightZeroPointNotZero,
3034 TosaErrorValidator.evPadSmallerZero,
3035 TosaErrorValidator.evStrideSmallerOne,
3036 TosaErrorValidator.evDilationSmallerOne,
3037 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003038 TosaErrorValidator.evConvOutputShapeMismatch,
3039 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003040 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003041 "data_gen": {
3042 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3043 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003044 "template": True,
3045 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003046 # Templated operator. Filled in by createDynamicOpLists
3047 "conv3d_TEMPLATE": {
3048 "op": Op.CONV3D,
3049 "operands": (1, 2),
3050 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003051 "build_fcn": (
3052 build_conv3d,
3053 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003054 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003055 TosaArgGen.agConv,
3056 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003057 "qgen": TosaQuantGen.qgConv,
3058 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003059 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3060 "error_if_validators": (
3061 TosaErrorValidator.evWrongInputType,
3062 TosaErrorValidator.evWrongOutputType,
3063 TosaErrorValidator.evWrongInputList,
3064 TosaErrorValidator.evWrongOutputList,
3065 TosaErrorValidator.evInputZeroPointNotZero,
3066 TosaErrorValidator.evWeightZeroPointNotZero,
3067 TosaErrorValidator.evPadSmallerZero,
3068 TosaErrorValidator.evStrideSmallerOne,
3069 TosaErrorValidator.evDilationSmallerOne,
3070 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003071 TosaErrorValidator.evConvOutputShapeMismatch,
3072 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003073 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003074 "template": True,
3075 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003076 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003077 "depthwise_conv2d_TEMPLATE": {
3078 "op": Op.DEPTHWISE_CONV2D,
3079 "operands": (1, 2),
3080 "filter": [1, 1],
3081 "rank": (4, 4),
3082 "build_fcn": (
3083 build_depthwise_conv2d,
3084 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003085 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003086 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003087 ),
3088 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003089 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003090 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3091 "error_if_validators": (
3092 TosaErrorValidator.evWrongInputType,
3093 TosaErrorValidator.evWrongOutputType,
3094 TosaErrorValidator.evWrongInputList,
3095 TosaErrorValidator.evWrongOutputList,
3096 TosaErrorValidator.evInputZeroPointNotZero,
3097 TosaErrorValidator.evWeightZeroPointNotZero,
3098 TosaErrorValidator.evPadSmallerZero,
3099 TosaErrorValidator.evStrideSmallerOne,
3100 TosaErrorValidator.evDilationSmallerOne,
3101 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003102 TosaErrorValidator.evConvOutputShapeMismatch,
3103 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003104 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003105 "template": True,
3106 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003107 "fully_connected": {
3108 "op": Op.FULLY_CONNECTED,
3109 "operands": (1, 2),
3110 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003111 "build_fcn": (
3112 build_fully_connected,
3113 TosaTensorGen.tgFullyConnected,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003114 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003115 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003116 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003117 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003118 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003119 "error_if_validators": (
3120 TosaErrorValidator.evInputZeroPointNotZero,
3121 TosaErrorValidator.evWeightZeroPointNotZero,
3122 TosaErrorValidator.evWrongRank,
3123 TosaErrorValidator.evWrongInputType,
3124 TosaErrorValidator.evWrongOutputType,
3125 TosaErrorValidator.evWrongInputList,
3126 TosaErrorValidator.evWrongOutputList,
3127 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003128 "data_gen": {
3129 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3130 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003131 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003132 "matmul": {
3133 "op": Op.MATMUL,
3134 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003135 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003136 "build_fcn": (
3137 build_matmul,
3138 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003139 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003140 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003141 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003142 "qgen": TosaQuantGen.qgMatmul,
3143 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003144 "error_if_validators": (
3145 TosaErrorValidator.evInputZeroPointNotZero,
3146 TosaErrorValidator.evWrongRank,
3147 TosaErrorValidator.evWrongInputType,
3148 TosaErrorValidator.evWrongOutputType,
3149 TosaErrorValidator.evWrongInputList,
3150 TosaErrorValidator.evWrongOutputList,
3151 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003152 "data_gen": {
3153 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003154 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003155 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003156 "max_pool2d": {
3157 "op": Op.MAX_POOL2D,
3158 "operands": (1, 0),
3159 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003160 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003161 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003162 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003163 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003164 TosaArgGen.agPooling,
3165 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003166 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003167 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003168 "error_if_validators": (
3169 TosaErrorValidator.evKernelSmallerOne,
3170 TosaErrorValidator.evStrideSmallerOne,
3171 TosaErrorValidator.evPadSmallerZero,
3172 TosaErrorValidator.evWrongRank,
3173 TosaErrorValidator.evWrongInputType,
3174 TosaErrorValidator.evWrongOutputType,
3175 TosaErrorValidator.evWrongInputList,
3176 TosaErrorValidator.evWrongOutputList,
3177 TosaErrorValidator.evPadLargerEqualKernel,
3178 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003179 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003180 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003181 "data_gen": {
3182 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3183 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003184 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003185 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003186 "transpose_conv2d_TEMPLATE": {
3187 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003188 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003189 "rank": (4, 4),
3190 "build_fcn": (
3191 build_transpose_conv2d,
3192 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003193 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003194 TosaArgGen.agTransposeConv2D,
3195 ),
3196 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003197 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003198 "invalid_test_validators": (
3199 TosaInvalidValidator.ivHeightWidthInvalid,
3200 TosaInvalidValidator.ivNonPositiveOutputShape,
3201 ),
3202 "error_if_validators": (
3203 TosaErrorValidator.evWrongInputType,
3204 TosaErrorValidator.evWrongOutputType,
3205 TosaErrorValidator.evWrongInputList,
3206 TosaErrorValidator.evWrongOutputList,
3207 TosaErrorValidator.evInputZeroPointNotZero,
3208 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003209 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003210 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003211 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003212 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003213 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003214 "template": True,
3215 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003216 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003217 "clamp": {
3218 "op": Op.CLAMP,
3219 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003220 "build_fcn": (
3221 build_clamp,
3222 TosaTensorGen.tgBasic,
3223 TosaTensorValuesGen.tvgDefault,
3224 None,
3225 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003226 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003227 "error_if_validators": (
3228 TosaErrorValidator.evMaxSmallerMin,
3229 TosaErrorValidator.evWrongInputType,
3230 TosaErrorValidator.evWrongOutputType,
3231 TosaErrorValidator.evWrongInputList,
3232 TosaErrorValidator.evWrongOutputList,
3233 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003234 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003235 "sigmoid": {
3236 "op": Op.SIGMOID,
3237 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003238 "build_fcn": (
3239 build_sigmoid,
3240 TosaTensorGen.tgBasic,
3241 TosaTensorValuesGen.tvgDefault,
3242 None,
3243 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003244 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003245 "error_if_validators": (
3246 TosaErrorValidator.evWrongInputType,
3247 TosaErrorValidator.evWrongOutputType,
3248 TosaErrorValidator.evWrongInputList,
3249 TosaErrorValidator.evWrongOutputList,
3250 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003251 },
3252 "tanh": {
3253 "op": Op.TANH,
3254 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003255 "build_fcn": (
3256 build_tanh,
3257 TosaTensorGen.tgBasic,
3258 TosaTensorValuesGen.tvgDefault,
3259 None,
3260 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003261 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003262 "error_if_validators": (
3263 TosaErrorValidator.evWrongInputType,
3264 TosaErrorValidator.evWrongOutputType,
3265 TosaErrorValidator.evWrongInputList,
3266 TosaErrorValidator.evWrongOutputList,
3267 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003268 },
Won Jeon78155c62023-06-10 00:20:04 +00003269 "erf": {
3270 "op": Op.ERF,
3271 "operands": (1, 0),
3272 "build_fcn": (
3273 build_erf,
3274 TosaTensorGen.tgBasic,
3275 TosaTensorValuesGen.tvgDefault,
3276 None,
3277 ),
3278 "types": TYPE_FP,
3279 "error_if_validators": (
3280 TosaErrorValidator.evWrongInputType,
3281 TosaErrorValidator.evWrongOutputType,
3282 TosaErrorValidator.evWrongInputList,
3283 TosaErrorValidator.evWrongOutputList,
3284 ),
3285 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003286 # Elementwise Binary Operators
3287 "add": {
3288 "op": Op.ADD,
3289 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003290 "build_fcn": (
3291 build_binary_broadcast,
3292 TosaTensorGen.tgBroadcastFuzz,
3293 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003294 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003295 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003296 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003297 "error_if_validators": (
3298 TosaErrorValidator.evRankMismatch,
3299 TosaErrorValidator.evWrongInputType,
3300 TosaErrorValidator.evWrongOutputType,
3301 TosaErrorValidator.evWrongInputList,
3302 TosaErrorValidator.evWrongOutputList,
3303 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003304 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003305 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003306 "data_gen": {
3307 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3308 },
3309 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003310 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003311 "arithmetic_right_shift": {
3312 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3313 "operands": (2, 0),
3314 "build_fcn": (
3315 build_arithmetic_right_shift,
3316 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003317 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003318 TosaArgGen.agArithmeticRightShift,
3319 ),
3320 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003321 "error_if_validators": (
3322 TosaErrorValidator.evRankMismatch,
3323 TosaErrorValidator.evWrongInputType,
3324 TosaErrorValidator.evWrongOutputType,
3325 TosaErrorValidator.evWrongInputList,
3326 TosaErrorValidator.evWrongOutputList,
3327 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003328 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003329 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003330 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003331 "bitwise_and": {
3332 "op": Op.BITWISE_AND,
3333 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003334 "build_fcn": (
3335 build_binary_broadcast,
3336 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003337 TosaTensorValuesGen.tvgLazyGenDefault,
3338 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003339 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003340 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003341 "error_if_validators": (
3342 TosaErrorValidator.evRankMismatch,
3343 TosaErrorValidator.evWrongInputType,
3344 TosaErrorValidator.evWrongOutputType,
3345 TosaErrorValidator.evWrongInputList,
3346 TosaErrorValidator.evWrongOutputList,
3347 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003348 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003349 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003350 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003351 "bitwise_or": {
3352 "op": Op.BITWISE_OR,
3353 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003354 "build_fcn": (
3355 build_binary_broadcast,
3356 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003357 TosaTensorValuesGen.tvgLazyGenDefault,
3358 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003359 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003360 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003361 "error_if_validators": (
3362 TosaErrorValidator.evRankMismatch,
3363 TosaErrorValidator.evWrongInputType,
3364 TosaErrorValidator.evWrongOutputType,
3365 TosaErrorValidator.evWrongInputList,
3366 TosaErrorValidator.evWrongOutputList,
3367 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003368 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003369 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003370 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003371 "bitwise_xor": {
3372 "op": Op.BITWISE_XOR,
3373 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003374 "build_fcn": (
3375 build_binary_broadcast,
3376 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003377 TosaTensorValuesGen.tvgLazyGenDefault,
3378 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003379 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003380 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003381 "error_if_validators": (
3382 TosaErrorValidator.evRankMismatch,
3383 TosaErrorValidator.evWrongInputType,
3384 TosaErrorValidator.evWrongOutputType,
3385 TosaErrorValidator.evWrongInputList,
3386 TosaErrorValidator.evWrongOutputList,
3387 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003388 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003389 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003390 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003391 "intdiv": {
3392 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003393 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003394 "build_fcn": (
3395 build_binary_broadcast,
3396 TosaTensorGen.tgBroadcastFuzz,
3397 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003398 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003399 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003400 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003401 "error_if_validators": (
3402 TosaErrorValidator.evRankMismatch,
3403 TosaErrorValidator.evWrongInputType,
3404 TosaErrorValidator.evWrongOutputType,
3405 TosaErrorValidator.evWrongInputList,
3406 TosaErrorValidator.evWrongOutputList,
3407 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003408 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003409 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003410 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003411 "logical_and": {
3412 "op": Op.LOGICAL_AND,
3413 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003414 "build_fcn": (
3415 build_binary_broadcast,
3416 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003417 TosaTensorValuesGen.tvgLazyGenDefault,
3418 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003419 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003420 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003421 "error_if_validators": (
3422 TosaErrorValidator.evRankMismatch,
3423 TosaErrorValidator.evWrongInputType,
3424 TosaErrorValidator.evWrongOutputType,
3425 TosaErrorValidator.evWrongInputList,
3426 TosaErrorValidator.evWrongOutputList,
3427 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003428 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003429 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003430 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003431 "logical_left_shift": {
3432 "op": Op.LOGICAL_LEFT_SHIFT,
3433 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003434 "build_fcn": (
3435 build_binary_broadcast,
3436 TosaTensorGen.tgBroadcastFuzz,
3437 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003438 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003439 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003440 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003441 "error_if_validators": (
3442 TosaErrorValidator.evRankMismatch,
3443 TosaErrorValidator.evWrongInputType,
3444 TosaErrorValidator.evWrongOutputType,
3445 TosaErrorValidator.evWrongInputList,
3446 TosaErrorValidator.evWrongOutputList,
3447 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003448 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003449 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003450 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003451 "logical_right_shift": {
3452 "op": Op.LOGICAL_RIGHT_SHIFT,
3453 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003454 "build_fcn": (
3455 build_binary_broadcast,
3456 TosaTensorGen.tgBroadcastFuzz,
3457 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003458 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003459 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003460 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003461 "error_if_validators": (
3462 TosaErrorValidator.evRankMismatch,
3463 TosaErrorValidator.evWrongInputType,
3464 TosaErrorValidator.evWrongOutputType,
3465 TosaErrorValidator.evWrongInputList,
3466 TosaErrorValidator.evWrongOutputList,
3467 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003468 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003469 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003470 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003471 "logical_or": {
3472 "op": Op.LOGICAL_OR,
3473 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003474 "build_fcn": (
3475 build_binary_broadcast,
3476 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003477 TosaTensorValuesGen.tvgLazyGenDefault,
3478 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003479 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003480 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003481 "error_if_validators": (
3482 TosaErrorValidator.evRankMismatch,
3483 TosaErrorValidator.evWrongInputType,
3484 TosaErrorValidator.evWrongOutputType,
3485 TosaErrorValidator.evWrongInputList,
3486 TosaErrorValidator.evWrongOutputList,
3487 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003488 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003489 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003490 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003491 "logical_xor": {
3492 "op": Op.LOGICAL_XOR,
3493 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003494 "build_fcn": (
3495 build_binary_broadcast,
3496 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003497 TosaTensorValuesGen.tvgLazyGenDefault,
3498 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003499 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003500 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003501 "error_if_validators": (
3502 TosaErrorValidator.evRankMismatch,
3503 TosaErrorValidator.evWrongInputType,
3504 TosaErrorValidator.evWrongOutputType,
3505 TosaErrorValidator.evWrongInputList,
3506 TosaErrorValidator.evWrongOutputList,
3507 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003508 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003509 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003510 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003511 "maximum": {
3512 "op": Op.MAXIMUM,
3513 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003514 "build_fcn": (
3515 build_binary_broadcast,
3516 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003517 TosaTensorValuesGen.tvgLazyGenDefault,
3518 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003519 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003520 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003521 "error_if_validators": (
3522 TosaErrorValidator.evRankMismatch,
3523 TosaErrorValidator.evWrongInputType,
3524 TosaErrorValidator.evWrongOutputType,
3525 TosaErrorValidator.evWrongInputList,
3526 TosaErrorValidator.evWrongOutputList,
3527 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003528 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003529 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003530 "data_gen": {
3531 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3532 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003533 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003534 "minimum": {
3535 "op": Op.MINIMUM,
3536 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003537 "build_fcn": (
3538 build_binary_broadcast,
3539 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003540 TosaTensorValuesGen.tvgLazyGenDefault,
3541 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003542 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003543 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003544 "error_if_validators": (
3545 TosaErrorValidator.evRankMismatch,
3546 TosaErrorValidator.evWrongInputType,
3547 TosaErrorValidator.evWrongOutputType,
3548 TosaErrorValidator.evWrongInputList,
3549 TosaErrorValidator.evWrongOutputList,
3550 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003551 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003552 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003553 "data_gen": {
3554 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3555 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003556 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003557 "mul": {
3558 "op": Op.MUL,
3559 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003560 "build_fcn": (
3561 build_mul,
3562 TosaTensorGen.tgBroadcastFuzz,
3563 TosaTensorValuesGen.tvgMul,
3564 TosaArgGen.agMul,
3565 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003566 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003567 "error_if_validators": (
3568 TosaErrorValidator.evWrongInputType,
3569 TosaErrorValidator.evWrongOutputType,
3570 TosaErrorValidator.evWrongInputList,
3571 TosaErrorValidator.evWrongOutputList,
3572 TosaErrorValidator.evRankMismatch,
3573 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003574 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003575 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003576 "data_gen": {
3577 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3578 },
3579 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003580 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003581 "pow": {
3582 "op": Op.POW,
3583 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003584 "build_fcn": (
3585 build_binary_broadcast,
3586 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003587 TosaTensorValuesGen.tvgLazyGenDefault,
3588 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003589 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003590 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003591 "error_if_validators": (
3592 TosaErrorValidator.evRankMismatch,
3593 TosaErrorValidator.evWrongInputType,
3594 TosaErrorValidator.evWrongOutputType,
3595 TosaErrorValidator.evWrongInputList,
3596 TosaErrorValidator.evWrongOutputList,
3597 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003598 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003599 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003600 "data_gen": {
3601 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3602 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003603 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003604 "sub": {
3605 "op": Op.SUB,
3606 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003607 "build_fcn": (
3608 build_binary_broadcast,
3609 TosaTensorGen.tgBroadcastFuzz,
3610 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003611 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003612 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003613 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003614 "error_if_validators": (
3615 TosaErrorValidator.evRankMismatch,
3616 TosaErrorValidator.evWrongInputType,
3617 TosaErrorValidator.evWrongOutputType,
3618 TosaErrorValidator.evWrongInputList,
3619 TosaErrorValidator.evWrongOutputList,
3620 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003621 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003622 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003623 "data_gen": {
3624 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3625 },
3626 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003627 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003628 "table": {
3629 "op": Op.TABLE,
3630 # Use the automatic generation functions to create the input array
3631 # but create the table tensor in the build function, as it may be
3632 # a different type from the input
3633 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003634 "build_fcn": (
3635 build_table,
3636 TosaTensorGen.tgBasic,
3637 TosaTensorValuesGen.tvgDefault,
3638 TosaArgGen.agTable,
3639 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003640 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003641 "error_if_validators": (
3642 TosaErrorValidator.evWrongInputType,
3643 TosaErrorValidator.evWrongOutputType,
3644 TosaErrorValidator.evWrongInputList,
3645 TosaErrorValidator.evWrongOutputList,
3646 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003647 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003648 # Elementwise Unary operators
3649 "abs": {
3650 "op": Op.ABS,
3651 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003652 "build_fcn": (
3653 build_unary,
3654 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003655 TosaTensorValuesGen.tvgLazyGenDefault,
3656 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003657 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003658 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003659 "error_if_validators": (
3660 TosaErrorValidator.evWrongInputType,
3661 TosaErrorValidator.evWrongOutputType,
3662 TosaErrorValidator.evWrongInputList,
3663 TosaErrorValidator.evWrongOutputList,
3664 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003665 "data_gen": {
3666 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3667 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003668 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003669 "bitwise_not": {
3670 "op": Op.BITWISE_NOT,
3671 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003672 "build_fcn": (
3673 build_unary,
3674 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003675 TosaTensorValuesGen.tvgLazyGenDefault,
3676 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003677 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003678 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003679 "error_if_validators": (
3680 TosaErrorValidator.evWrongInputType,
3681 TosaErrorValidator.evWrongOutputType,
3682 TosaErrorValidator.evWrongInputList,
3683 TosaErrorValidator.evWrongOutputList,
3684 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003685 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003686 "ceil": {
3687 "op": Op.CEIL,
3688 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003689 "build_fcn": (
3690 build_unary,
3691 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003692 TosaTensorValuesGen.tvgLazyGenDefault,
3693 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003694 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003695 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003696 "error_if_validators": (
3697 TosaErrorValidator.evWrongInputType,
3698 TosaErrorValidator.evWrongOutputType,
3699 TosaErrorValidator.evWrongInputList,
3700 TosaErrorValidator.evWrongOutputList,
3701 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003702 "data_gen": {
3703 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3704 },
3705 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003706 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003707 "clz": {
3708 "op": Op.CLZ,
3709 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003710 "build_fcn": (
3711 build_unary,
3712 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003713 TosaTensorValuesGen.tvgLazyGenDefault,
3714 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003715 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003716 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003717 "error_if_validators": (
3718 TosaErrorValidator.evWrongInputType,
3719 TosaErrorValidator.evWrongOutputType,
3720 TosaErrorValidator.evWrongInputList,
3721 TosaErrorValidator.evWrongOutputList,
3722 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003723 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003724 "exp": {
3725 "op": Op.EXP,
3726 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003727 "build_fcn": (
3728 build_unary,
3729 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003730 TosaTensorValuesGen.tvgLazyGenDefault,
3731 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003732 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003733 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003734 "error_if_validators": (
3735 TosaErrorValidator.evWrongInputType,
3736 TosaErrorValidator.evWrongOutputType,
3737 TosaErrorValidator.evWrongInputList,
3738 TosaErrorValidator.evWrongOutputList,
3739 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003740 "data_gen": {
3741 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3742 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003743 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003744 "floor": {
3745 "op": Op.FLOOR,
3746 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003747 "build_fcn": (
3748 build_unary,
3749 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003750 TosaTensorValuesGen.tvgLazyGenDefault,
3751 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003752 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003753 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003754 "error_if_validators": (
3755 TosaErrorValidator.evWrongInputType,
3756 TosaErrorValidator.evWrongOutputType,
3757 TosaErrorValidator.evWrongInputList,
3758 TosaErrorValidator.evWrongOutputList,
3759 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003760 "data_gen": {
3761 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3762 },
3763 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003764 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003765 "log": {
3766 "op": Op.LOG,
3767 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003768 "build_fcn": (
3769 build_unary,
3770 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003771 TosaTensorValuesGen.tvgLazyGenDefault,
3772 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003773 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003774 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003775 "error_if_validators": (
3776 TosaErrorValidator.evWrongInputType,
3777 TosaErrorValidator.evWrongOutputType,
3778 TosaErrorValidator.evWrongInputList,
3779 TosaErrorValidator.evWrongOutputList,
3780 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003782 "logical_not": {
3783 "op": Op.LOGICAL_NOT,
3784 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003785 "build_fcn": (
3786 build_unary,
3787 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003788 TosaTensorValuesGen.tvgLazyGenDefault,
3789 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003790 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003791 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003792 "error_if_validators": (
3793 TosaErrorValidator.evWrongInputType,
3794 TosaErrorValidator.evWrongOutputType,
3795 TosaErrorValidator.evWrongInputList,
3796 TosaErrorValidator.evWrongOutputList,
3797 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003798 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003799 "negate": {
3800 "op": Op.NEGATE,
3801 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003802 "build_fcn": (
3803 build_unary,
3804 TosaTensorGen.tgBasic,
3805 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003806 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003807 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003808 "qgen": TosaQuantGen.qgUnary,
3809 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003810 "error_if_validators": (
3811 TosaErrorValidator.evInputZeroPointNotZero,
3812 TosaErrorValidator.evOutputZeroPointNotZero,
3813 TosaErrorValidator.evWrongInputType,
3814 TosaErrorValidator.evWrongOutputType,
3815 TosaErrorValidator.evWrongInputList,
3816 TosaErrorValidator.evWrongOutputList,
3817 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003818 "data_gen": {
3819 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3820 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003821 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003822 "reciprocal": {
3823 "op": Op.RECIPROCAL,
3824 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003825 "build_fcn": (
3826 build_unary,
3827 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003828 TosaTensorValuesGen.tvgLazyGenDefault,
3829 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003830 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003831 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003832 "error_if_validators": (
3833 TosaErrorValidator.evWrongInputType,
3834 TosaErrorValidator.evWrongOutputType,
3835 TosaErrorValidator.evWrongInputList,
3836 TosaErrorValidator.evWrongOutputList,
3837 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003838 "data_gen": {
3839 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3840 },
3841 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003842 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003843 "rsqrt": {
3844 "op": Op.RSQRT,
3845 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003846 "build_fcn": (
3847 build_unary,
3848 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003849 TosaTensorValuesGen.tvgLazyGenDefault,
3850 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003851 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003852 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003853 "error_if_validators": (
3854 TosaErrorValidator.evWrongInputType,
3855 TosaErrorValidator.evWrongOutputType,
3856 TosaErrorValidator.evWrongInputList,
3857 TosaErrorValidator.evWrongOutputList,
3858 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003859 "data_gen": {
3860 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3861 },
3862 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003863 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003864 # Elementwise Ternary operators
3865 "select": {
3866 "op": Op.SELECT,
3867 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003868 "build_fcn": (
3869 build_select,
3870 TosaTensorGen.tgBroadcastFuzz,
3871 TosaTensorValuesGen.tvgSelect,
3872 None,
3873 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003874 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003875 "error_if_validators": (
3876 TosaErrorValidator.evRankMismatch,
3877 TosaErrorValidator.evWrongInputType,
3878 TosaErrorValidator.evWrongOutputType,
3879 TosaErrorValidator.evWrongInputList,
3880 TosaErrorValidator.evWrongOutputList,
3881 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003882 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003883 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003884 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003885 # Comparison operators
3886 "equal": {
3887 "op": Op.EQUAL,
3888 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003889 "build_fcn": (
3890 build_comparison,
3891 TosaTensorGen.tgBroadcastFuzz,
3892 TosaTensorValuesGen.tvgEqual,
3893 None,
3894 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003895 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003896 "error_if_validators": (
3897 TosaErrorValidator.evRankMismatch,
3898 TosaErrorValidator.evWrongInputType,
3899 TosaErrorValidator.evWrongOutputType,
3900 TosaErrorValidator.evWrongInputList,
3901 TosaErrorValidator.evWrongOutputList,
3902 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003903 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003904 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003905 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003906 "greater_equal": {
3907 "op": Op.GREATER_EQUAL,
3908 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003909 "build_fcn": (
3910 build_comparison,
3911 TosaTensorGen.tgBroadcastFuzz,
3912 TosaTensorValuesGen.tvgDefault,
3913 None,
3914 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003915 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003916 "error_if_validators": (
3917 TosaErrorValidator.evRankMismatch,
3918 TosaErrorValidator.evWrongInputType,
3919 TosaErrorValidator.evWrongOutputType,
3920 TosaErrorValidator.evWrongInputList,
3921 TosaErrorValidator.evWrongOutputList,
3922 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003923 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003924 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003925 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003926 "greater": {
3927 "op": Op.GREATER,
3928 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003929 "build_fcn": (
3930 build_comparison,
3931 TosaTensorGen.tgBroadcastFuzz,
3932 TosaTensorValuesGen.tvgDefault,
3933 None,
3934 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003935 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003936 "error_if_validators": (
3937 TosaErrorValidator.evRankMismatch,
3938 TosaErrorValidator.evWrongInputType,
3939 TosaErrorValidator.evWrongOutputType,
3940 TosaErrorValidator.evWrongInputList,
3941 TosaErrorValidator.evWrongOutputList,
3942 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003943 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003944 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003945 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003946 # Reduction operators
3947 "reduce_all": {
3948 "op": Op.REDUCE_ALL,
3949 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003950 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003951 "build_fcn": (
3952 build_reduce,
3953 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003954 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003955 TosaArgGen.agAxis,
3956 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003957 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003958 "error_if_validators": (
3959 TosaErrorValidator.evAxisLargerRank,
3960 TosaErrorValidator.evAxisSmallerZero,
3961 TosaErrorValidator.evShapeOfAxisNotOne,
3962 TosaErrorValidator.evWrongInputType,
3963 TosaErrorValidator.evWrongOutputType,
3964 TosaErrorValidator.evWrongRank,
3965 TosaErrorValidator.evWrongInputList,
3966 TosaErrorValidator.evWrongOutputList,
3967 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003968 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003969 "reduce_any": {
3970 "op": Op.REDUCE_ANY,
3971 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003972 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003973 "build_fcn": (
3974 build_reduce,
3975 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003976 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003977 TosaArgGen.agAxis,
3978 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003979 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003980 "error_if_validators": (
3981 TosaErrorValidator.evAxisLargerRank,
3982 TosaErrorValidator.evAxisSmallerZero,
3983 TosaErrorValidator.evShapeOfAxisNotOne,
3984 TosaErrorValidator.evWrongInputType,
3985 TosaErrorValidator.evWrongOutputType,
3986 TosaErrorValidator.evWrongRank,
3987 TosaErrorValidator.evWrongInputList,
3988 TosaErrorValidator.evWrongOutputList,
3989 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003990 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003991 "reduce_max": {
3992 "op": Op.REDUCE_MAX,
3993 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003994 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003995 "build_fcn": (
3996 build_reduce,
3997 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003998 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003999 TosaArgGen.agAxis,
4000 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004001 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004002 "error_if_validators": (
4003 TosaErrorValidator.evAxisLargerRank,
4004 TosaErrorValidator.evAxisSmallerZero,
4005 TosaErrorValidator.evShapeOfAxisNotOne,
4006 TosaErrorValidator.evWrongInputType,
4007 TosaErrorValidator.evWrongOutputType,
4008 TosaErrorValidator.evWrongRank,
4009 TosaErrorValidator.evWrongInputList,
4010 TosaErrorValidator.evWrongOutputList,
4011 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004012 "data_gen": {
4013 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4014 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004015 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004016 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004017 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004018 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004019 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004020 "build_fcn": (
4021 build_reduce,
4022 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004023 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004024 TosaArgGen.agAxis,
4025 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004026 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004027 "error_if_validators": (
4028 TosaErrorValidator.evAxisLargerRank,
4029 TosaErrorValidator.evAxisSmallerZero,
4030 TosaErrorValidator.evShapeOfAxisNotOne,
4031 TosaErrorValidator.evWrongInputType,
4032 TosaErrorValidator.evWrongOutputType,
4033 TosaErrorValidator.evWrongRank,
4034 TosaErrorValidator.evWrongInputList,
4035 TosaErrorValidator.evWrongOutputList,
4036 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004037 "data_gen": {
4038 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4039 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004040 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004041 "reduce_product": {
4042 "op": Op.REDUCE_PRODUCT,
4043 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004044 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004045 "build_fcn": (
4046 build_reduce,
4047 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004048 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004049 TosaArgGen.agAxis,
4050 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004051 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004052 "error_if_validators": (
4053 TosaErrorValidator.evAxisLargerRank,
4054 TosaErrorValidator.evAxisSmallerZero,
4055 TosaErrorValidator.evShapeOfAxisNotOne,
4056 TosaErrorValidator.evWrongInputType,
4057 TosaErrorValidator.evWrongOutputType,
4058 TosaErrorValidator.evWrongRank,
4059 TosaErrorValidator.evWrongInputList,
4060 TosaErrorValidator.evWrongOutputList,
4061 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004062 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004063 "reduce_sum": {
4064 "op": Op.REDUCE_SUM,
4065 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004066 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004067 "build_fcn": (
4068 build_reduce,
4069 TosaTensorGen.tgBasic,
4070 TosaTensorValuesGen.tvgReduceSum,
4071 TosaArgGen.agAxis,
4072 ),
James Ward24dbc422022-10-19 12:20:31 +01004073 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004074 "error_if_validators": (
4075 TosaErrorValidator.evAxisLargerRank,
4076 TosaErrorValidator.evAxisSmallerZero,
4077 TosaErrorValidator.evShapeOfAxisNotOne,
4078 TosaErrorValidator.evWrongInputType,
4079 TosaErrorValidator.evWrongOutputType,
4080 TosaErrorValidator.evWrongRank,
4081 TosaErrorValidator.evWrongInputList,
4082 TosaErrorValidator.evWrongOutputList,
4083 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004084 "data_gen": {
4085 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4086 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004087 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004088 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004089 "concat": {
4090 "op": Op.CONCAT,
4091 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004092 "build_fcn": (
4093 build_concat,
4094 TosaTensorGen.tgConcat,
4095 TosaTensorValuesGen.tvgConcat,
4096 TosaArgGen.agAxis,
4097 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004098 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004099 "error_if_validators": (
4100 TosaErrorValidator.evAxisLargerRank,
4101 TosaErrorValidator.evAxisSmallerZero,
4102 TosaErrorValidator.evConcatInputRankMismatch,
4103 TosaErrorValidator.evConcatShapeSumMismatch,
4104 TosaErrorValidator.evConcatInputDimMismatch,
4105 TosaErrorValidator.evWrongInputType,
4106 TosaErrorValidator.evWrongOutputType,
4107 TosaErrorValidator.evWrongOutputList,
4108 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004109 },
4110 "pad": {
4111 "op": Op.PAD,
4112 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004113 "build_fcn": (
4114 build_pad,
4115 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004116 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004117 TosaArgGen.agPad,
4118 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004119 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004120 "error_if_validators": (
4121 TosaErrorValidator.evWrongInputType,
4122 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004123 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004124 TosaErrorValidator.evWrongOutputType,
4125 TosaErrorValidator.evWrongInputList,
4126 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004127 TosaErrorValidator.evRankMismatch,
4128 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004129 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004130 "data_gen": {
4131 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4132 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004133 },
Won Jeona21b2e82023-08-10 10:33:01 +00004134 "dim": {
4135 "op": Op.DIM,
4136 "operands": (1, 0),
4137 "build_fcn": (
4138 build_dim,
4139 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004140 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004141 TosaArgGen.agAxis,
4142 ),
4143 "types": TYPE_FIB,
4144 "error_if_validators": (
4145 TosaErrorValidator.evAxisLargerRank,
4146 TosaErrorValidator.evAxisSmallerZero,
4147 TosaErrorValidator.evWrongInputType,
4148 TosaErrorValidator.evWrongInputList,
4149 TosaErrorValidator.evWrongOutputList,
4150 TosaErrorValidator.evWrongRank,
4151 ),
4152 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004153 "reshape": {
4154 "op": Op.RESHAPE,
4155 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004156 "build_fcn": (
4157 build_reshape,
4158 TosaTensorGen.tgBasic,
4159 TosaTensorValuesGen.tvgDefault,
4160 TosaArgGen.agReshape,
4161 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004162 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004163 "error_if_validators": (
4164 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4165 TosaErrorValidator.evWrongInputType,
4166 TosaErrorValidator.evWrongOutputType,
4167 TosaErrorValidator.evWrongInputList,
4168 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00004169 TosaErrorValidator.evReshapeOutputSizeMultiInference,
4170 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004171 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004172 },
4173 "reverse": {
4174 "op": Op.REVERSE,
4175 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004176 "build_fcn": (
4177 build_reverse,
4178 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004179 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004180 TosaArgGen.agAxis,
4181 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004182 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004183 "error_if_validators": (
4184 TosaErrorValidator.evAxisSmallerZero,
4185 TosaErrorValidator.evAxisLargerRank,
4186 TosaErrorValidator.evWrongInputType,
4187 TosaErrorValidator.evWrongOutputType,
4188 TosaErrorValidator.evWrongInputList,
4189 TosaErrorValidator.evWrongOutputList,
4190 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004191 },
4192 "slice": {
4193 "op": Op.SLICE,
4194 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004195 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004196 "build_fcn": (
4197 build_slice,
4198 TosaTensorGen.tgBasic,
4199 TosaTensorValuesGen.tvgDefault,
4200 TosaArgGen.agSlice,
4201 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004202 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004203 "error_if_validators": (
4204 TosaErrorValidator.evStartSmallerZero,
4205 TosaErrorValidator.evSizeSmallerEqualZero,
4206 TosaErrorValidator.evStartSizeOutsideBounds,
4207 TosaErrorValidator.evSizeOutputShapeMismatch,
4208 TosaErrorValidator.evInputSizeStartLengthMismatch,
4209 TosaErrorValidator.evWrongRank,
4210 TosaErrorValidator.evWrongInputType,
4211 TosaErrorValidator.evWrongOutputType,
4212 TosaErrorValidator.evWrongInputList,
4213 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004214 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004215 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004216 },
4217 "tile": {
4218 "op": Op.TILE,
4219 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004220 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004221 "build_fcn": (
4222 build_tile,
4223 TosaTensorGen.tgBasic,
4224 TosaTensorValuesGen.tvgDefault,
4225 TosaArgGen.agTile,
4226 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004227 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004228 "error_if_validators": (
4229 TosaErrorValidator.evWrongInputType,
4230 TosaErrorValidator.evWrongOutputType,
4231 TosaErrorValidator.evWrongInputList,
4232 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004233 TosaErrorValidator.evRankMismatch,
4234 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004235 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004236 },
4237 "transpose": {
4238 "op": Op.TRANSPOSE,
4239 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004240 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004241 "build_fcn": (
4242 build_transpose,
4243 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004244 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004245 TosaArgGen.agTranspose,
4246 ),
4247 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004248 "error_if_validators": (
4249 TosaErrorValidator.evIndexOutsideBounds,
4250 TosaErrorValidator.evIndexUsedTwice,
4251 TosaErrorValidator.evWrongInputType,
4252 TosaErrorValidator.evWrongOutputType,
4253 TosaErrorValidator.evWrongInputList,
4254 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004255 TosaErrorValidator.evWrongRank,
4256 TosaErrorValidator.evRankMismatch,
4257 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004258 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004259 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004260 # Data nodes
4261 "const": {
4262 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004263 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004264 "build_fcn": (
4265 build_const,
4266 TosaTensorGen.tgBasic,
4267 TosaTensorValuesGen.tvgDefault,
4268 None,
4269 ),
Luke Hutton65872422023-02-20 10:33:04 +00004270 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004271 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004272 "identity": {
4273 "op": Op.IDENTITY,
4274 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004275 "build_fcn": (
4276 build_unary,
4277 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004278 TosaTensorValuesGen.tvgLazyGenDefault,
4279 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004280 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004281 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004282 "data_gen": {
4283 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4284 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004285 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004286 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004287 "gather": {
4288 "op": Op.GATHER,
4289 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4290 "operands": (1, 0),
4291 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004292 "build_fcn": (
4293 build_gather,
4294 TosaTensorGen.tgBasic,
4295 TosaTensorValuesGen.tvgDefault,
4296 None,
4297 ),
James Ward24dbc422022-10-19 12:20:31 +01004298 "types": (
4299 DType.INT8,
4300 DType.INT16,
4301 DType.INT32,
4302 DType.FP16,
4303 DType.BF16,
4304 DType.FP32,
4305 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004306 "error_if_validators": (
4307 TosaErrorValidator.evWrongInputType,
4308 TosaErrorValidator.evWrongOutputType,
4309 TosaErrorValidator.evWrongInputList,
4310 TosaErrorValidator.evWrongOutputList,
4311 TosaErrorValidator.evWrongRank,
4312 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004313 },
4314 "scatter": {
4315 "op": Op.SCATTER,
4316 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004317 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004318 "operands": (2, 0),
4319 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004320 "build_fcn": (
4321 build_scatter,
4322 TosaTensorGen.tgScatter,
4323 TosaTensorValuesGen.tvgDefault,
4324 None,
4325 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004326 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004327 "error_if_validators": (
4328 TosaErrorValidator.evWrongInputType,
4329 TosaErrorValidator.evWrongOutputType,
4330 TosaErrorValidator.evWrongInputList,
4331 TosaErrorValidator.evWrongOutputList,
4332 TosaErrorValidator.evWrongRank,
4333 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004334 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004335 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004336 "resize": {
4337 "op": Op.RESIZE,
4338 "operands": (1, 0),
4339 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004340 "build_fcn": (
4341 build_resize,
4342 TosaTensorGen.tgNHWC,
4343 TosaTensorValuesGen.tvgDefault,
4344 TosaArgGen.agResize,
4345 ),
James Ward24dbc422022-10-19 12:20:31 +01004346 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004347 "invalid_test_validators": (
4348 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004349 ),
4350 "error_if_validators": (
4351 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004352 TosaErrorValidator.evScaleSmallerEqualZero,
4353 TosaErrorValidator.evScaleNLargerMax,
4354 TosaErrorValidator.evScaleDLargerMax,
4355 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004356 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004357 TosaErrorValidator.evBorderSmallerMin,
4358 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004359 TosaErrorValidator.evWrongInputType,
4360 TosaErrorValidator.evWrongOutputType,
4361 TosaErrorValidator.evWrongRank,
4362 TosaErrorValidator.evWrongInputList,
4363 TosaErrorValidator.evWrongOutputList,
4364 TosaErrorValidator.evBatchMismatch,
4365 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004366 TosaErrorValidator.evResizeOutputShapeMismatch,
4367 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004368 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004369 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004370 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004371 "cast": {
4372 "op": Op.CAST,
4373 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004374 "build_fcn": (
4375 build_cast,
4376 TosaTensorGen.tgBasic,
4377 TosaTensorValuesGen.tvgDefault,
4378 TosaArgGen.agCast,
4379 ),
James Ward8b390432022-08-12 20:48:56 +01004380 "types": (
4381 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004382 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004383 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004384 DType.INT8,
4385 DType.INT16,
4386 DType.INT32,
4387 DType.BOOL,
4388 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004389 "error_if_validators": (
4390 TosaErrorValidator.evWrongInputType,
4391 TosaErrorValidator.evWrongOutputType,
4392 TosaErrorValidator.evWrongInputList,
4393 TosaErrorValidator.evWrongOutputList,
4394 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004395 },
4396 "rescale": {
4397 "op": Op.RESCALE,
4398 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004399 "build_fcn": (
4400 build_rescale,
4401 TosaTensorGen.tgBasic,
4402 TosaTensorValuesGen.tvgDefault,
4403 TosaArgGen.agRescale,
4404 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004405 "types": [
4406 DType.UINT8,
4407 DType.INT8,
4408 DType.INT16,
4409 DType.INT32,
4410 DType.INT48,
4411 DType.UINT16,
4412 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004413 "error_if_validators": (
4414 TosaErrorValidator.evInputZeroPointNotZero,
4415 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004416 TosaErrorValidator.evU16InputZeroPointNotValid,
4417 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004418 TosaErrorValidator.evScaleTrue,
4419 TosaErrorValidator.evScaleNotTrue,
4420 TosaErrorValidator.evWrongInputType,
4421 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004422 TosaErrorValidator.evWrongInputList,
4423 TosaErrorValidator.evWrongOutputList,
4424 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004425 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004426 # Custom
4427 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004428 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004429 # Two varients of cond_if, one that generates one of two constant tensors (no
4430 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4431 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004432 "cond_if_const": {
4433 "op": Op.COND_IF,
4434 "operands": (0, 2),
4435 "build_fcn": (
4436 build_cond_if_const,
4437 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004438 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004439 TosaArgGen.agCondIf,
4440 ),
4441 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004442 "error_if_validators": (
4443 TosaErrorValidator.evOutputListThenGraphMismatch,
4444 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004445 TosaErrorValidator.evCondIfCondNotMatchingBool,
4446 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004447 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004448 },
4449 "cond_if_binary": {
4450 "op": Op.COND_IF,
4451 "operands": (2, 0),
4452 "build_fcn": (
4453 build_cond_if_binary,
4454 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004455 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004456 TosaArgGen.agCondIf,
4457 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004458 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004459 "error_if_validators": (
4460 TosaErrorValidator.evInputListThenGraphMismatch,
4461 TosaErrorValidator.evInputListElseGraphMismatch,
4462 TosaErrorValidator.evOutputListThenGraphMismatch,
4463 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004464 TosaErrorValidator.evCondIfCondNotMatchingBool,
4465 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004466 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004467 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004468 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004469 "while_loop": {
4470 "op": Op.WHILE_LOOP,
4471 "operands": (0, 1),
4472 "build_fcn": (
4473 build_while_loop,
4474 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004475 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004476 TosaArgGen.agWhileLoop,
4477 ),
4478 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004479 "error_if_validators": (
4480 TosaErrorValidator.evInputListOutputListMismatch,
4481 TosaErrorValidator.evInputListCondGraphMismatch,
4482 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4483 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4484 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004485 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004486 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004487 },
Luke Hutton57287132023-02-06 14:54:18 +00004488 "fft2d": {
4489 "op": Op.FFT2D,
4490 "operands": (2, 0),
4491 "rank": (3, 3),
4492 "build_fcn": (
4493 build_fft2d,
4494 TosaTensorGen.tgFFT2d,
4495 TosaTensorValuesGen.tvgDefault,
4496 TosaArgGen.agFFT2d,
4497 ),
4498 "types": [DType.FP32],
4499 "error_if_validators": (
4500 TosaErrorValidator.evWrongInputType,
4501 TosaErrorValidator.evWrongOutputType,
4502 TosaErrorValidator.evWrongInputList,
4503 TosaErrorValidator.evWrongOutputList,
4504 TosaErrorValidator.evWrongRank,
4505 TosaErrorValidator.evBatchMismatch,
4506 TosaErrorValidator.evKernelNotPowerOfTwo,
4507 TosaErrorValidator.evFFTInputShapeMismatch,
4508 TosaErrorValidator.evFFTOutputShapeMismatch,
4509 ),
4510 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004511 "rfft2d": {
4512 "op": Op.RFFT2D,
4513 "operands": (1, 0),
4514 "rank": (3, 3),
4515 "build_fcn": (
4516 build_rfft2d,
4517 TosaTensorGen.tgRFFT2d,
4518 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004519 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004520 ),
4521 "types": [DType.FP32],
4522 "error_if_validators": (
4523 TosaErrorValidator.evWrongInputType,
4524 TosaErrorValidator.evWrongOutputType,
4525 TosaErrorValidator.evWrongInputList,
4526 TosaErrorValidator.evWrongOutputList,
4527 TosaErrorValidator.evWrongRank,
4528 TosaErrorValidator.evBatchMismatch,
4529 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004530 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004531 ),
4532 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004533 }
4534
Kevin Cheng550ccc52021-03-03 11:21:43 -08004535
Eric Kunzee5e26762020-10-13 16:11:07 -07004536class OutputShaper:
4537 # Methods in this class compute the expected output shape and datatype
4538 # for common classes of operations
4539 def __init__(self):
4540 pass
4541
4542 # These methods return arguments that can be used for
4543 # creating a new output tensor
4544 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004545 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4546 if error_name != ErrorIf.RankMismatch:
4547 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004548 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004549
4550 shape = []
4551 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004552 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004553 shape.append(b.shape[i])
4554 else:
4555 shape.append(a.shape[i])
4556
Jerry Ge135c9552023-05-23 20:59:32 +00004557 fuzz_idx = rng.integers(0, len(a.shape))
4558 if error_name == ErrorIf.DimensionMismatch:
4559 shape[fuzz_idx] += 1
4560
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004561 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004562 all_dtypes = [
4563 DType.INT8,
4564 DType.INT16,
4565 DType.INT32,
4566 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004567 DType.FP16,
4568 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004569 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004570 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004571 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4572 outputDType = rng.choice(wrong_dtypes)
4573 else:
4574 outputDType = a.dtype
4575
4576 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004577
4578 @staticmethod
4579 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004580 assert len(a.shape) == len(b.shape)
4581 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004582
4583 shape = []
4584 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004585 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004586 shape.append(a.shape[i])
4587
Kevin Cheng550ccc52021-03-03 11:21:43 -08004588 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004589
4590 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004591 def unaryOp(ser, rng, a, error_name=None):
4592 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004593 all_dtypes = [
4594 DType.INT8,
4595 DType.INT16,
4596 DType.INT32,
4597 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004598 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004599 DType.FP16,
4600 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004601 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004602 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4603 outputDType = rng.choice(wrong_dtypes)
4604 else:
4605 outputDType = a.dtype
4606
4607 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004608
4609 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004610 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004611 if error_name != ErrorIf.RankMismatch:
4612 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004613 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004614
4615 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004616 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004617 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004618 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4619 else:
4620 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004621
Jerry Ge135c9552023-05-23 20:59:32 +00004622 fuzz_idx = rng.integers(0, len(a.shape))
4623 if error_name == ErrorIf.DimensionMismatch:
4624 shape[fuzz_idx] += 1
4625
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004626 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004627 all_dtypes = [
4628 DType.INT8,
4629 DType.INT16,
4630 DType.INT32,
4631 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004632 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004633 DType.FP16,
4634 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004635 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004636 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4637 outputDType = rng.choice(wrong_dtypes)
4638 else:
4639 outputDType = a.dtype
4640
4641 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004642
4643 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004644 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004645 if error_name != ErrorIf.RankMismatch:
4646 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004647 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004648
4649 # Do broadcast
4650 shape = []
4651 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004652 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004653 shape.append(b.shape[i])
4654 else:
4655 shape.append(a.shape[i])
4656
Jerry Ge135c9552023-05-23 20:59:32 +00004657 fuzz_idx = rng.integers(0, len(a.shape))
4658 if error_name == ErrorIf.DimensionMismatch:
4659 shape[fuzz_idx] += 1
4660
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004661 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004662 wrong_dtypes = [
4663 DType.INT8,
4664 DType.INT16,
4665 DType.INT32,
4666 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004667 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004668 DType.FP16,
4669 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004670 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004671 outputDType = rng.choice(wrong_dtypes)
4672 else:
4673 outputDType = DType.BOOL
4674
4675 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004676
4677 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004678 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004679 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004680 if error_name not in [
4681 ErrorIf.AxisSmallerZero,
4682 ErrorIf.AxisLargerRank,
4683 ErrorIf.ShapeOfAxisNotOne,
4684 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004685 shape[axis] = 1
4686 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4687 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004688
Matthew Haddond6ce7252021-09-29 15:35:44 +01004689 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004690 all_dtypes = [
4691 DType.INT8,
4692 DType.INT16,
4693 DType.INT32,
4694 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004695 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004696 DType.FP16,
4697 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004698 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004699 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4700 outputDType = rng.choice(wrong_dtypes)
4701 else:
4702 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004703
Matthew Haddond6ce7252021-09-29 15:35:44 +01004704 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004705
4706 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004707 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004708 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004709
4710 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4711 del shape[axis]
4712
4713 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4714 remove = rng.choice([True, False])
4715 if remove and len(shape) > 1:
4716 del shape[0]
4717 else:
4718 shape.append(1)
4719 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4720 for i in range(len(shape)):
4721 shape[i] = shape[i] + rng.integers(1, 10)
4722
4723 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004724 all_dtypes = [
4725 DType.INT8,
4726 DType.INT16,
4727 DType.INT32,
4728 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004729 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004730 DType.FP16,
4731 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004732 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004733 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4734 outputDType = rng.choice(wrong_dtypes)
4735 else:
4736 outputDType = DType.INT32
4737
4738 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004739
4740 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004741 def conv2dOp(
4742 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4743 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004744
4745 # IFM: NHWC
4746 # Filter: OHWI
4747 # OFM: NHWC
4748
Kevin Cheng550ccc52021-03-03 11:21:43 -08004749 h = (
4750 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004751 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004752 + padding[0]
4753 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004754 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004755 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004756
Kevin Cheng550ccc52021-03-03 11:21:43 -08004757 w = (
4758 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004759 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004760 + padding[2]
4761 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004762 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004763 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004764
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004765 if error_name == ErrorIf.ConvOutputShapeMismatch:
4766 choices = [1, 2, 3]
4767 change = rng.choice(choices)
4768 # increment in multiples of stride to not hit non-integer error case
4769 if change in [1, 3]:
4770 h = h + (rng.choice(choices) * strides[0])
4771 if change in [2, 3]:
4772 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004773
Eric Kunzee5e26762020-10-13 16:11:07 -07004774 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4775
James Ward8b390432022-08-12 20:48:56 +01004776 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004777 # Pick some potentially correct output dtype if input type is incorrect
4778 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004779 else:
James Ward8b390432022-08-12 20:48:56 +01004780 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004781
4782 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004783 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004784 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004785 else:
4786 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004787 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004788 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004789
Kevin Cheng550ccc52021-03-03 11:21:43 -08004790 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004791
4792 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004793 def conv3dOp(
4794 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4795 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004796
4797 # IFM: NDHWC
4798 # Filter: ODHWI
4799 # OFM: NDHWC
4800
4801 d = (
4802 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004803 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004804 + padding[0]
4805 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004806 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004807 ) // strides[0] + 1
4808
4809 h = (
4810 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004811 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004812 + padding[2]
4813 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004814 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004815 ) // strides[1] + 1
4816
4817 w = (
4818 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004819 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004820 + padding[4]
4821 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004822 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004823 ) // strides[2] + 1
4824
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004825 if error_name == ErrorIf.ConvOutputShapeMismatch:
4826 choices = [1, 2, 3, 4]
4827 change = rng.choice(choices)
4828 # increment in multiples of stride to not hit non-integer error case
4829 if change in [1, 4]:
4830 d = d + (rng.choice(choices) * strides[0])
4831 if change in [2, 4]:
4832 h = h + (rng.choice(choices) * strides[1])
4833 if change in [3, 4]:
4834 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004835
Kevin Cheng1533b852021-09-01 12:51:58 -07004836 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4837
James Ward8b390432022-08-12 20:48:56 +01004838 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004839 # Pick some potentially correct output dtype if input type is incorrect
4840 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004841 else:
James Ward8b390432022-08-12 20:48:56 +01004842 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004843
4844 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004845 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004846 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004847 else:
4848 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004849 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004850 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004851
4852 return ser.addOutput(ofm_shape, out_dtype)
4853
4854 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004855 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004856 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004857 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004858 # IFM: NHWC
4859 # Filter: HWCM
4860 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004861
Kevin Cheng550ccc52021-03-03 11:21:43 -08004862 h = (
4863 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004864 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004865 + padding[0]
4866 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004867 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004868 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004869
Kevin Cheng550ccc52021-03-03 11:21:43 -08004870 w = (
4871 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004872 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004873 + padding[2]
4874 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004875 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004876 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004877
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004878 if error_name == ErrorIf.ConvOutputShapeMismatch:
4879 choices = [1, 2, 3]
4880 change = rng.choice(choices)
4881 # increment in multiples of stride to not hit non-integer error case
4882 if change in [1, 3]:
4883 h = h + (rng.choice(choices) * strides[0])
4884 if change in [2, 3]:
4885 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004886
Eric Kunzee5e26762020-10-13 16:11:07 -07004887 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4888
James Ward8b390432022-08-12 20:48:56 +01004889 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004890 # Pick some potentially correct output dtype if input type is incorrect
4891 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004892 else:
James Ward8b390432022-08-12 20:48:56 +01004893 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004894
4895 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004896 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004897 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004898 else:
4899 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004900 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004901 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004902
Kevin Cheng550ccc52021-03-03 11:21:43 -08004903 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004904
4905 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004906 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004907 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004908 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004909 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004910 h = 1
4911 w = 1
4912 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004913 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4914 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004915
4916 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004917 choices = [1, 2, 3]
4918 change = rng.choice(choices)
4919 # increment in multiples of stride to not hit non-integer error case
4920 if change in [1, 3]:
4921 h = h + (rng.choice(choices) * stride[0])
4922 if change in [2, 3]:
4923 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004924 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004925
4926 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004927 all_dtypes = [
4928 DType.INT8,
4929 DType.INT16,
4930 DType.INT32,
4931 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004932 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004933 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004934 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004935 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004936 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4937 outputDType = rng.choice(wrong_dtypes)
4938 else:
4939 outputDType = ifm.dtype
4940
4941 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004942
4943 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004944 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004945 # input: N, IC
4946 # filter: OC, IC
4947 # output: N, OC
4948
4949 output_shape = [input.shape[0], filter.shape[0]]
4950
James Ward8b390432022-08-12 20:48:56 +01004951 # Validated in arg_gen (also invalidated for ErrorIf)
4952 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004953
Kevin Cheng550ccc52021-03-03 11:21:43 -08004954 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004955
4956 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004957 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004958 # a: N, H, C
4959 # b: N, C, W
4960 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004961
Kevin Cheng2d60f002021-06-09 14:18:32 -07004962 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004963
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004964 if error_name == ErrorIf.WrongOutputType:
4965 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004966 incorrect_types = (
4967 DType.INT4,
4968 DType.INT8,
4969 DType.INT16,
4970 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004971 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004972 DType.FP16,
4973 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004974 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004975 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004976 incorrect_types = (
4977 DType.INT4,
4978 DType.INT8,
4979 DType.INT16,
4980 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004981 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004982 DType.FP16,
4983 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004984 )
James Ward24dbc422022-10-19 12:20:31 +01004985 elif (
4986 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4987 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004988 incorrect_types = (
4989 DType.INT4,
4990 DType.INT8,
4991 DType.INT16,
4992 DType.INT32,
4993 DType.INT48,
4994 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004995 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004996 elif error_name == ErrorIf.WrongInputType:
4997 # Pick some potentially correct output dtype if input type is incorrect
4998 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004999 else:
James Ward8b390432022-08-12 20:48:56 +01005000 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005001
Kevin Cheng550ccc52021-03-03 11:21:43 -08005002 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005003
5004 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005005 def concatOp(ser, rng, axis, inputs, error_name=None):
5006 input1 = inputs[0]
5007 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005008
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005009 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005010 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005011 if not (
5012 # unable to concat tensors of different ranks
5013 error_name == ErrorIf.ConcatInputRankMismatch
5014 # unable to concat tensors along an invalid axis
5015 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005016 ):
5017 for tensor in remaining_inputs:
5018 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005019
Matthew Haddon01c359d2021-10-15 16:30:48 +01005020 if error_name == ErrorIf.ConcatShapeSumMismatch:
5021 output_shape[axis] += rng.integers(5, 10)
5022
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005023 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005024 all_dtypes = {
5025 DType.INT8,
5026 DType.INT16,
5027 DType.INT32,
5028 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005029 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005030 DType.FP16,
5031 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005032 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005033 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5034 outputDType = rng.choice(wrong_dtypes)
5035 else:
5036 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005037
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005038 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005039
5040 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005041 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005042
5043 output_shape = a.shape.copy()
5044
5045 for i in range(len(output_shape)):
5046 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5047
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005048 if error_name == ErrorIf.PadOutputShapeMismatch:
5049 bad_dim = rng.choice(range(len(output_shape)))
5050 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005051 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005052 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005053
Matthew Haddone807aae2021-10-11 18:12:58 +01005054 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005055 all_dtypes = [
5056 DType.INT8,
5057 DType.INT16,
5058 DType.INT32,
5059 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005060 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005061 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005062 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005063 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005064 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5065 outputDType = rng.choice(wrong_dtypes)
5066 else:
5067 outputDType = a.dtype
5068
5069 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005070
5071 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005072 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00005073 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00005074
5075 if error_name == ErrorIf.WrongOutputType:
5076 all_dtypes = [
5077 DType.INT8,
5078 DType.INT16,
5079 DType.INT32,
5080 DType.INT48,
5081 DType.FP32,
5082 DType.FP16,
5083 DType.BF16,
5084 ]
5085 wrong_dtypes = list(set(all_dtypes))
5086 outputDType = rng.choice(wrong_dtypes)
5087 else:
5088 outputDType = DType.SHAPE
5089
5090 return ser.addOutput(output_shape, outputDType)
5091
5092 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005093 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005094 output_shape = shape.copy()
5095
Matthew Haddone807aae2021-10-11 18:12:58 +01005096 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5097 for i in range(len(output_shape)):
5098 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5099
5100 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005101 all_dtypes = [
5102 DType.INT8,
5103 DType.INT16,
5104 DType.INT32,
5105 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005106 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005107 DType.FP16,
5108 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005109 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005110 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5111 outputDType = rng.choice(wrong_dtypes)
5112 else:
5113 outputDType = a.dtype
5114
5115 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005116
5117 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005118 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005119
Matthew Haddone807aae2021-10-11 18:12:58 +01005120 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005121 all_dtypes = [
5122 DType.INT8,
5123 DType.INT16,
5124 DType.INT32,
5125 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005126 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005127 DType.FP16,
5128 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005129 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005130 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005131 outputDType = rng.choice(wrong_dtypes)
5132 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005133 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005134
Luke Huttona4e48ca2023-02-22 11:53:48 +00005135 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005136 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005137 for index in range(len(output_shape)):
5138 if output_shape[index] <= 2:
5139 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5140 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005141 output_shape[index] = output_shape[index] + rng.choice(
5142 [-2, -1, 1, 2]
5143 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005144 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5145 output_shape = input.shape.copy()
5146 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005147 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005148
5149 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005150
5151 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005152 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005153
5154 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005155 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005156
5157 for i in range(len(output_shape)):
5158 output_shape[i] = a.shape[i] * multiples[i]
5159
Luke Huttona4e48ca2023-02-22 11:53:48 +00005160 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005161 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005162
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005163 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005164 all_dtypes = [
5165 DType.INT8,
5166 DType.INT16,
5167 DType.INT32,
5168 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005169 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005170 DType.FP16,
5171 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005172 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005173 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5174 outputDType = rng.choice(wrong_dtypes)
5175 else:
5176 outputDType = a.dtype
5177
5178 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005179
5180 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005181 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005182 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005183
Kevin Cheng550ccc52021-03-03 11:21:43 -08005184 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005185
Luke Huttona4e48ca2023-02-22 11:53:48 +00005186 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005187 for i in range(len(output_shape)):
5188 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005189
Luke Huttona4e48ca2023-02-22 11:53:48 +00005190 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5191 for i in range(len(output_shape)):
5192 output_shape[i] += rng.integers(1, 10)
5193 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005194 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005195
Matthew Haddone807aae2021-10-11 18:12:58 +01005196 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005197 all_dtypes = [
5198 DType.INT8,
5199 DType.INT16,
5200 DType.INT32,
5201 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005202 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005203 DType.FP16,
5204 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005205 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005206 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5207 outputDType = rng.choice(wrong_dtypes)
5208 else:
5209 outputDType = a.dtype
5210
5211 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005212
5213 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005214 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005215 if error_name != ErrorIf.WrongRank:
5216 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005217 assert len(indices.shape) == 2
5218 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005219
Kevin Cheng77d0f762020-11-24 10:26:32 -08005220 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5221
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005222 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005223 all_dtypes = [
5224 DType.INT8,
5225 DType.INT16,
5226 DType.INT32,
5227 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005228 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005229 DType.FP16,
5230 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005231 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005232 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5233 outputDType = rng.choice(wrong_dtypes)
5234 else:
5235 outputDType = values.dtype
5236
5237 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005238
5239 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005240 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005241 if error_name != ErrorIf.WrongRank:
5242 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005243 assert len(indices.shape) == 2
5244 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005245 assert values_in.shape[0] == indices.shape[0] # N
5246 assert input.shape[1] == indices.shape[1] # W
5247 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005248
5249 output_shape = values_in.shape
5250
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005251 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005252 all_dtypes = [
5253 DType.INT8,
5254 DType.INT16,
5255 DType.INT32,
5256 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005257 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005258 DType.FP16,
5259 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005260 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005261 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5262 outputDType = rng.choice(wrong_dtypes)
5263 else:
5264 outputDType = values_in.dtype
5265
5266 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005267
5268 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005269 def tableOp(ser, rng, input, error_name=None):
5270 # Same shape as the input, dtype dependent on input dtype
5271 if error_name != ErrorIf.WrongInputType:
5272 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005273 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005274 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005275 wrong_dtypes = [
5276 DType.INT8,
5277 DType.INT16,
5278 DType.INT32,
5279 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005280 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005281 DType.FP16,
5282 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005283 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005284 wrong_dtypes.remove(output_dtype)
5285 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005286 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005287
5288 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005289 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005290 serializer,
5291 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005292 input,
5293 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005294 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005295 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005296 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005297 input_dtype,
5298 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005299 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005300 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005301 # Calculate OH, OW
5302 scale_y_n = scale[0]
5303 scale_y_d = scale[1]
5304 scale_x_n = scale[2]
5305 scale_x_d = scale[3]
5306 if error_name == ErrorIf.ScaleSmallerEqualZero:
5307 scale_y_n = max(scale_y_n, 1)
5308 scale_y_d = max(scale_y_d, 1)
5309 scale_x_n = max(scale_x_n, 1)
5310 scale_x_d = max(scale_x_d, 1)
5311
5312 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5313 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5314
5315 if error_name is not None:
5316 # Make sure the output tensor is valid, which can occur when
5317 # scale, offset or border have been changed for ERROR_IFs
5318 oh = max(oh, 1)
5319 ow = max(ow, 1)
5320 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005321 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5322 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005323
5324 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5325 choices = [1, 2, 3]
5326 change = rng.choice(choices)
5327 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5328 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005329 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005330 oh -= scale_y_d
5331 assert oh > 0 # Should have been caught in agResize
5332 else:
5333 oh += scale_y_d
5334 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005335 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005336 ow -= scale_x_d
5337 assert ow > 0 # Should have been caught in agResize
5338 else:
5339 ow += scale_x_d
5340
Matthew Haddon848efb42021-09-09 12:30:53 +01005341 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005342 output_dims = [
5343 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005344 oh,
5345 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005346 input.shape[0],
5347 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005348 elif error_name == ErrorIf.BatchMismatch:
5349 output_dims = [
5350 input.shape[0] + rng.integers(1, 10),
5351 oh,
5352 ow,
5353 input.shape[3],
5354 ]
5355 elif error_name == ErrorIf.ChannelMismatch:
5356 output_dims = [
5357 input.shape[0],
5358 oh,
5359 ow,
5360 input.shape[3] + rng.integers(1, 10),
5361 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005362 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005363 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005364
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005365 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005366
5367 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005368 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005369 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005370
5371 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005372 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005373 if error_name == ErrorIf.ConvOutputShapeMismatch:
5374 choices = [1, 2, 3]
5375 change = rng.choice(choices)
5376 if change in [1, 3]:
5377 output_shape[1] = output_shape[1] + rng.choice(choices)
5378 if change in [2, 3]:
5379 output_shape[2] = output_shape[2] + rng.choice(choices)
5380
James Ward8b390432022-08-12 20:48:56 +01005381 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005382 # Pick some potentially correct output dtype if input type is incorrect
5383 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005384 else:
James Ward8b390432022-08-12 20:48:56 +01005385 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005386
5387 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005388 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005389 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005390 else:
5391 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005392 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005393 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005394
Kevin Cheng550ccc52021-03-03 11:21:43 -08005395 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005396
5397 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005398 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5399 outputs = []
5400
5401 assert ifm1.dtype == ifm2.dtype
5402 input_dtype = ifm1.dtype
5403
5404 if error_name != ErrorIf.FFTInputShapeMismatch:
5405 assert ifm1.shape == ifm2.shape
5406
5407 input_shape = ifm1.shape
5408 if error_name != ErrorIf.WrongRank:
5409 assert len(input_shape) == 3
5410
5411 output_shape = input_shape.copy()
5412 output_dtype = input_dtype
5413
5414 if error_name == ErrorIf.WrongOutputType:
5415 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005416 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005417 output_dtype = rng.choice(wrong_dtypes)
5418 elif error_name == ErrorIf.BatchMismatch:
5419 output_shape[0] += rng.integers(1, 10)
5420 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5421 modify_dim = rng.choice([1, 2])
5422 output_shape[modify_dim] += rng.integers(1, 10)
5423
5424 outputs.append(serializer.addOutput(output_shape, output_dtype))
5425 outputs.append(serializer.addOutput(output_shape, output_dtype))
5426 return outputs
5427
5428 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005429 def rfft2dOp(serializer, rng, value, error_name=None):
5430 outputs = []
5431
5432 input_shape = value.shape
5433 if error_name != ErrorIf.WrongRank:
5434 assert len(input_shape) == 3
5435
5436 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5437
5438 output_dtype = value.dtype
5439 if error_name == ErrorIf.WrongOutputType:
5440 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005441 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005442 output_dtype = rng.choice(wrong_dtypes)
5443 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005444 output_shape[0] += rng.integers(1, 10)
5445 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5446 modify_dim = rng.choice([1, 2])
5447 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005448
5449 outputs.append(serializer.addOutput(output_shape, output_dtype))
5450 outputs.append(serializer.addOutput(output_shape, output_dtype))
5451 return outputs