blob: a83ead0e5af3bf00e086c2bfe3adbc84786b49c3 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
40 TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6)
41 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
69 vals.append(v)
70 return tuple(sorted(vals))
71
72 self.random_float_range = {}
73 for dtype in (DType.FP32, DType.FP16, DType.BF16):
74 self.random_float_range[dtype] = convertFPRange(
75 args.tensor_fp_value_range,
76 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
77 )
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 def createSerializer(self, opName, testPath):
80 self.testPath = os.path.join(opName, testPath)
81
82 fullPath = os.path.join(self.basePath, self.testPath)
83 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010084 # Embed const data in the flatbuffer
85 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010086 if self.args.lazy_data_gen:
87 # Lazy data generation - so make constants files
88 constMode = ts.ConstMode.INPUTS
89 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 constMode = ts.ConstMode.EMBED_DUMP
91 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070092
93 def getSerializer(self):
94 return self.ser
95
Jeremy Johnson1271c442023-09-05 11:39:26 +010096 def serialize(self, testName, metaData=None):
97 path = Path(self.basePath) / self.testPath
98
99 # Write out TOSA flatbuffer binary
100 path_fb = path / f"{testName}.tosa"
101 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 fd.write(self.ser.serialize())
103
Jeremy Johnson1271c442023-09-05 11:39:26 +0100104 # Get JSON descriptor from serializer
105 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
106
107 if metaData:
108 # Add extra meta data to desc.json
109 desc["meta"] = metaData
110
111 # Validate desc.json before we output it
112 self.descSchemaValidator.validate_config(desc)
113
114 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100115 if "data_gen" in metaData:
116 if self.args.lazy_data_gen:
117 # Output datagen meta data as CPP data
118 path_md = path / f"{testName}_meta_data_gen.cpp"
119 with path_md.open("w") as fd:
120 fd.write(TOSA_AUTOGENERATED_HEADER)
121 fd.write("// Test meta data for data generation setup\n\n")
122 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
123 json.dump(metaData["data_gen"], fd)
124 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100125 if "compliance" in metaData:
126 # Output datagen meta data as CPP data
127 path_md = path / f"{testName}_meta_compliance.cpp"
128 with path_md.open("w") as fd:
129 fd.write(TOSA_AUTOGENERATED_HEADER)
130 fd.write("// Test meta data for compliance validation\n\n")
131 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
132 json.dump(metaData["compliance"], fd)
133 fd.write(')";\n\n')
134
135 # Write desc.json
136 path_desc = path / "desc.json"
137 with path_desc.open("w") as fd:
138 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700139
Matthew Haddon74567092021-07-16 15:38:20 +0100140 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000141 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100142 seed = self.random_seed + 1
143 self.rng = np.random.default_rng(seed)
144
Jeremy Johnson1271c442023-09-05 11:39:26 +0100145 def getDTypeRange(self, dtype, high_inclusive=False):
146 # Returns dtype value range boundaries (low, high)
147 # The high boundary is excluded in the range
148 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100149 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100150 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 elif dtype == DType.BOOL:
152 rng = (0, 2)
153 elif dtype == DType.UINT8:
154 rng = (0, 256)
155 elif dtype == DType.UINT16:
156 rng = (0, 65536)
157 elif dtype == DType.INT4:
158 # TOSA specific INT4 weight range from -7 to 7
159 rng = (-7, 8)
160 elif dtype == DType.INT8:
161 rng = (-128, 128)
162 elif dtype == DType.INT16:
163 rng = (-32768, 32768)
164 elif dtype in (DType.INT32, DType.SHAPE):
165 # restricting too large value for SHAPE
166 rng = (-(1 << 31), (1 << 31))
167 elif dtype == DType.INT48:
168 rng = (-(1 << 47), (1 << 47))
169 else:
170 raise Exception("Unknown dtype: {}".format(dtype))
171
172 if not high_inclusive:
173 # Exclusive high: low <= range < high
174 return rng
175 else:
176 # Inclusive range: low <= range <= high
177 return (rng[0], rng[1] - 1)
178
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000179 def getRandTensor(self, shape, dtype, data_range=None):
180 if data_range is None:
181 low, high = self.getDTypeRange(dtype)
182 else:
183 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100184
Eric Kunzee5e26762020-10-13 16:11:07 -0700185 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100188 return np.int64(self.rng.integers(low=low, high=high, size=shape))
189 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
190 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
191
192 if dtype == DType.FP16:
193 return np.float16(f_tensor)
194 else:
195 f32_tensor = np.float32(f_tensor)
196 if dtype == DType.BF16:
197 # Floor the last 16 bits of each f32 value
198 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
199 else:
200 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100202 # All other integer types
203 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700204
Kevin Cheng989cb052021-04-28 16:29:44 -0700205 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 placeholders = []
207
Kevin Cheng989cb052021-04-28 16:29:44 -0700208 assert len(shape_list) == len(dtype_list)
209
Jeremy Johnson1271c442023-09-05 11:39:26 +0100210 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100212 if not self.args.lazy_data_gen:
213 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700214 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
216 return placeholders
217
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700219 consts = []
220
Kevin Cheng989cb052021-04-28 16:29:44 -0700221 assert len(shape_list) == len(dtype_list)
222
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100225 if not self.args.lazy_data_gen:
226 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700227 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700228
229 return consts
230
231 def makeShape(self, rank):
232 if self.targetted_shape:
233 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800234 return np.int32(
235 self.rng.integers(
236 low=self.args.tensor_shape_range[0],
237 high=self.args.tensor_shape_range[1],
238 size=rank,
239 )
240 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700241
242 def setTargetShape(self, shape):
243 self.targetted_shape = shape
244
245 def randInt(self, low=0, high=256):
246 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
247
248 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100249 low, high = self.getDTypeRange(dtype)
250
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100251 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100252 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100253 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100254 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100255 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100256 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
257 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 elif dtype == DType.BOOL:
259 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700260 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700261 # Special size
262 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700263
264 return np.int32(self.rng.integers(low, high, size=1))[0]
265
266 def shapeStr(self, shape):
267
268 sStr = []
269 # Convert to strings
270 for i in shape:
271 sStr.append(str(i))
272
Kevin Cheng550ccc52021-03-03 11:21:43 -0800273 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100275 def typeStr(self, dtype):
276 if isinstance(dtype, list) or isinstance(dtype, tuple):
277 assert len(dtype) >= 2
278 strs = [self.typeStr(t) for t in dtype]
279 # Limit types to the first 2 as the 3rd is the accumulator
280 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700281 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100282 if dtype in gtu.DTYPE_ATTRIBUTES:
283 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700284 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100285 raise Exception(
286 "Unknown dtype, cannot convert to string: {}".format(dtype)
287 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700288
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100289 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100290 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100291 if dtype in gtu.DTYPE_ATTRIBUTES:
292 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700293 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100294 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700295
Luke Hutton57287132023-02-06 14:54:18 +0000296 def constrictBatchSize(self, shape):
297 # Limit the batch size unless an explicit target shape set
298 if self.args.max_batch_size and not self.args.target_shapes:
299 shape[0] = min(shape[0], self.args.max_batch_size)
300 return shape
301
James Ward30124a82023-02-02 14:56:33 +0000302 def makeDimension(self):
303 return self.randInt(
304 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
305 )
306
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100307 def tensorComplianceMetaData(
308 self, op, inputType, argsDict, outputTensor, errorName
309 ):
310 if (
311 errorName
312 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
313 or not gtu.dtypeIsSupportedByCompliance(inputType)
314 ):
315 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100316 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100317
Jeremy Johnson1271c442023-09-05 11:39:26 +0100318 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100319 compliance_tens = {
320 "mode": None,
321 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
322 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
323 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100324 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
325 mode = gtu.ComplianceMode.DOT_PRODUCT
326 compliance_tens["dot_product_info"] = {
327 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100328 "ks": int(argsDict["ksb"])
329 if "ksb" in argsDict
330 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100331 }
332 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
333 mode = gtu.ComplianceMode.FP_SPECIAL
334 elif "compliance" in op and "ulp" in op["compliance"]:
335 mode = gtu.ComplianceMode.ULP
336 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
337 elif op["op"] == Op.REDUCE_PRODUCT:
338 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnson9a758382023-11-07 16:27:35 +0000339 elif op["op"] in (Op.EXP, Op.POW):
340 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnson1271c442023-09-05 11:39:26 +0100341 else:
342 mode = gtu.ComplianceMode.EXACT
343 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
344
345 return compliance_tens
346
347 # Build Op functions
348 # Create the output tensor (calling OutputShaper as needed)
349 # Do final tweaks to attributes (if necessary for errorIf)
350 # Add Op into graph
351 # Return resulting tensor information or BuildInfo
352
353 class BuildInfo:
354 """Enhanced build information containing result tensor and associated compliance dict."""
355
356 def __init__(self, resultTensor, complianceDict):
357 self.resultTensor = resultTensor
358 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700359
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000360 def build_unary(
361 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
362 ):
363 assert len(inputs) == 1
364 a = inputs[0]
365 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100366
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000367 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100368
369 # Ensure new output type has correct qinfo
370 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000371 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000372 qinfo = [
373 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000374 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000375 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100376
377 # Invalidate Input/Output list for error if checks.
378 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000379 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100380 pCount, cCount = op["operands"]
381 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000382 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
383 self, error_name, input_list, output_list
384 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100385
Les Bell729b0352021-11-24 10:28:21 +0000386 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100387 self.ser,
388 validator_fcns,
389 error_name,
390 op=op,
391 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000392 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000394 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100395 input_list=input_list,
396 output_list=output_list,
397 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000398 ):
399 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100400
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000401 attr = None
402 if op["op"] == Op.NEGATE:
403 attr = ts.TosaSerializerAttribute()
404 attr.NegateAttribute(qinfo[0], qinfo[1])
405
406 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000407
Jeremy Johnson9a758382023-11-07 16:27:35 +0000408 if op["op"] in (Op.LOG,):
409 # TODO - add compliance support LOG
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000410 compliance = None
411 else:
412 compliance = self.tensorComplianceMetaData(
413 op, a.dtype, args_dict, result_tensor, error_name
414 )
415 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700416
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000417 def build_binary_broadcast(
418 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
419 ):
420 assert len(inputs) == 2
421 a, b = inputs
422 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000423 self.ser, self.rng, a, b, error_name
424 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100425
426 # Invalidate Input/Output list for error if checks.
427 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000428 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100429 pCount, cCount = op["operands"]
430 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000431 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
432 self, error_name, input_list, output_list
433 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100434
Les Bell729b0352021-11-24 10:28:21 +0000435 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100436 self.ser,
437 validator_fcns,
438 error_name,
439 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000440 input1=a,
441 input2=b,
442 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000443 output_dtype=result_tensor.dtype,
444 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100445 input_list=input_list,
446 output_list=output_list,
447 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000448 ):
449 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100450
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000451 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000452
Jeremy Johnson9a758382023-11-07 16:27:35 +0000453 compliance = self.tensorComplianceMetaData(
454 op, a.dtype, args_dict, result_tensor, error_name
455 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000456
457 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700458
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100459 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700460 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000461 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700462 return result_tens
463
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000464 def build_arithmetic_right_shift(
465 self, op, a, b, round, validator_fcns=None, error_name=None
466 ):
467 result_tens = OutputShaper.binaryBroadcastOp(
468 self.ser, self.rng, a, b, error_name
469 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100470
471 # Invalidate Input/Output list for error if checks.
472 input_list = [a.name, b.name]
473 output_list = [result_tens.name]
474 pCount, cCount = op["operands"]
475 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000476 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
477 self, error_name, input_list, output_list
478 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100479
Les Bell729b0352021-11-24 10:28:21 +0000480 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100481 self.ser,
482 validator_fcns,
483 error_name,
484 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000485 input1=a,
486 input2=b,
487 input_dtype=a.dtype,
488 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000489 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100490 input_list=input_list,
491 output_list=output_list,
492 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000493 ):
494 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800495
496 attr = ts.TosaSerializerAttribute()
497 attr.ArithmeticRightShiftAttribute(round)
498
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000499 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800500 return result_tens
501
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100502 def build_mul(
503 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
504 ):
505 assert len(inputs) == 2
506 a, b = inputs
507 shift = args_dict["shift"]
508
509 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000510 self.ser, self.rng, a, b, error_name
511 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700512
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100513 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100514 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100515 result_tensor.setDtype(DType.INT32)
516
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100517 if error_name == ErrorIf.WrongOutputType:
518 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
519 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100520 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100521
522 # Invalidate Input/Output list for error if checks.
523 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100524 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100525 pCount, cCount = op["operands"]
526 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000527 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
528 self, error_name, input_list, output_list
529 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100530
Les Bell729b0352021-11-24 10:28:21 +0000531 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100532 self.ser,
533 validator_fcns,
534 error_name,
535 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000536 input1=a,
537 input2=b,
538 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100539 output_dtype=result_tensor.dtype,
540 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100541 input_list=input_list,
542 output_list=output_list,
543 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000544 ):
545 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700546
Kevin Chengaee1fac2020-11-11 13:54:06 -0800547 attr = ts.TosaSerializerAttribute()
548 attr.MulAttribute(shift)
549
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000550 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100551
552 compliance = self.tensorComplianceMetaData(
553 op, a.dtype, args_dict, result_tensor, error_name
554 )
555
556 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700557
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100558 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
559 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700560
Kevin Chengfe392ce2021-10-18 21:51:55 +0000561 attr = ts.TosaSerializerAttribute()
562 attr.TableAttribute(table)
563
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100564 # Invalidate Input/Output list for error if checks.
565 input_list = [a.name]
566 output_list = [result_tens.name]
567 pCount, cCount = op["operands"]
568 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000569 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
570 self, error_name, input_list, output_list
571 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100572
Les Bell729b0352021-11-24 10:28:21 +0000573 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100574 self.ser,
575 validator_fcns,
576 error_name,
577 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000578 input_shape=a.shape,
579 input_dtype=a.dtype,
580 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000581 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100582 input_list=input_list,
583 output_list=output_list,
584 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000585 ):
586 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100587
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000588 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700589
590 return result_tens
591
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100592 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
593 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
594
595 # Invalidate Input/Output list for error if checks.
596 input_list = [cond.name, a.name, b.name]
597 output_list = [result_tens.name]
598 pCount, cCount = op["operands"]
599 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000600 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
601 self, error_name, input_list, output_list
602 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100603
Les Bell729b0352021-11-24 10:28:21 +0000604 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100605 self.ser,
606 validator_fcns,
607 error_name,
608 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000609 input1=cond,
610 input2=a,
611 input3=b,
612 input_shape=a.shape,
613 input_dtype=a.dtype,
614 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000615 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100616 input_list=input_list,
617 output_list=output_list,
618 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000619 ):
620 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100621
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000622 self.ser.addOperator(
623 op["op"],
624 input_list,
625 output_list,
626 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700627 return result_tens
628
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100629 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000630 result_tens = OutputShaper.binaryComparisonOp(
631 self.ser, self.rng, a, b, error_name
632 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100633
634 # Invalidate Input/Output list for error if checks.
635 input_list = [a.name, b.name]
636 output_list = [result_tens.name]
637 pCount, cCount = op["operands"]
638 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000639 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
640 self, error_name, input_list, output_list
641 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100642
Les Bell729b0352021-11-24 10:28:21 +0000643 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100644 self.ser,
645 validator_fcns,
646 error_name,
647 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000648 input1=a,
649 input2=b,
650 input_shape=a.shape,
651 input_dtype=a.dtype,
652 output_shape=result_tens.shape,
653 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000654 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100655 input_list=input_list,
656 output_list=output_list,
657 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000658 ):
659 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100660
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000661 self.ser.addOperator(
662 op["op"],
663 input_list,
664 output_list,
665 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700666 return result_tens
667
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000668 def build_argmax(
669 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
670 ):
671 assert len(inputs) == 1
672 a = inputs[0]
673 axis = args_dict["axis"]
674 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100675
676 # Invalidate Input/Output list for error if checks.
677 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000678 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100679 pCount, cCount = op["operands"]
680 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000681 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
682 self, error_name, input_list, output_list
683 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100684
Les Bell729b0352021-11-24 10:28:21 +0000685 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100686 self.ser,
687 validator_fcns,
688 error_name,
689 op=op,
690 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000691 input_shape=a.shape,
692 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000693 output_shape=result_tensor.shape,
694 output_dtype=result_tensor.dtype,
695 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100696 input_list=input_list,
697 output_list=output_list,
698 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000699 ):
700 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700701
702 attr = ts.TosaSerializerAttribute()
703 attr.AxisAttribute(axis)
704
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000705 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000706
707 compliance = self.tensorComplianceMetaData(
708 op, inputs[0].dtype, args_dict, result_tensor, error_name
709 )
710 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700711
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000712 def build_pool2d(
713 self,
714 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100715 inputs,
716 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000717 validator_fcns=None,
718 error_name=None,
719 qinfo=None,
720 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100721 assert len(inputs) == 1
722 input = inputs[0]
723 # max_pool has no accum_dtype
724 accum_dtype = (
725 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
726 )
727 stride = args_dict["stride"]
728 pad = args_dict["pad"]
729 kernel = args_dict["kernel"]
730
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000731 result_tens = OutputShaper.pool2dOp(
732 self.ser, self.rng, input, kernel, stride, pad, error_name
733 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100734
735 # Ensure new output type has correct qinfo
736 if error_name == ErrorIf.WrongInputType:
737 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000738 qinfo = [
739 TosaQuantGen.getZeroPoint(self, input.dtype),
740 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
741 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100742
743 # Invalidate Input/Output list for error if checks.
744 input_list = [input.name]
745 output_list = [result_tens.name]
746 pCount, cCount = op["operands"]
747 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000748 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
749 self, error_name, input_list, output_list
750 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100751
Les Bell729b0352021-11-24 10:28:21 +0000752 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100753 self.ser,
754 validator_fcns,
755 error_name,
756 op=op,
757 input_shape=input.shape,
758 input_dtype=input.dtype,
759 output_shape=result_tens.shape,
760 output_dtype=result_tens.dtype,
761 kernel=kernel,
762 stride=stride,
763 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000764 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000765 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100766 input_list=input_list,
767 output_list=output_list,
768 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000769 ):
770 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700771
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000772 if qinfo is None:
773 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700774
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000775 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100776 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000777
778 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700779 return result_tens
780
James Ward8b390432022-08-12 20:48:56 +0100781 def build_maxpool2d(
782 self,
783 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100784 inputs,
785 args_dict,
James Ward8b390432022-08-12 20:48:56 +0100786 validator_fcns=None,
787 error_name=None,
788 qinfo=None,
789 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100790 result_tensor = self.build_pool2d(
James Ward8b390432022-08-12 20:48:56 +0100791 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100792 inputs,
793 args_dict,
James Ward8b390432022-08-12 20:48:56 +0100794 validator_fcns,
795 error_name,
796 qinfo,
797 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100798 compliance = self.tensorComplianceMetaData(
799 op, inputs[0].dtype, args_dict, result_tensor, error_name
800 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100801
802 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100803
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000804 def build_conv2d(
805 self,
806 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100807 inputs,
808 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000809 validator_fcns=None,
810 error_name=None,
811 qinfo=None,
812 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100813 assert len(inputs) == 3
814 ifm, filter, bias = inputs
815 accum_dtype = args_dict["acc_type"]
816 strides = args_dict["stride"]
817 padding = args_dict["pad"]
818 dilations = args_dict["dilation"]
819
Kevin Cheng550ccc52021-03-03 11:21:43 -0800820 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100821 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100822 self.ser,
823 self.rng,
824 ifm,
825 filter,
826 accum_dtype,
827 strides,
828 padding,
829 dilations,
830 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000831 )
832
833 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000834 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
835 DType.INT8,
836 DType.UINT8,
837 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000838 qinfo = [
839 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100840 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000841 ]
Les Bell0e027d42021-11-09 14:42:14 +0000842
843 # Invalidate Input/Output list for error_if checks.
844 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100845 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000846 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000847 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
848 self, error_name, input_list, output_list
849 )
Les Bell0e027d42021-11-09 14:42:14 +0000850
Les Bell729b0352021-11-24 10:28:21 +0000851 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000852 self.ser,
853 validator_fcns,
854 error_name,
855 op=op,
856 input_dtype=ifm.dtype,
857 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100858 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000859 qinfo=qinfo,
860 input_list=input_list,
861 num_operands=num_operands,
862 output_list=output_list,
863 pad=padding,
864 stride=strides,
865 dilation=dilations,
866 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100867 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100868 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000869 ):
870 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700871
Tai Lyd3797f02023-11-15 23:06:19 +0000872 # TODO - Test local_bound, for now set local bound attribute to False
873 local_bound = False
874
Eric Kunzee5e26762020-10-13 16:11:07 -0700875 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000876 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700877
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000878 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100879
880 compliance = self.tensorComplianceMetaData(
881 op, ifm.dtype, args_dict, result_tensor, error_name
882 )
883
884 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700885
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000886 def build_conv3d(
887 self,
888 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100889 inputs,
890 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000891 validator_fcns=None,
892 error_name=None,
893 qinfo=None,
894 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100895 assert len(inputs) == 3
896 ifm, filter, bias = inputs
897 accum_dtype = args_dict["acc_type"]
898 strides = args_dict["stride"]
899 padding = args_dict["pad"]
900 dilations = args_dict["dilation"]
901
Kevin Cheng1533b852021-09-01 12:51:58 -0700902 assert len(padding) == 6
903 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100904 self.ser,
905 self.rng,
906 ifm,
907 filter,
908 accum_dtype,
909 strides,
910 padding,
911 dilations,
912 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000913 )
914
915 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000916 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
917 DType.INT8,
918 DType.UINT8,
919 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000920 qinfo = [
921 TosaQuantGen.getZeroPoint(self, ifm.dtype),
922 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
923 ]
Les Bell0e027d42021-11-09 14:42:14 +0000924
925 # Invalidate Input/Output list for error_if checks.
926 input_list = [ifm.name, filter.name, bias.name]
927 output_list = [result_tens.name]
928 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000929 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
930 self, error_name, input_list, output_list
931 )
Les Bell0e027d42021-11-09 14:42:14 +0000932
Les Bell729b0352021-11-24 10:28:21 +0000933 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000934 self.ser,
935 validator_fcns,
936 error_name,
937 op=op,
938 input_dtype=ifm.dtype,
939 weight_dtype=filter.dtype,
940 output_dtype=result_tens.dtype,
941 qinfo=qinfo,
942 input_list=input_list,
943 num_operands=num_operands,
944 output_list=output_list,
945 pad=padding,
946 stride=strides,
947 dilation=dilations,
948 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100949 weight_shape=filter.shape,
950 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000951 ):
952 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700953
Tai Lyd3797f02023-11-15 23:06:19 +0000954 # TODO - Test local_bound, for now set local bound attribute to False
955 local_bound = False
956
Kevin Cheng1533b852021-09-01 12:51:58 -0700957 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000958 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700959
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000960 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700961 return result_tens
962
Kevin Cheng550ccc52021-03-03 11:21:43 -0800963 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000964 self,
965 op,
966 ifm,
967 filter,
968 bias,
James Ward8b390432022-08-12 20:48:56 +0100969 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000970 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700971 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000972 output_shape,
973 validator_fcns=None,
974 error_name=None,
975 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800976 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700977 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000978 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100979 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000980 )
Les Bell0e027d42021-11-09 14:42:14 +0000981
982 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000983 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
984 DType.INT8,
985 DType.UINT8,
986 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000987 qinfo = [
988 TosaQuantGen.getZeroPoint(self, ifm.dtype),
989 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
990 ]
Les Bell0e027d42021-11-09 14:42:14 +0000991
992 # Invalidate Input/Output list for error_if checks.
993 input_list = [ifm.name, filter.name, bias.name]
994 output_list = [result_tens.name]
995 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000996 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
997 self, error_name, input_list, output_list
998 )
Les Bell0e027d42021-11-09 14:42:14 +0000999
Les Bell729b0352021-11-24 10:28:21 +00001000 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001001 self.ser,
1002 validator_fcns,
1003 error_name,
1004 op=op,
1005 input_dtype=ifm.dtype,
1006 weight_dtype=filter.dtype,
1007 output_dtype=result_tens.dtype,
1008 qinfo=qinfo,
1009 input_list=input_list,
1010 num_operands=num_operands,
1011 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001012 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001013 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001014 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001015 weight_shape=filter.shape,
1016 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001017 ):
1018 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001019
Tai Lyd3797f02023-11-15 23:06:19 +00001020 # TODO - Test local_bound, for now set local bound attribute to False
1021 local_bound = False
1022
Eric Kunzee5e26762020-10-13 16:11:07 -07001023 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001024 attr.TransposeConvAttribute(
1025 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1026 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001027
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001028 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001029 return result_tens
1030
Kevin Cheng550ccc52021-03-03 11:21:43 -08001031 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001032 self,
1033 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001034 inputs,
1035 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001036 validator_fcns=None,
1037 error_name=None,
1038 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001039 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001040 assert len(inputs) == 3
1041 ifm, filter, bias = inputs
1042 accum_dtype = args_dict["acc_type"]
1043 strides = args_dict["stride"]
1044 padding = args_dict["pad"]
1045 dilations = args_dict["dilation"]
1046
Kevin Cheng550ccc52021-03-03 11:21:43 -08001047 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001048 self.ser,
1049 self.rng,
1050 ifm,
1051 filter,
1052 accum_dtype,
1053 strides,
1054 padding,
1055 dilations,
1056 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001057 )
1058
1059 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001060 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1061 DType.INT8,
1062 DType.UINT8,
1063 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001064 qinfo = [
1065 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1066 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1067 ]
Les Bell0e027d42021-11-09 14:42:14 +00001068
1069 # Invalidate Input/Output list for error_if checks.
1070 input_list = [ifm.name, filter.name, bias.name]
1071 output_list = [result_tens.name]
1072 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001073 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1074 self, error_name, input_list, output_list
1075 )
Les Bell0e027d42021-11-09 14:42:14 +00001076
Les Bell729b0352021-11-24 10:28:21 +00001077 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001078 self.ser,
1079 validator_fcns,
1080 error_name,
1081 op=op,
1082 input_dtype=ifm.dtype,
1083 weight_dtype=filter.dtype,
1084 output_dtype=result_tens.dtype,
1085 qinfo=qinfo,
1086 input_list=input_list,
1087 num_operands=num_operands,
1088 output_list=output_list,
1089 pad=padding,
1090 stride=strides,
1091 dilation=dilations,
1092 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001093 weight_shape=filter.shape,
1094 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001095 ):
1096 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001097
Tai Lyd3797f02023-11-15 23:06:19 +00001098 # TODO - Test local_bound, for now set local bound attribute to False
1099 local_bound = False
1100
Eric Kunzee5e26762020-10-13 16:11:07 -07001101 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001102 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001103
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001104 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001105 return result_tens
1106
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001107 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001108 self,
1109 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001110 inputs,
1111 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001112 validator_fcns=None,
1113 error_name=None,
1114 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001115 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001116 assert len(inputs) == 3
1117 ifm, filter, bias = inputs
1118 accum_dtype = args_dict["acc_type"]
1119
1120 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001121 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001122 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001123
1124 # Invalidate Input/Output list for error if checks.
1125 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001126 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001127 pCount, cCount = op["operands"]
1128 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001129 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1130 self, error_name, input_list, output_list
1131 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001132
Les Bell729b0352021-11-24 10:28:21 +00001133 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001134 self.ser,
1135 validator_fcns,
1136 error_name,
1137 op=op,
1138 input_shape=ifm.shape,
1139 input_dtype=ifm.dtype,
1140 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001141 output_shape=result_tensor.shape,
1142 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001143 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001144 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001145 input_list=input_list,
1146 output_list=output_list,
1147 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001148 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001149 ):
1150 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001151
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001152 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001153 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001154
1155 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001156
1157 compliance = self.tensorComplianceMetaData(
1158 op, ifm.dtype, args_dict, result_tensor, error_name
1159 )
1160
1161 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001162
James Ward8b390432022-08-12 20:48:56 +01001163 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001164 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001165 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001166 assert len(inputs) == 2
1167 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001168 accum_dtype = args_dict["acc_type"]
1169 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001170 self.ser, self.rng, a, b, accum_dtype, error_name
1171 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001172
1173 # Invalidate Input/Output list for error if checks.
1174 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001175 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001176 pCount, cCount = op["operands"]
1177 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001178 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1179 self, error_name, input_list, output_list
1180 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001181
Les Bell729b0352021-11-24 10:28:21 +00001182 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001183 self.ser,
1184 validator_fcns,
1185 error_name,
1186 op=op,
1187 input_shape=a.shape,
1188 input_dtype=a.dtype,
1189 input2_shape=b.shape,
1190 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001191 output_shape=result_tensor.shape,
1192 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001193 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001194 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001195 input_list=input_list,
1196 output_list=output_list,
1197 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001198 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001199 ):
1200 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001201
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001202 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001203 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001204
1205 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001206
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001207 compliance = self.tensorComplianceMetaData(
1208 op, a.dtype, args_dict, result_tensor, error_name
1209 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001210
1211 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001212
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001213 def build_reduce(
1214 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1215 ):
1216 assert len(inputs) == 1
1217 a = inputs[0]
1218 axis = args_dict["axis"]
1219 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001220
1221 # Invalidate Input/Output list for error if checks.
1222 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001223 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001224 pCount, cCount = op["operands"]
1225 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001226 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1227 self, error_name, input_list, output_list
1228 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001229
Les Bell729b0352021-11-24 10:28:21 +00001230 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001231 self.ser,
1232 validator_fcns,
1233 error_name,
1234 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001235 axis=axis,
1236 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001237 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001238 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001239 output_dtype=result_tensor.dtype,
1240 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001241 input_list=input_list,
1242 output_list=output_list,
1243 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001244 ):
1245 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001246
1247 attr = ts.TosaSerializerAttribute()
1248 attr.AxisAttribute(axis)
1249
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001250 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001251
1252 if op["op"] == Op.REDUCE_PRODUCT:
1253 # TODO: Add compliance support!
1254 compliance = None
1255 else:
1256 compliance = self.tensorComplianceMetaData(
1257 op, a.dtype, args_dict, result_tensor, error_name
1258 )
1259
1260 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001261
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001262 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1263 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001264
Jeremy Johnson18e26662021-07-22 16:15:29 +01001265 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001266
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001267 if error_name == ErrorIf.MaxSmallerMin:
1268 # Make sure the numbers are different to invoke this error
1269 while v[0] == v[1]:
1270 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1271 max_val = min(v)
1272 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001273 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001274 max_val = max(v)
1275 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001276
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001277 # Invalidate Input/Output list for error if checks.
1278 input_list = [a.name]
1279 output_list = [result_tens.name]
1280 pCount, cCount = op["operands"]
1281 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001282 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1283 self, error_name, input_list, output_list
1284 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001285
Les Bell729b0352021-11-24 10:28:21 +00001286 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001287 self.ser,
1288 validator_fcns,
1289 error_name,
1290 op=op,
1291 max_val=max_val,
1292 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001293 input_shape=a.shape,
1294 output_shape=result_tens.shape,
1295 input_dtype=a.dtype,
1296 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001297 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001298 input_list=input_list,
1299 output_list=output_list,
1300 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001301 ):
1302 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001303
1304 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001305 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1306 if a.dtype == DType.FP16:
1307 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1308 min_val = min_val.astype(np.float32)
1309 max_val = max_val.astype(np.float32)
1310
1311 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001312 else:
James Ward34071252022-12-07 15:48:47 +00001313 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001314
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001315 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001316 return result_tens
1317
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001318 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1319 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001320 attr = ts.TosaSerializerAttribute()
1321
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001322 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001323
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001324 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001325 return result_tens
1326
1327 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001328 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1329 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001330
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001331 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001332 return result_tens
1333
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001334 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1335 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1336
1337 # Invalidate Input/Output list for error if checks.
1338 input_list = [a.name]
1339 output_list = [result_tens.name]
1340 pCount, cCount = op["operands"]
1341 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001342 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1343 self, error_name, input_list, output_list
1344 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001345
Les Bell729b0352021-11-24 10:28:21 +00001346 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001347 self.ser,
1348 validator_fcns,
1349 error_name,
1350 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001351 input_shape=a.shape,
1352 output_shape=result_tens.shape,
1353 input_dtype=a.dtype,
1354 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001355 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001356 input_list=input_list,
1357 output_list=output_list,
1358 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001359 ):
1360 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001361
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001362 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001363 return result_tens
1364
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001365 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1366 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1367
1368 # Invalidate Input/Output list for error if checks.
1369 input_list = [a.name]
1370 output_list = [result_tens.name]
1371 pCount, cCount = op["operands"]
1372 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001373 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1374 self, error_name, input_list, output_list
1375 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001376
Les Bell729b0352021-11-24 10:28:21 +00001377 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001378 self.ser,
1379 validator_fcns,
1380 error_name,
1381 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001382 input_shape=a.shape,
1383 output_shape=result_tens.shape,
1384 input_dtype=a.dtype,
1385 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001386 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001387 input_list=input_list,
1388 output_list=output_list,
1389 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001390 ):
1391 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001392
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001393 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001394 return result_tens
1395
Won Jeon78155c62023-06-10 00:20:04 +00001396 def build_erf(self, op, a, validator_fcns=None, error_name=None):
1397 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1398
1399 # Invalidate Input/Output list for error if checks.
1400 input_list = [a.name]
1401 output_list = [result_tens.name]
1402 pCount, cCount = op["operands"]
1403 num_operands = pCount + cCount
1404 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1405 self, error_name, input_list, output_list
1406 )
1407
1408 if not TosaErrorValidator.evValidateErrorIfs(
1409 self.ser,
1410 validator_fcns,
1411 error_name,
1412 op=op,
1413 input_shape=a.shape,
1414 output_shape=result_tens.shape,
1415 input_dtype=a.dtype,
1416 output_dtype=result_tens.dtype,
1417 result_tensors=[result_tens],
1418 input_list=input_list,
1419 output_list=output_list,
1420 num_operands=num_operands,
1421 ):
1422 return None
1423
1424 self.ser.addOperator(op["op"], input_list, output_list)
1425 return result_tens
1426
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001427 def build_concat(
1428 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1429 ):
1430 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001431 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001432 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001433
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001434 result_tensor = OutputShaper.concatOp(
1435 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001436 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001437
Matthew Haddon818ab902021-07-27 09:12:49 +01001438 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001439 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001440 input_tensor_names.append(tensor.name)
1441
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001442 # Invalidate Input/Output list for error if checks.
1443 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001444 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001445 pCount, cCount = op["operands"]
1446 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001447 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1448 self, error_name, input_list, output_list
1449 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001450
Les Bell729b0352021-11-24 10:28:21 +00001451 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001452 self.ser,
1453 validator_fcns,
1454 error_name,
1455 op=op,
1456 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001457 input_shape=inputs[0].shape,
1458 output_shape=result_tensor.shape,
1459 input_dtype=inputs[0].dtype,
1460 output_dtype=result_tensor.dtype,
1461 inputs=inputs,
1462 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001463 input_list=input_list,
1464 output_list=output_list,
1465 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001466 ):
1467 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001468
1469 attr = ts.TosaSerializerAttribute()
1470 attr.AxisAttribute(axis)
1471
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001472 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001473 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001474
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001475 def build_pad(
1476 self,
1477 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001478 inputs,
1479 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001480 validator_fcns=None,
1481 error_name=None,
1482 qinfo=None,
1483 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001484 assert len(inputs) == 1
1485 a = inputs[0]
1486 padding = args_dict["pad"]
1487 pad_const_int = args_dict["pad_const_int"]
1488 pad_const_float = args_dict["pad_const_fp"]
1489
1490 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001491
Kevin Chengfe392ce2021-10-18 21:51:55 +00001492 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001493 attr.PadAttribute(
1494 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1495 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001496
Matthew Haddone807aae2021-10-11 18:12:58 +01001497 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001498 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001499 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001500 pCount, cCount = op["operands"]
1501 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001502 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1503 self, error_name, input_list, output_list
1504 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001505
Les Bell729b0352021-11-24 10:28:21 +00001506 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001507 self.ser,
1508 validator_fcns,
1509 error_name,
1510 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001511 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001512 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001513 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001514 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001515 pad=padding,
1516 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001517 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001518 input_list=input_list,
1519 output_list=output_list,
1520 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001521 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001522 ):
1523 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001524
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001525 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001526
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001527 compliance = self.tensorComplianceMetaData(
1528 op, a.dtype, args_dict, result_tensor, error_name
1529 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001530
1531 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001532
Won Jeona21b2e82023-08-10 10:33:01 +00001533 def build_dim(
1534 self,
1535 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001536 inputs,
1537 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001538 validator_fcns=None,
1539 error_name=None,
1540 qinfo=None,
1541 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001542 assert len(inputs) == 1
1543 a = inputs[0]
1544 axis = args_dict["axis"]
1545 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001546
1547 # Invalidate Input/Output list for error if checks.
1548 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001549 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001550 pCount, cCount = op["operands"]
1551 num_operands = pCount + cCount
1552 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1553 self, error_name, input_list, output_list
1554 )
1555
1556 if not TosaErrorValidator.evValidateErrorIfs(
1557 self.ser,
1558 validator_fcns,
1559 error_name,
1560 op=op,
1561 axis=axis,
1562 input_shape=a.shape,
1563 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001564 output_shape=result_tensor.shape,
1565 output_dtype=result_tensor.dtype,
1566 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001567 input_list=input_list,
1568 output_list=output_list,
1569 num_operands=num_operands,
1570 ):
1571 return None
1572
1573 attr = ts.TosaSerializerAttribute()
1574 attr.AxisAttribute(axis)
1575
1576 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001577 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001578
Matthew Haddone807aae2021-10-11 18:12:58 +01001579 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001580 result_tens = OutputShaper.reshapeOp(
1581 self.ser, self.rng, a, newShape, error_name
1582 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001583
1584 # Invalidate Input/Output list for error if checks.
1585 input_list = [a.name]
1586 output_list = [result_tens.name]
1587 pCount, cCount = op["operands"]
1588 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001589 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1590 self, error_name, input_list, output_list
1591 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001592
Les Bell729b0352021-11-24 10:28:21 +00001593 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001594 self.ser,
1595 validator_fcns,
1596 error_name,
1597 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001598 input_shape=a.shape,
1599 output_shape=result_tens.shape,
1600 input_dtype=a.dtype,
1601 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001602 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001603 input_list=input_list,
1604 output_list=output_list,
1605 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001606 ):
1607 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001608
1609 attr = ts.TosaSerializerAttribute()
1610 attr.ReshapeAttribute(newShape)
1611
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001612 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001613 return result_tens
1614
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001615 def build_reverse(
1616 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1617 ):
1618 assert len(inputs) == 1
1619 a = inputs[0]
1620 axis = args_dict["axis"]
1621 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001622
1623 # Invalidate Input/Output list for error if checks.
1624 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001625 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001626 pCount, cCount = op["operands"]
1627 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001628 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1629 self, error_name, input_list, output_list
1630 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001631
Les Bell729b0352021-11-24 10:28:21 +00001632 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001633 self.ser,
1634 validator_fcns,
1635 error_name,
1636 op=op,
1637 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001638 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001639 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001640 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001641 output_dtype=result_tensor.dtype,
1642 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001643 input_list=input_list,
1644 output_list=output_list,
1645 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001646 ):
1647 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001648
1649 attr = ts.TosaSerializerAttribute()
1650 attr.AxisAttribute(axis)
1651
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001652 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001653 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001654
Matthew Haddone807aae2021-10-11 18:12:58 +01001655 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1656 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001657
Kevin Chengfe392ce2021-10-18 21:51:55 +00001658 attr = ts.TosaSerializerAttribute()
1659 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001660
Matthew Haddone807aae2021-10-11 18:12:58 +01001661 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001662 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001663 output_list = [result_tens.name]
1664 pCount, cCount = op["operands"]
1665 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001666 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1667 self, error_name, input_list, output_list
1668 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001669
Les Bell729b0352021-11-24 10:28:21 +00001670 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001671 self.ser,
1672 validator_fcns,
1673 error_name,
1674 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001675 input_shape=a.shape,
1676 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001677 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001678 input_dtype=a.dtype,
1679 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001680 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001681 input_list=input_list,
1682 output_list=output_list,
1683 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001684 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001685 ):
1686 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001687
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001688 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001689 return result_tens
1690
Matthew Haddone807aae2021-10-11 18:12:58 +01001691 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001692 result_tens = OutputShaper.sliceOp(
1693 self.ser, self.rng, a, start, size, error_name
1694 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001695
1696 # Invalidate Input/Output list for error if checks.
1697 input_list = [a.name]
1698 output_list = [result_tens.name]
1699 pCount, cCount = op["operands"]
1700 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001701 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1702 self, error_name, input_list, output_list
1703 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001704
Les Bell729b0352021-11-24 10:28:21 +00001705 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001706 self.ser,
1707 validator_fcns,
1708 error_name,
1709 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001710 input_shape=a.shape,
1711 output_shape=result_tens.shape,
1712 input_dtype=a.dtype,
1713 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001714 start=start,
1715 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001716 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001717 input_list=input_list,
1718 output_list=output_list,
1719 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001720 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001721 ):
1722 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001723
1724 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001725 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001726
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001727 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001728 return result_tens
1729
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001730 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1731 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1732
1733 # Invalidate Input/Output list for error if checks.
1734 input_list = [a.name]
1735 output_list = [result_tens.name]
1736 pCount, cCount = op["operands"]
1737 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001738 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1739 self, error_name, input_list, output_list
1740 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001741
Les Bell729b0352021-11-24 10:28:21 +00001742 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001743 self.ser,
1744 validator_fcns,
1745 error_name,
1746 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001747 input_shape=a.shape,
1748 output_shape=result_tens.shape,
1749 input_dtype=a.dtype,
1750 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001751 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001752 input_list=input_list,
1753 output_list=output_list,
1754 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001755 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001756 ):
1757 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001758
1759 attr = ts.TosaSerializerAttribute()
1760 attr.TileAttribute(multiples)
1761
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001762 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001763 return result_tens
1764
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001765 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001766
1767 # Create a new indicies tensor
1768 # here with data that doesn't exceed the dimensions of the values tensor
1769
Kevin Cheng550ccc52021-03-03 11:21:43 -08001770 K = values.shape[1] # K
1771 W = self.randInt(
1772 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1773 ) # W
1774 indicies_arr = np.int32(
1775 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1776 ) # (N, W)
1777 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001778
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001779 result_tens = OutputShaper.gatherOp(
1780 self.ser, self.rng, values, indicies, error_name
1781 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001782
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001783 # Invalidate Input/Output list for error if checks.
1784 input_list = [values.name, indicies.name]
1785 output_list = [result_tens.name]
1786 pCount, cCount = op["operands"]
1787 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001788 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1789 self, error_name, input_list, output_list
1790 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001791
Les Bell729b0352021-11-24 10:28:21 +00001792 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001793 self.ser,
1794 validator_fcns,
1795 error_name,
1796 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001797 input_shape=values.shape,
1798 output_shape=result_tens.shape,
1799 input_dtype=values.dtype,
1800 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001801 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001802 input_list=input_list,
1803 output_list=output_list,
1804 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001805 ):
1806 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001807
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001808 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001809
1810 return result_tens
1811
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001812 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001813
1814 # Create a new indicies tensor
1815 # here with data that doesn't exceed the dimensions of the values_in tensor
1816
Kevin Cheng550ccc52021-03-03 11:21:43 -08001817 K = values_in.shape[1] # K
1818 W = input.shape[1] # W
1819 indicies_arr = np.int32(
1820 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1821 ) # (N, W)
1822 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001823
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001824 result_tens = OutputShaper.scatterOp(
1825 self.ser, self.rng, values_in, indicies, input, error_name
1826 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001827
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001828 # Invalidate Input/Output list for error if checks.
1829 input_list = [values_in.name, indicies.name, input.name]
1830 output_list = [result_tens.name]
1831 pCount, cCount = op["operands"]
1832 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001833 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1834 self, error_name, input_list, output_list
1835 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001836
Les Bell729b0352021-11-24 10:28:21 +00001837 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001838 self.ser,
1839 validator_fcns,
1840 error_name,
1841 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001842 input_shape=values_in.shape,
1843 output_shape=result_tens.shape,
1844 input_dtype=values_in.dtype,
1845 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001846 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001847 input_list=input_list,
1848 output_list=output_list,
1849 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001850 ):
1851 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001852
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001853 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001854
Kevin Cheng77d0f762020-11-24 10:26:32 -08001855 return result_tens
1856
Kevin Cheng550ccc52021-03-03 11:21:43 -08001857 def build_resize(
1858 self,
1859 op,
1860 input,
1861 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001862 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001863 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001864 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001865 input_dtype,
1866 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001867 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001868 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001869 ):
1870 result_tens = OutputShaper.resizeOp(
1871 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001872 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001873 input,
1874 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001875 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001876 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001877 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001878 input_dtype,
1879 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001880 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001881 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001882
Matthew Haddon848efb42021-09-09 12:30:53 +01001883 # Invalidate Input/Output list for error if checks.
1884 input_list = [input.name]
1885 output_list = [result_tens.name]
1886 pCount, cCount = op["operands"]
1887 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001888 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1889 self, error_name, input_list, output_list
1890 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001891
Les Bell729b0352021-11-24 10:28:21 +00001892 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001893 self.ser,
1894 validator_fcns,
1895 error_name,
1896 op=op,
1897 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001898 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001899 input_dtype=input_dtype,
1900 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001901 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001902 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001903 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001904 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001905 input_list=input_list,
1906 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001907 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001908 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001909 ):
1910 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001911
Eric Kunzee5e26762020-10-13 16:11:07 -07001912 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001913
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001914 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001915
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001916 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001917 return result_tens
1918
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001919 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1920 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1921 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001922 self.ser.addOperator(
1923 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1924 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001925 return result_tens
1926
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001927 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001928 self.ser.addOutputTensor(val)
1929 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001930
1931 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001932 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001933 result_tens = OutputShaper.typeConversionOp(
1934 self.ser, self.rng, val, out_dtype, error_name
1935 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001936
1937 # Invalidate Input/Output list for error if checks.
1938 input_list = [val.name]
1939 output_list = [result_tens.name]
1940 pCount, cCount = op["operands"]
1941 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001942 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1943 self, error_name, input_list, output_list
1944 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001945
Les Bell729b0352021-11-24 10:28:21 +00001946 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001947 self.ser,
1948 validator_fcns,
1949 error_name,
1950 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001951 input_shape=val.shape,
1952 output_shape=result_tens.shape,
1953 input_dtype=val.dtype,
1954 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001955 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001956 input_list=input_list,
1957 output_list=output_list,
1958 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001959 ):
1960 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001961
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001962 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001963 return result_tens
1964
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001965 def build_rescale(
1966 self,
1967 op,
1968 val,
1969 out_dtype,
1970 scale32,
1971 double_round,
1972 per_channel,
1973 validator_fcns,
1974 error_name,
1975 ):
1976 result_tens = OutputShaper.typeConversionOp(
1977 self.ser, self.rng, val, out_dtype, error_name
1978 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001979
1980 if per_channel:
1981 nc = val.shape[-1]
1982 else:
1983 nc = 1
1984
1985 in_type_width = self.typeWidth(val.dtype)
1986 out_type_width = self.typeWidth(out_dtype)
1987
Kevin Cheng3a478572021-01-22 17:21:02 -08001988 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001989 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001990 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001991 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001992 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001993 in_type_width += 1
1994 elif error_name in [
1995 ErrorIf.InputZeroPointNotZero,
1996 ErrorIf.U16InputZeroPointNotValid,
1997 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001998 input_zp = self.randInt(-128, 128)
1999 if input_zp == 0:
2000 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002001 in_type_width += 1
2002 elif val.dtype == DType.UINT16:
2003 # Must come after ErrorIf.U16InputZeroPointNotValid check
2004 input_zp = self.rng.choice([0, 32768])
2005 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002006 else:
2007 input_zp = 0
2008
Kevin Cheng3a478572021-01-22 17:21:02 -08002009 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002010 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002011 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002012 elif out_dtype == DType.UINT8:
2013 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002014 out_type_width += 1
2015 elif error_name in [
2016 ErrorIf.OutputZeroPointNotZero,
2017 ErrorIf.U16OutputZeroPointNotValid,
2018 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002019 output_zp = self.randInt(-128, 128)
2020 if output_zp == 0:
2021 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002022 out_type_width += 1
2023 elif out_dtype == DType.UINT16:
2024 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2025 output_zp = self.rng.choice([0, 32768])
2026 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002027 else:
2028 output_zp = 0
2029
2030 # Calculate scale based on:
2031 # scale = a *(2^output_width)/(2^input_width))
2032
2033 a = np.float32(self.rng.random(size=[nc]))
2034 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2035
2036 if scale32:
2037 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002038 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002039 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2040 else:
2041 # Cap the scaling at 2^15 - 1 for scale16
2042 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2043
Kevin Cheng550ccc52021-03-03 11:21:43 -08002044 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002045
2046 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2047 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002048 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2049 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002050
2051 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002052 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2053 scale_arr[i], scale32
2054 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002055 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2056 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002057
Kevin Cheng550ccc52021-03-03 11:21:43 -08002058 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002059 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002060 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002061 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002062 assert val.placeholderFilename
2063 values = np.load(
2064 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2065 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002066 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2067 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2068 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2069 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002070 if not np.all(np.array_equal(values, val_adj)):
2071 # Values changed so overwrite file with new values
2072 np.save(
2073 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2074 val_adj,
2075 False,
2076 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002077
Matthew Haddonc2025212021-10-08 21:21:05 +01002078 # Invalidate Input/Output list for error if checks.
2079 input_list = [val.name]
2080 output_list = [result_tens.name]
2081 pCount, cCount = op["operands"]
2082 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002083 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2084 self, error_name, input_list, output_list
2085 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002086
2087 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002088 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002089 self.ser,
2090 validator_fcns,
2091 error_name,
2092 op=op,
2093 input_dtype=val.dtype,
2094 output_dtype=out_dtype,
2095 input_shape=val.shape,
2096 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002097 scale32=scale32,
2098 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002099 input_list=input_list,
2100 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002101 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002102 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002103 ):
2104 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002105
Eric Kunzee5e26762020-10-13 16:11:07 -07002106 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002107 attr.RescaleAttribute(
2108 input_zp,
2109 output_zp,
2110 multiplier_arr,
2111 shift_arr,
2112 scale32,
2113 double_round,
2114 per_channel,
2115 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002116
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002117 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002118 return result_tens
2119
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002120 def _get_condition_tensor(self, op, cond, error_name):
2121 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002122 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002123 else:
2124 cond_type = DType.BOOL
2125 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2126 choice = self.rng.choice([1, 2])
2127 if choice == 1:
2128 cond_shape = [2]
2129 else:
2130 cond_shape = [1, 2]
2131 else:
2132 # Must be of size 1 (rank 0)
2133 cond_shape = []
2134 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2135 return cond_tens
2136
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002137 def build_cond_if_const(
2138 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2139 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002140 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002141 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002142 # and fill them with const nodes for the body.
2143
2144 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002145 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002146
2147 # Make then/else tensors
2148 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002149
2150 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002151 if error_name in [
2152 ErrorIf.CondIfOutputListThenGraphMismatch,
2153 ErrorIf.CondIfOutputListElseGraphMismatch,
2154 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002155 incorrect_shape = deepcopy(then_tens.shape)
2156 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002157 incorrect_shape[i] += (
2158 self.rng.choice([-3, -2, 2, 3])
2159 if incorrect_shape[i] > 3
2160 else self.rng.choice([1, 2, 4])
2161 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002162 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2163
Jeremy Johnson18e26662021-07-22 16:15:29 +01002164 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2165 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002166
2167 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002168 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002169
2170 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002171 then_block = "THEN_BLOCK"
2172 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002173 attr = ts.TosaSerializerAttribute()
2174 attr.CondIfAttribute(then_block, else_block)
2175
2176 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002177 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002178
Jerry Ge9e94af82022-10-27 09:57:00 -07002179 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002180 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002181 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2182 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2183 else:
2184 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002185 self.ser.addOutputTensor(then_tens)
2186
Jerry Ge9e94af82022-10-27 09:57:00 -07002187 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002188 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2189 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2190 else:
2191 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002192 self.ser.addOutputTensor(else_tens)
2193
Les Bell729b0352021-11-24 10:28:21 +00002194 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002195 self.ser,
2196 validator_fcns,
2197 error_name,
2198 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002199 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002200 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002201 ):
2202 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002203
Eric Kunzee5e26762020-10-13 16:11:07 -07002204 return result_tens
2205
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002206 def build_cond_if_binary(
2207 self, op, a, b, cond, validator_fcns=None, error_name=None
2208 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002209 # For cond_if with a binary op in the then/else blocks, take a and b and
2210 # alternately add or subtract them based on the condition
2211
2212 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002213 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002214
Kevin Cheng550ccc52021-03-03 11:21:43 -08002215 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002216
2217 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002218 then_block = "THEN_BLOCK"
2219 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002220 attr = ts.TosaSerializerAttribute()
2221 attr.CondIfAttribute(then_block, else_block)
2222
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002223 if error_name in [
2224 ErrorIf.CondIfInputListThenGraphMismatch,
2225 ErrorIf.CondIfInputListElseGraphMismatch,
2226 ErrorIf.CondIfOutputListElseGraphMismatch,
2227 ErrorIf.CondIfOutputListThenGraphMismatch,
2228 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002229 incorrect_shape = a.shape.copy()
2230 for i in range(len(incorrect_shape)):
2231 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2232 incorrect_block_input = deepcopy(a)
2233 incorrect_block_input.shape = incorrect_shape
2234
Eric Kunzee5e26762020-10-13 16:11:07 -07002235 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002236 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002237 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002238 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002239
James Ward24dbc422022-10-19 12:20:31 +01002240 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002241 then_op, else_op = Op.ADD, Op.SUB
2242 elif a.dtype in (DType.INT8, DType.INT16):
2243 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2244 else:
2245 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002246
Les Bell6040b4d2021-10-11 12:50:31 +01002247 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002248 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002249 if (
2250 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2251 and block == then_block
2252 ) or (
2253 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2254 and block == else_block
2255 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002256 self.ser.addInputTensor(incorrect_block_input)
2257 self.ser.addInputTensor(b)
2258 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002259 elif (
2260 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2261 and block == then_block
2262 ) or (
2263 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2264 and block == else_block
2265 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002266 self.ser.addInputTensor(a)
2267 self.ser.addInputTensor(b)
2268 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2269 else:
2270 self.ser.addInputTensor(a)
2271 self.ser.addInputTensor(b)
2272 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002273 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002274
Les Bell729b0352021-11-24 10:28:21 +00002275 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002276 self.ser,
2277 validator_fcns,
2278 error_name,
2279 op=op,
2280 a=a,
2281 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002282 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002283 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002284 ):
2285 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002286
Eric Kunzee5e26762020-10-13 16:11:07 -07002287 return result_tens
2288
Matthew Haddon630c17c2021-10-14 15:05:41 +01002289 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002290 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002291
Kevin Cheng550ccc52021-03-03 11:21:43 -08002292 cond_block = "COND_BLOCK"
2293 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002294
2295 attr = ts.TosaSerializerAttribute()
2296 attr.WhileLoopAttribute(cond_block, body_block)
2297
2298 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002299 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002300 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002301 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002302
2303 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002304 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2305 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002306 if error_name == ErrorIf.InputListOutputListMismatch:
2307 incorrect_acc = deepcopy(acc)
2308 for i in range(len(incorrect_acc.shape)):
2309 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2310 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2311 else:
2312 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002313
2314 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002315 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002316 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002317 [iter.name, a.name, acc.name],
2318 [iter_out.name, a_out.name, acc_out.name],
2319 attr,
2320 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002321 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002322
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002323 if error_name in [
2324 ErrorIf.InputListCondGraphMismatch,
2325 ErrorIf.InputListBodyGraphInputMismatch,
2326 ErrorIf.InputListBodyGraphOutputMismatch,
2327 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002328 incorrect_iter = deepcopy(iter)
2329 for i in range(len(incorrect_iter.shape)):
2330 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2331 if len(incorrect_iter.shape) == 0:
2332 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2333
2334 incorrect_acc = deepcopy(acc)
2335 for i in range(len(incorrect_acc.shape)):
2336 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2337
Eric Kunzee5e26762020-10-13 16:11:07 -07002338 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002339 self.ser.addBasicBlock(cond_block)
2340
Matthew Haddon630c17c2021-10-14 15:05:41 +01002341 if error_name == ErrorIf.InputListCondGraphMismatch:
2342 self.ser.addInputTensor(incorrect_iter)
2343 self.ser.addInputTensor(a)
2344 self.ser.addInputTensor(incorrect_acc)
2345 else:
2346 self.ser.addInputTensor(iter)
2347 self.ser.addInputTensor(a)
2348 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002349 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002350
2351 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002352 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002353 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002354 cond_type = DType.BOOL
2355 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2356 choice = self.rng.choice([1, 2])
2357 if choice == 1:
2358 cond_shape = [3]
2359 else:
2360 cond_shape = [1, 2]
2361 else:
2362 cond_shape = []
2363 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002364
Kevin Cheng550ccc52021-03-03 11:21:43 -08002365 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002366
2367 # BODY block (input: a, acc, iter, output: a, acc, iter)
2368 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002369 self.ser.addBasicBlock(body_block)
2370
Matthew Haddon630c17c2021-10-14 15:05:41 +01002371 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2372 self.ser.addInputTensor(incorrect_iter)
2373 self.ser.addInputTensor(a)
2374 self.ser.addInputTensor(incorrect_acc)
2375 else:
2376 self.ser.addInputTensor(iter)
2377 self.ser.addInputTensor(a)
2378 self.ser.addInputTensor(acc)
2379
Kevin Cheng550ccc52021-03-03 11:21:43 -08002380 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002381
2382 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002383 iter_body_out = self.ser.addIntermediate(
2384 incorrect_iter.shape, incorrect_iter.dtype
2385 )
2386 acc_body_out = self.ser.addIntermediate(
2387 incorrect_acc.shape, incorrect_acc.dtype
2388 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002389 else:
2390 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2391 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2392
Eric Kunzee5e26762020-10-13 16:11:07 -07002393 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2394 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2395 self.ser.addOutputTensor(iter_body_out)
2396 self.ser.addOutputTensor(a)
2397 self.ser.addOutputTensor(acc_body_out)
2398
Les Bell729b0352021-11-24 10:28:21 +00002399 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002400 self.ser,
2401 validator_fcns,
2402 error_name,
2403 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002404 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002405 ):
2406 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002407
Eric Kunzee5e26762020-10-13 16:11:07 -07002408 return acc_out
2409
Luke Hutton57287132023-02-06 14:54:18 +00002410 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002411 self,
2412 op,
2413 val1,
2414 val2,
2415 inverse,
2416 validator_fcns=None,
2417 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002418 ):
2419 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2420
2421 input_names = [val1.name, val2.name]
2422 pCount, cCount = op["operands"]
2423 num_operands = pCount + cCount
2424
2425 output_names = [res.name for res in results]
2426 output_shapes = [res.shape for res in results]
2427 output_dtypes = [res.dtype for res in results]
2428
2429 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2430 self, error_name, input_names, output_names
2431 )
2432
2433 if not TosaErrorValidator.evValidateErrorIfs(
2434 self.ser,
2435 validator_fcns,
2436 error_name,
2437 op=op,
2438 inverse=inverse,
2439 input1=val1,
2440 input2=val2,
2441 input_shape=val1.shape,
2442 input_dtype=val1.dtype,
2443 output_shape=output_shapes,
2444 output_dtype=output_dtypes,
2445 result_tensors=results,
2446 input_list=input_names,
2447 output_list=output_names,
2448 num_operands=num_operands,
2449 ):
2450 return None
2451
Tai Lyd3797f02023-11-15 23:06:19 +00002452 # TODO - Test local_bound, for now set local bound attribute to False
2453 local_bound = False
2454
Luke Hutton57287132023-02-06 14:54:18 +00002455 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002456 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002457
2458 self.ser.addOperator(op["op"], input_names, output_names, attr)
2459 return results
2460
Tai Lyd3797f02023-11-15 23:06:19 +00002461 def build_rfft2d(
2462 self,
2463 op,
2464 val,
2465 validator_fcns=None,
2466 error_name=None,
2467 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002468 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2469
2470 input_names = [val.name]
2471 pCount, cCount = op["operands"]
2472 num_operands = pCount + cCount
2473
2474 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002475 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002476 output_dtypes = [res.dtype for res in results]
2477
2478 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2479 self, error_name, input_names, output_names
2480 )
2481
2482 if not TosaErrorValidator.evValidateErrorIfs(
2483 self.ser,
2484 validator_fcns,
2485 error_name,
2486 op=op,
2487 input_shape=val.shape,
2488 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002489 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002490 output_dtype=output_dtypes,
2491 result_tensors=results,
2492 input_list=input_names,
2493 output_list=output_names,
2494 num_operands=num_operands,
2495 ):
2496 return None
2497
Tai Lyd3797f02023-11-15 23:06:19 +00002498 # TODO - Test local_bound, for now set local bound attribute to False
2499 local_bound = False
2500
2501 attr = ts.TosaSerializerAttribute()
2502 attr.RFFTAttribute(local_bound)
2503
2504 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002505 return results
2506
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002507 def create_filter_lists(
2508 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2509 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002510 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2511 default_test_rank_range = range(1, 5)
2512 if not shapeFilter:
2513 shapeFilter = [None]
2514
2515 # Calculate the filters based on what is requested and what the operator allows
2516 rmin, rmax = op["rank"]
2517 if rankFilter is not None:
2518 cleanRankFilter = []
2519 # Ensure rankFilter values are allowed by operator
2520 for rank in rankFilter:
2521 if rank >= rmin and rank <= rmax:
2522 cleanRankFilter.append(rank)
2523 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002524 # Ensure default behaviour is bounded by default range or by operator,
2525 # whichever is the smaller range of ranks.
2526 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002527 cleanRankFilter = (
2528 opRankRange
2529 if len(opRankRange) <= len(default_test_rank_range)
2530 else default_test_rank_range
2531 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002532 else:
2533 cleanRankFilter = range(rmin, rmax + 1)
2534
2535 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002536
Matthew Haddon1c00b712021-10-01 15:51:03 +01002537 if dtypeFilter is not None:
2538 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002539 # Create list of operator dtypes filtered by requested dtypes
2540 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002541 if dtype in dtypeFilter or (
2542 isinstance(dtype, list) and dtype[0] in dtypeFilter
2543 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002544 cleanDtypeFilter.append(dtype)
2545 else:
2546 cleanDtypeFilter = dtypes
2547
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002548 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002549 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002550 "shapeFilter": shapeFilter,
2551 "rankFilter": cleanRankFilter,
2552 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002553 }
2554 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002555 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002556 if validator is not None:
2557 validator_info = validator(check=False, op=op)
2558 else:
2559 return None
2560
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002561 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002562
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002563 # Set parameters as required
2564 if error_arguments["rank"] is not None:
2565 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002566 else:
2567 rankFilter = cleanRankFilter
2568
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002569 if error_arguments["dtype"] is not None:
2570 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002571 else:
2572 dtypeFilter = cleanDtypeFilter
2573
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002574 if error_arguments["shape"] is not None:
2575 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002576 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002577 shapeFilter = shapeFilter[
2578 :2
2579 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002580
2581 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002582 "shapeFilter": shapeFilter,
2583 "rankFilter": rankFilter,
2584 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002585 }
2586 return filterDict
2587
Kevin Cheng550ccc52021-03-03 11:21:43 -08002588 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002589 self,
2590 opName,
2591 shapeFilter=[None],
2592 rankFilter=None,
2593 dtypeFilter=None,
2594 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002595 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002596
2597 try:
2598 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002599 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002600 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002601
2602 # Initialize a new random number generator
2603 self.rng = np.random.default_rng(self.random_seed)
2604
Jeremy Johnson1271c442023-09-05 11:39:26 +01002605 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002606
Eric Kunzee5e26762020-10-13 16:11:07 -07002607 # Test list consists of a tuple of:
2608 # (opName, testNameStr, dtype, shapeList, argumentsList)
2609 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002610 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002611 error_if_validators = op["error_if_validators"]
2612 else:
2613 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002614
Matthew Haddon1c00b712021-10-01 15:51:03 +01002615 for validator in error_if_validators:
2616 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002617 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002618 else:
2619 error_name = None
2620
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002621 filterDict = self.create_filter_lists(
2622 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2623 )
2624 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002625 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002626 cleanRankFilter = filterDict["rankFilter"]
2627 cleanDtypeFilter = filterDict["dtypeFilter"]
2628 cleanShapeFilter = filterDict["shapeFilter"]
2629 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002630
2631 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002632 for t in cleanDtypeFilter:
2633 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002634 # Filter out by rank
2635 if shape is not None and len(shape) != r:
2636 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002637 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002638 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002639
Matthew Haddon74567092021-07-16 15:38:20 +01002640 shapeStr = self.shapeStr(shapeList[0])
2641 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002642
Matthew Haddon74567092021-07-16 15:38:20 +01002643 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2644 argList = []
2645 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002646 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002647 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002648 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002649
Matthew Haddon74567092021-07-16 15:38:20 +01002650 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002651 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002652 if argStr:
2653 testStr = "{}_{}_{}_{}".format(
2654 opName, shapeStr, typeStr, argStr
2655 )
2656 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002657 testStr = "{}_{}_{}".format(
2658 opName, shapeStr, typeStr
2659 )
2660 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002661 if argStr:
2662 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2663 opName, error_name, shapeStr, typeStr, argStr
2664 )
2665 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002666 testStr = "{}_ERRORIF_{}_{}_{}".format(
2667 opName, error_name, shapeStr, typeStr
2668 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002669
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002670 testList.append(
2671 (opName, testStr, t, error_name, shapeList, args)
2672 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002673
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002674 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002675 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2676 if "invalid_test_validators" in op:
2677 invalid_test_validators = op["invalid_test_validators"]
2678 clean_testList = []
2679 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002680 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002681 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002682 if validator_fcn(
2683 opName=test[0],
2684 input_dtype=test[2],
2685 shapeList=test[4],
2686 args=test[5],
2687 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002688 remove_test = True
2689 if not remove_test:
2690 clean_testList.append(test)
2691 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002692
2693 return testList
2694
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002695 def serializeTest(
2696 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2697 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002698 try:
2699 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002700 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002701 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002702
Jeremy Johnson0c716862023-04-13 17:18:19 +01002703 if self.args.verbose:
2704 print(f"Creating {testStr}")
2705
Eric Kunzee5e26762020-10-13 16:11:07 -07002706 # Create a serializer
2707 self.createSerializer(opName, testStr)
2708
Jeremy Johnson1271c442023-09-05 11:39:26 +01002709 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002710 if "error_if_validators" in op:
2711 error_if_validators = op["error_if_validators"]
2712 else:
2713 error_if_validators = None
2714
Kevin Cheng550ccc52021-03-03 11:21:43 -08002715 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002716 num_operands = pCount + cCount
2717
2718 if isinstance(dtype_or_dtypeList, list):
2719 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002720 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002721 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002722 else:
2723 dtypeList = [dtype_or_dtypeList] * (num_operands)
2724
Kevin Cheng93a16282021-08-31 16:14:03 -07002725 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002726 assert (
2727 len(shapeList) == num_operands
2728 ), "shapeList length {} must match number of operands {}".format(
2729 len(shapeList), num_operands
2730 )
2731 assert (
2732 len(dtypeList) == num_operands
2733 ), "dtypeList length {} must match number of operands {}".format(
2734 len(dtypeList), num_operands
2735 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002736
2737 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002738 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002739 except KeyError:
2740 qgen = None
2741
2742 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002743
Matthew Haddon1c00b712021-10-01 15:51:03 +01002744 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002745 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002746 else:
2747 qinfo = None
2748
Jeremy Johnson1271c442023-09-05 11:39:26 +01002749 # Extra meta data for the desc.json
2750 tensMeta = {}
2751
2752 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002753 if isinstance(testArgs, dict):
2754 # New interface with args info in dictionary
2755 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002756 assert "dg_type" in argsDict
2757 tvgInfo = tvgen_fcn(
2758 self, opName, dtypeList, shapeList, argsDict, error_name
2759 )
2760 if tvgInfo.dataGenDict:
2761 tensMeta["data_gen"] = tvgInfo.dataGenDict
2762 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002763
2764 result = build_fcn(
2765 self,
2766 op,
2767 tens,
2768 argsDict,
2769 validator_fcns=error_if_validators,
2770 error_name=error_name,
2771 qinfo=qinfo,
2772 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002773 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002774 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002775 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002776
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002777 try:
2778 if error_if_validators is None:
2779 if qinfo is not None:
2780 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2781 else:
2782 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002783 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002784 if qinfo is not None:
2785 result = build_fcn(
2786 self,
2787 op,
2788 *tens,
2789 *testArgs,
2790 validator_fcns=error_if_validators,
2791 error_name=error_name,
2792 qinfo=qinfo,
2793 )
2794 else:
2795 result = build_fcn(
2796 self,
2797 op,
2798 *tens,
2799 *testArgs,
2800 validator_fcns=error_if_validators,
2801 error_name=error_name,
2802 )
2803 except TypeError as e:
2804 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2805 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002806
Jeremy Johnson1271c442023-09-05 11:39:26 +01002807 if result:
Les Bell729b0352021-11-24 10:28:21 +00002808 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002809 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2810 # Add the compliance meta data
2811 # NOTE: This currently expects only one result output
2812 tensMeta["compliance"] = {
2813 "version": "0.1",
2814 "tensors": {result.resultTensor.name: result.complianceDict},
2815 }
2816 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002817 else:
2818 # The test is not valid
2819 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002820
Eric Kunzee5e26762020-10-13 16:11:07 -07002821 def createDynamicOpLists(self):
2822
Jeremy Johnson00423432022-09-12 17:27:37 +01002823 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2824 # Already created these lists (can occur when class is initialized more than once)
2825 return
2826
Eric Kunzee5e26762020-10-13 16:11:07 -07002827 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002828 if not self.args.level8k:
2829 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2830 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2831 else:
2832 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2833 KERNELS_2D = [[1, bigK], [bigK, 2]]
2834 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002835
Kevin Cheng1533b852021-09-01 12:51:58 -07002836 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002837 testName = "conv2d_{}x{}".format(k[0], k[1])
2838 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2839 self.TOSA_OP_LIST[testName]["filter"] = k
2840 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002841
Kevin Cheng550ccc52021-03-03 11:21:43 -08002842 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2843 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2844 "depthwise_conv2d_TEMPLATE"
2845 ].copy()
2846 self.TOSA_OP_LIST[testName]["filter"] = k
2847 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002848
Kevin Cheng550ccc52021-03-03 11:21:43 -08002849 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2850 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2851 "transpose_conv2d_TEMPLATE"
2852 ].copy()
2853 self.TOSA_OP_LIST[testName]["filter"] = k
2854 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002855
Kevin Cheng1533b852021-09-01 12:51:58 -07002856 for k in KERNELS_3D:
2857 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2858 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2859 self.TOSA_OP_LIST[testName]["filter"] = k
2860 self.TOSA_OP_LIST[testName]["template"] = False
2861
Eric Kunzee5e26762020-10-13 16:11:07 -07002862 # Delete any templates after having created any dynamic ops
2863 # This is a two-pass operation because it's bad practice to delete
2864 # keys from dictionaries while iterating
2865 keyList = []
2866 for k in self.TOSA_OP_LIST:
2867 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002868 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002869 keyList.append(k)
2870 continue
2871 except KeyError:
2872 pass
2873
2874 for k in keyList:
2875 del self.TOSA_OP_LIST[k]
2876
2877 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002878 """Fill in default fields for ops if they aren't already specified.
2879 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002880 for op in self.TOSA_OP_LIST:
2881
2882 # Required fields
2883 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002884 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002885 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002886 raise Exception(
2887 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2888 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002889
2890 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002891 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002892 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002893 raise Exception(
2894 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2895 op
2896 )
2897 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002898
2899 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002900 _ = self.TOSA_OP_LIST[op]["types"]
2901 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002902 raise Exception(
2903 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2904 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002905
2906 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002907 _ = self.TOSA_OP_LIST[op]["op"]
2908 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002909 raise Exception(
2910 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2911 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002912
2913 # Put in default rank range, if missing
2914 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002915 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002916 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002917 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002918
2919 # Tensor operator list
2920 # 'op': op name
2921 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002922 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2923 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002924 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2925 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002926 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002927
Kevin Cheng550ccc52021-03-03 11:21:43 -08002928 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002929 TYPE_INT_FP = [
2930 DType.INT8,
2931 DType.INT16,
2932 DType.INT32,
2933 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002934 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002935 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002936 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002937
Kevin Cheng550ccc52021-03-03 11:21:43 -08002938 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002939 TYPE_FI32 = [
2940 DType.FP32,
2941 DType.FP16,
2942 DType.BF16,
2943 DType.INT32,
2944 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002945 TYPE_FIB = [
2946 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002947 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002948 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002949 DType.INT8,
2950 DType.INT16,
2951 DType.INT32,
2952 DType.BOOL,
2953 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002954 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002955
James Ward24dbc422022-10-19 12:20:31 +01002956 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002957
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002958 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002959 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002960 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002961 [DType.INT8, DType.INT8, DType.INT32],
2962 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002963 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002964 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002965 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002966 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002967 ]
2968
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002969 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002970
2971 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002972 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002973 "argmax": {
2974 "op": Op.ARGMAX,
2975 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002976 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002977 "build_fcn": (
2978 build_argmax,
2979 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002980 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002981 TosaArgGen.agAxis,
2982 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002983 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002984 "error_if_validators": (
2985 TosaErrorValidator.evAxisSmallerZero,
2986 TosaErrorValidator.evAxisLargerRank,
2987 TosaErrorValidator.evArgmaxOutputRankMismatch,
2988 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2989 TosaErrorValidator.evWrongRank,
2990 TosaErrorValidator.evWrongInputType,
2991 TosaErrorValidator.evWrongOutputType,
2992 TosaErrorValidator.evWrongInputList,
2993 TosaErrorValidator.evWrongOutputList,
2994 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002995 "data_gen": {
2996 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
2997 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002998 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002999 "avg_pool2d": {
3000 "op": Op.AVG_POOL2D,
3001 "operands": (1, 0),
3002 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003003 "build_fcn": (
3004 build_pool2d,
3005 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003006 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003007 TosaArgGen.agPooling,
3008 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003009 "qgen": TosaQuantGen.qgUnary,
3010 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003011 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003012 "error_if_validators": (
3013 TosaErrorValidator.evKernelSmallerOne,
3014 TosaErrorValidator.evStrideSmallerOne,
3015 TosaErrorValidator.evPadSmallerZero,
3016 TosaErrorValidator.evWrongRank,
3017 TosaErrorValidator.evWrongInputType,
3018 TosaErrorValidator.evWrongOutputType,
3019 TosaErrorValidator.evWrongInputList,
3020 TosaErrorValidator.evWrongOutputList,
3021 TosaErrorValidator.evInputZeroPointNotZero,
3022 TosaErrorValidator.evOutputZeroPointNotZero,
3023 TosaErrorValidator.evPadLargerEqualKernel,
3024 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003025 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003026 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003027 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003028 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003029 "conv2d_TEMPLATE": {
3030 "op": Op.CONV2D,
3031 "operands": (1, 2),
3032 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003033 "build_fcn": (
3034 build_conv2d,
3035 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003036 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003037 TosaArgGen.agConv,
3038 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003039 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003040 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003041 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3042 "error_if_validators": (
3043 TosaErrorValidator.evWrongInputType,
3044 TosaErrorValidator.evWrongOutputType,
3045 TosaErrorValidator.evWrongInputList,
3046 TosaErrorValidator.evWrongOutputList,
3047 TosaErrorValidator.evInputZeroPointNotZero,
3048 TosaErrorValidator.evWeightZeroPointNotZero,
3049 TosaErrorValidator.evPadSmallerZero,
3050 TosaErrorValidator.evStrideSmallerOne,
3051 TosaErrorValidator.evDilationSmallerOne,
3052 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003053 TosaErrorValidator.evConvOutputShapeMismatch,
3054 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003055 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003056 "data_gen": {
3057 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3058 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003059 "template": True,
3060 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003061 # Templated operator. Filled in by createDynamicOpLists
3062 "conv3d_TEMPLATE": {
3063 "op": Op.CONV3D,
3064 "operands": (1, 2),
3065 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003066 "build_fcn": (
3067 build_conv3d,
3068 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003069 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003070 TosaArgGen.agConv,
3071 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003072 "qgen": TosaQuantGen.qgConv,
3073 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003074 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3075 "error_if_validators": (
3076 TosaErrorValidator.evWrongInputType,
3077 TosaErrorValidator.evWrongOutputType,
3078 TosaErrorValidator.evWrongInputList,
3079 TosaErrorValidator.evWrongOutputList,
3080 TosaErrorValidator.evInputZeroPointNotZero,
3081 TosaErrorValidator.evWeightZeroPointNotZero,
3082 TosaErrorValidator.evPadSmallerZero,
3083 TosaErrorValidator.evStrideSmallerOne,
3084 TosaErrorValidator.evDilationSmallerOne,
3085 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003086 TosaErrorValidator.evConvOutputShapeMismatch,
3087 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003088 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003089 "template": True,
3090 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003091 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003092 "depthwise_conv2d_TEMPLATE": {
3093 "op": Op.DEPTHWISE_CONV2D,
3094 "operands": (1, 2),
3095 "filter": [1, 1],
3096 "rank": (4, 4),
3097 "build_fcn": (
3098 build_depthwise_conv2d,
3099 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003100 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003101 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003102 ),
3103 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003104 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003105 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3106 "error_if_validators": (
3107 TosaErrorValidator.evWrongInputType,
3108 TosaErrorValidator.evWrongOutputType,
3109 TosaErrorValidator.evWrongInputList,
3110 TosaErrorValidator.evWrongOutputList,
3111 TosaErrorValidator.evInputZeroPointNotZero,
3112 TosaErrorValidator.evWeightZeroPointNotZero,
3113 TosaErrorValidator.evPadSmallerZero,
3114 TosaErrorValidator.evStrideSmallerOne,
3115 TosaErrorValidator.evDilationSmallerOne,
3116 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003117 TosaErrorValidator.evConvOutputShapeMismatch,
3118 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003119 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003120 "template": True,
3121 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003122 "fully_connected": {
3123 "op": Op.FULLY_CONNECTED,
3124 "operands": (1, 2),
3125 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003126 "build_fcn": (
3127 build_fully_connected,
3128 TosaTensorGen.tgFullyConnected,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003129 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003130 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003131 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003132 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003133 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003134 "error_if_validators": (
3135 TosaErrorValidator.evInputZeroPointNotZero,
3136 TosaErrorValidator.evWeightZeroPointNotZero,
3137 TosaErrorValidator.evWrongRank,
3138 TosaErrorValidator.evWrongInputType,
3139 TosaErrorValidator.evWrongOutputType,
3140 TosaErrorValidator.evWrongInputList,
3141 TosaErrorValidator.evWrongOutputList,
3142 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003143 "data_gen": {
3144 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3145 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003146 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003147 "matmul": {
3148 "op": Op.MATMUL,
3149 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003150 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003151 "build_fcn": (
3152 build_matmul,
3153 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003154 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003155 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003156 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003157 "qgen": TosaQuantGen.qgMatmul,
3158 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003159 "error_if_validators": (
3160 TosaErrorValidator.evInputZeroPointNotZero,
3161 TosaErrorValidator.evWrongRank,
3162 TosaErrorValidator.evWrongInputType,
3163 TosaErrorValidator.evWrongOutputType,
3164 TosaErrorValidator.evWrongInputList,
3165 TosaErrorValidator.evWrongOutputList,
3166 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003167 "data_gen": {
3168 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003169 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003170 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003171 "max_pool2d": {
3172 "op": Op.MAX_POOL2D,
3173 "operands": (1, 0),
3174 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003175 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01003176 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003177 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003178 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003179 TosaArgGen.agPooling,
3180 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003181 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003182 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003183 "error_if_validators": (
3184 TosaErrorValidator.evKernelSmallerOne,
3185 TosaErrorValidator.evStrideSmallerOne,
3186 TosaErrorValidator.evPadSmallerZero,
3187 TosaErrorValidator.evWrongRank,
3188 TosaErrorValidator.evWrongInputType,
3189 TosaErrorValidator.evWrongOutputType,
3190 TosaErrorValidator.evWrongInputList,
3191 TosaErrorValidator.evWrongOutputList,
3192 TosaErrorValidator.evPadLargerEqualKernel,
3193 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003194 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003195 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003196 "data_gen": {
3197 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3198 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003199 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003200 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003201 "transpose_conv2d_TEMPLATE": {
3202 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003203 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003204 "rank": (4, 4),
3205 "build_fcn": (
3206 build_transpose_conv2d,
3207 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003208 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003209 TosaArgGen.agTransposeConv2D,
3210 ),
3211 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003212 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003213 "invalid_test_validators": (
3214 TosaInvalidValidator.ivHeightWidthInvalid,
3215 TosaInvalidValidator.ivNonPositiveOutputShape,
3216 ),
3217 "error_if_validators": (
3218 TosaErrorValidator.evWrongInputType,
3219 TosaErrorValidator.evWrongOutputType,
3220 TosaErrorValidator.evWrongInputList,
3221 TosaErrorValidator.evWrongOutputList,
3222 TosaErrorValidator.evInputZeroPointNotZero,
3223 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003224 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003225 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003226 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003227 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003228 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003229 "template": True,
3230 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003231 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003232 "clamp": {
3233 "op": Op.CLAMP,
3234 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003235 "build_fcn": (
3236 build_clamp,
3237 TosaTensorGen.tgBasic,
3238 TosaTensorValuesGen.tvgDefault,
3239 None,
3240 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003241 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003242 "error_if_validators": (
3243 TosaErrorValidator.evMaxSmallerMin,
3244 TosaErrorValidator.evWrongInputType,
3245 TosaErrorValidator.evWrongOutputType,
3246 TosaErrorValidator.evWrongInputList,
3247 TosaErrorValidator.evWrongOutputList,
3248 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003249 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003250 "sigmoid": {
3251 "op": Op.SIGMOID,
3252 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003253 "build_fcn": (
3254 build_sigmoid,
3255 TosaTensorGen.tgBasic,
3256 TosaTensorValuesGen.tvgDefault,
3257 None,
3258 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003259 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003260 "error_if_validators": (
3261 TosaErrorValidator.evWrongInputType,
3262 TosaErrorValidator.evWrongOutputType,
3263 TosaErrorValidator.evWrongInputList,
3264 TosaErrorValidator.evWrongOutputList,
3265 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003266 },
3267 "tanh": {
3268 "op": Op.TANH,
3269 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003270 "build_fcn": (
3271 build_tanh,
3272 TosaTensorGen.tgBasic,
3273 TosaTensorValuesGen.tvgDefault,
3274 None,
3275 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003276 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003277 "error_if_validators": (
3278 TosaErrorValidator.evWrongInputType,
3279 TosaErrorValidator.evWrongOutputType,
3280 TosaErrorValidator.evWrongInputList,
3281 TosaErrorValidator.evWrongOutputList,
3282 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003283 },
Won Jeon78155c62023-06-10 00:20:04 +00003284 "erf": {
3285 "op": Op.ERF,
3286 "operands": (1, 0),
3287 "build_fcn": (
3288 build_erf,
3289 TosaTensorGen.tgBasic,
3290 TosaTensorValuesGen.tvgDefault,
3291 None,
3292 ),
3293 "types": TYPE_FP,
3294 "error_if_validators": (
3295 TosaErrorValidator.evWrongInputType,
3296 TosaErrorValidator.evWrongOutputType,
3297 TosaErrorValidator.evWrongInputList,
3298 TosaErrorValidator.evWrongOutputList,
3299 ),
3300 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003301 # Elementwise Binary Operators
3302 "add": {
3303 "op": Op.ADD,
3304 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003305 "build_fcn": (
3306 build_binary_broadcast,
3307 TosaTensorGen.tgBroadcastFuzz,
3308 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003309 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003310 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003311 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003312 "error_if_validators": (
3313 TosaErrorValidator.evRankMismatch,
3314 TosaErrorValidator.evWrongInputType,
3315 TosaErrorValidator.evWrongOutputType,
3316 TosaErrorValidator.evWrongInputList,
3317 TosaErrorValidator.evWrongOutputList,
3318 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003319 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003320 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003321 "data_gen": {
3322 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3323 },
3324 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003325 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003326 "arithmetic_right_shift": {
3327 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3328 "operands": (2, 0),
3329 "build_fcn": (
3330 build_arithmetic_right_shift,
3331 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003332 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003333 TosaArgGen.agArithmeticRightShift,
3334 ),
3335 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003336 "error_if_validators": (
3337 TosaErrorValidator.evRankMismatch,
3338 TosaErrorValidator.evWrongInputType,
3339 TosaErrorValidator.evWrongOutputType,
3340 TosaErrorValidator.evWrongInputList,
3341 TosaErrorValidator.evWrongOutputList,
3342 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003343 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003344 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003345 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003346 "bitwise_and": {
3347 "op": Op.BITWISE_AND,
3348 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003349 "build_fcn": (
3350 build_binary_broadcast,
3351 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003352 TosaTensorValuesGen.tvgLazyGenDefault,
3353 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003354 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003355 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003356 "error_if_validators": (
3357 TosaErrorValidator.evRankMismatch,
3358 TosaErrorValidator.evWrongInputType,
3359 TosaErrorValidator.evWrongOutputType,
3360 TosaErrorValidator.evWrongInputList,
3361 TosaErrorValidator.evWrongOutputList,
3362 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003363 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003364 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003365 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003366 "bitwise_or": {
3367 "op": Op.BITWISE_OR,
3368 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003369 "build_fcn": (
3370 build_binary_broadcast,
3371 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003372 TosaTensorValuesGen.tvgLazyGenDefault,
3373 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003374 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003375 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003376 "error_if_validators": (
3377 TosaErrorValidator.evRankMismatch,
3378 TosaErrorValidator.evWrongInputType,
3379 TosaErrorValidator.evWrongOutputType,
3380 TosaErrorValidator.evWrongInputList,
3381 TosaErrorValidator.evWrongOutputList,
3382 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003383 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003384 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003385 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003386 "bitwise_xor": {
3387 "op": Op.BITWISE_XOR,
3388 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003389 "build_fcn": (
3390 build_binary_broadcast,
3391 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003392 TosaTensorValuesGen.tvgLazyGenDefault,
3393 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003394 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003395 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003396 "error_if_validators": (
3397 TosaErrorValidator.evRankMismatch,
3398 TosaErrorValidator.evWrongInputType,
3399 TosaErrorValidator.evWrongOutputType,
3400 TosaErrorValidator.evWrongInputList,
3401 TosaErrorValidator.evWrongOutputList,
3402 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003403 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003404 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003405 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003406 "intdiv": {
3407 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003408 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003409 "build_fcn": (
3410 build_binary_broadcast,
3411 TosaTensorGen.tgBroadcastFuzz,
3412 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003413 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003414 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003415 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003416 "error_if_validators": (
3417 TosaErrorValidator.evRankMismatch,
3418 TosaErrorValidator.evWrongInputType,
3419 TosaErrorValidator.evWrongOutputType,
3420 TosaErrorValidator.evWrongInputList,
3421 TosaErrorValidator.evWrongOutputList,
3422 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003423 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003424 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003425 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003426 "logical_and": {
3427 "op": Op.LOGICAL_AND,
3428 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003429 "build_fcn": (
3430 build_binary_broadcast,
3431 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003432 TosaTensorValuesGen.tvgLazyGenDefault,
3433 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003434 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003435 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003436 "error_if_validators": (
3437 TosaErrorValidator.evRankMismatch,
3438 TosaErrorValidator.evWrongInputType,
3439 TosaErrorValidator.evWrongOutputType,
3440 TosaErrorValidator.evWrongInputList,
3441 TosaErrorValidator.evWrongOutputList,
3442 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003443 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003444 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003445 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003446 "logical_left_shift": {
3447 "op": Op.LOGICAL_LEFT_SHIFT,
3448 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003449 "build_fcn": (
3450 build_binary_broadcast,
3451 TosaTensorGen.tgBroadcastFuzz,
3452 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003453 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003454 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003455 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003456 "error_if_validators": (
3457 TosaErrorValidator.evRankMismatch,
3458 TosaErrorValidator.evWrongInputType,
3459 TosaErrorValidator.evWrongOutputType,
3460 TosaErrorValidator.evWrongInputList,
3461 TosaErrorValidator.evWrongOutputList,
3462 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003463 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003464 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003465 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003466 "logical_right_shift": {
3467 "op": Op.LOGICAL_RIGHT_SHIFT,
3468 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003469 "build_fcn": (
3470 build_binary_broadcast,
3471 TosaTensorGen.tgBroadcastFuzz,
3472 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003473 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003474 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003475 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003476 "error_if_validators": (
3477 TosaErrorValidator.evRankMismatch,
3478 TosaErrorValidator.evWrongInputType,
3479 TosaErrorValidator.evWrongOutputType,
3480 TosaErrorValidator.evWrongInputList,
3481 TosaErrorValidator.evWrongOutputList,
3482 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003483 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003484 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003485 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003486 "logical_or": {
3487 "op": Op.LOGICAL_OR,
3488 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003489 "build_fcn": (
3490 build_binary_broadcast,
3491 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003492 TosaTensorValuesGen.tvgLazyGenDefault,
3493 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003494 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003495 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003496 "error_if_validators": (
3497 TosaErrorValidator.evRankMismatch,
3498 TosaErrorValidator.evWrongInputType,
3499 TosaErrorValidator.evWrongOutputType,
3500 TosaErrorValidator.evWrongInputList,
3501 TosaErrorValidator.evWrongOutputList,
3502 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003503 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003504 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003505 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003506 "logical_xor": {
3507 "op": Op.LOGICAL_XOR,
3508 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003509 "build_fcn": (
3510 build_binary_broadcast,
3511 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003512 TosaTensorValuesGen.tvgLazyGenDefault,
3513 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003514 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003515 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003516 "error_if_validators": (
3517 TosaErrorValidator.evRankMismatch,
3518 TosaErrorValidator.evWrongInputType,
3519 TosaErrorValidator.evWrongOutputType,
3520 TosaErrorValidator.evWrongInputList,
3521 TosaErrorValidator.evWrongOutputList,
3522 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003523 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003524 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003525 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003526 "maximum": {
3527 "op": Op.MAXIMUM,
3528 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003529 "build_fcn": (
3530 build_binary_broadcast,
3531 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003532 TosaTensorValuesGen.tvgLazyGenDefault,
3533 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003534 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003536 "error_if_validators": (
3537 TosaErrorValidator.evRankMismatch,
3538 TosaErrorValidator.evWrongInputType,
3539 TosaErrorValidator.evWrongOutputType,
3540 TosaErrorValidator.evWrongInputList,
3541 TosaErrorValidator.evWrongOutputList,
3542 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003543 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003544 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003545 "data_gen": {
3546 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3547 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003548 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003549 "minimum": {
3550 "op": Op.MINIMUM,
3551 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003552 "build_fcn": (
3553 build_binary_broadcast,
3554 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003555 TosaTensorValuesGen.tvgLazyGenDefault,
3556 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003557 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003558 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003559 "error_if_validators": (
3560 TosaErrorValidator.evRankMismatch,
3561 TosaErrorValidator.evWrongInputType,
3562 TosaErrorValidator.evWrongOutputType,
3563 TosaErrorValidator.evWrongInputList,
3564 TosaErrorValidator.evWrongOutputList,
3565 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003566 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003567 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003568 "data_gen": {
3569 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3570 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003571 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003572 "mul": {
3573 "op": Op.MUL,
3574 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003575 "build_fcn": (
3576 build_mul,
3577 TosaTensorGen.tgBroadcastFuzz,
3578 TosaTensorValuesGen.tvgMul,
3579 TosaArgGen.agMul,
3580 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003581 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003582 "error_if_validators": (
3583 TosaErrorValidator.evWrongInputType,
3584 TosaErrorValidator.evWrongOutputType,
3585 TosaErrorValidator.evWrongInputList,
3586 TosaErrorValidator.evWrongOutputList,
3587 TosaErrorValidator.evRankMismatch,
3588 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003589 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003590 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003591 "data_gen": {
3592 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3593 },
3594 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003595 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003596 "pow": {
3597 "op": Op.POW,
3598 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003599 "build_fcn": (
3600 build_binary_broadcast,
3601 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003602 TosaTensorValuesGen.tvgLazyGenDefault,
3603 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003604 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003605 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003606 "error_if_validators": (
3607 TosaErrorValidator.evRankMismatch,
3608 TosaErrorValidator.evWrongInputType,
3609 TosaErrorValidator.evWrongOutputType,
3610 TosaErrorValidator.evWrongInputList,
3611 TosaErrorValidator.evWrongOutputList,
3612 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003613 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003614 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003615 "data_gen": {
3616 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3617 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003618 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003619 "sub": {
3620 "op": Op.SUB,
3621 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003622 "build_fcn": (
3623 build_binary_broadcast,
3624 TosaTensorGen.tgBroadcastFuzz,
3625 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003626 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003627 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003628 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003629 "error_if_validators": (
3630 TosaErrorValidator.evRankMismatch,
3631 TosaErrorValidator.evWrongInputType,
3632 TosaErrorValidator.evWrongOutputType,
3633 TosaErrorValidator.evWrongInputList,
3634 TosaErrorValidator.evWrongOutputList,
3635 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003636 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003637 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003638 "data_gen": {
3639 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3640 },
3641 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003642 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003643 "table": {
3644 "op": Op.TABLE,
3645 # Use the automatic generation functions to create the input array
3646 # but create the table tensor in the build function, as it may be
3647 # a different type from the input
3648 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003649 "build_fcn": (
3650 build_table,
3651 TosaTensorGen.tgBasic,
3652 TosaTensorValuesGen.tvgDefault,
3653 TosaArgGen.agTable,
3654 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003655 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003656 "error_if_validators": (
3657 TosaErrorValidator.evWrongInputType,
3658 TosaErrorValidator.evWrongOutputType,
3659 TosaErrorValidator.evWrongInputList,
3660 TosaErrorValidator.evWrongOutputList,
3661 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003662 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003663 # Elementwise Unary operators
3664 "abs": {
3665 "op": Op.ABS,
3666 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003667 "build_fcn": (
3668 build_unary,
3669 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003670 TosaTensorValuesGen.tvgLazyGenDefault,
3671 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003672 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003673 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003674 "error_if_validators": (
3675 TosaErrorValidator.evWrongInputType,
3676 TosaErrorValidator.evWrongOutputType,
3677 TosaErrorValidator.evWrongInputList,
3678 TosaErrorValidator.evWrongOutputList,
3679 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003680 "data_gen": {
3681 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3682 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003683 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003684 "bitwise_not": {
3685 "op": Op.BITWISE_NOT,
3686 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003687 "build_fcn": (
3688 build_unary,
3689 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003690 TosaTensorValuesGen.tvgLazyGenDefault,
3691 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003692 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003693 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003694 "error_if_validators": (
3695 TosaErrorValidator.evWrongInputType,
3696 TosaErrorValidator.evWrongOutputType,
3697 TosaErrorValidator.evWrongInputList,
3698 TosaErrorValidator.evWrongOutputList,
3699 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003700 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003701 "ceil": {
3702 "op": Op.CEIL,
3703 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003704 "build_fcn": (
3705 build_unary,
3706 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003707 TosaTensorValuesGen.tvgLazyGenDefault,
3708 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003709 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003710 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003711 "error_if_validators": (
3712 TosaErrorValidator.evWrongInputType,
3713 TosaErrorValidator.evWrongOutputType,
3714 TosaErrorValidator.evWrongInputList,
3715 TosaErrorValidator.evWrongOutputList,
3716 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003717 "data_gen": {
3718 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3719 },
3720 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003721 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003722 "clz": {
3723 "op": Op.CLZ,
3724 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003725 "build_fcn": (
3726 build_unary,
3727 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003728 TosaTensorValuesGen.tvgLazyGenDefault,
3729 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003730 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003731 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003732 "error_if_validators": (
3733 TosaErrorValidator.evWrongInputType,
3734 TosaErrorValidator.evWrongOutputType,
3735 TosaErrorValidator.evWrongInputList,
3736 TosaErrorValidator.evWrongOutputList,
3737 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003738 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003739 "exp": {
3740 "op": Op.EXP,
3741 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003742 "build_fcn": (
3743 build_unary,
3744 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003745 TosaTensorValuesGen.tvgLazyGenDefault,
3746 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003747 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003748 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003749 "error_if_validators": (
3750 TosaErrorValidator.evWrongInputType,
3751 TosaErrorValidator.evWrongOutputType,
3752 TosaErrorValidator.evWrongInputList,
3753 TosaErrorValidator.evWrongOutputList,
3754 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003755 "data_gen": {
3756 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3757 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003758 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003759 "floor": {
3760 "op": Op.FLOOR,
3761 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003762 "build_fcn": (
3763 build_unary,
3764 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003765 TosaTensorValuesGen.tvgLazyGenDefault,
3766 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003767 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003768 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003769 "error_if_validators": (
3770 TosaErrorValidator.evWrongInputType,
3771 TosaErrorValidator.evWrongOutputType,
3772 TosaErrorValidator.evWrongInputList,
3773 TosaErrorValidator.evWrongOutputList,
3774 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003775 "data_gen": {
3776 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3777 },
3778 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003779 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003780 "log": {
3781 "op": Op.LOG,
3782 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003783 "build_fcn": (
3784 build_unary,
3785 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003786 TosaTensorValuesGen.tvgLazyGenDefault,
3787 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003788 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003789 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003790 "error_if_validators": (
3791 TosaErrorValidator.evWrongInputType,
3792 TosaErrorValidator.evWrongOutputType,
3793 TosaErrorValidator.evWrongInputList,
3794 TosaErrorValidator.evWrongOutputList,
3795 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003796 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003797 "logical_not": {
3798 "op": Op.LOGICAL_NOT,
3799 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003800 "build_fcn": (
3801 build_unary,
3802 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003803 TosaTensorValuesGen.tvgLazyGenDefault,
3804 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003805 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003806 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003807 "error_if_validators": (
3808 TosaErrorValidator.evWrongInputType,
3809 TosaErrorValidator.evWrongOutputType,
3810 TosaErrorValidator.evWrongInputList,
3811 TosaErrorValidator.evWrongOutputList,
3812 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003813 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003814 "negate": {
3815 "op": Op.NEGATE,
3816 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003817 "build_fcn": (
3818 build_unary,
3819 TosaTensorGen.tgBasic,
3820 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003821 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003822 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003823 "qgen": TosaQuantGen.qgUnary,
3824 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003825 "error_if_validators": (
3826 TosaErrorValidator.evInputZeroPointNotZero,
3827 TosaErrorValidator.evOutputZeroPointNotZero,
3828 TosaErrorValidator.evWrongInputType,
3829 TosaErrorValidator.evWrongOutputType,
3830 TosaErrorValidator.evWrongInputList,
3831 TosaErrorValidator.evWrongOutputList,
3832 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003833 "data_gen": {
3834 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3835 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003836 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003837 "reciprocal": {
3838 "op": Op.RECIPROCAL,
3839 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003840 "build_fcn": (
3841 build_unary,
3842 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003843 TosaTensorValuesGen.tvgLazyGenDefault,
3844 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003845 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003846 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003847 "error_if_validators": (
3848 TosaErrorValidator.evWrongInputType,
3849 TosaErrorValidator.evWrongOutputType,
3850 TosaErrorValidator.evWrongInputList,
3851 TosaErrorValidator.evWrongOutputList,
3852 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003853 "data_gen": {
3854 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3855 },
3856 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003857 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003858 "rsqrt": {
3859 "op": Op.RSQRT,
3860 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003861 "build_fcn": (
3862 build_unary,
3863 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003864 TosaTensorValuesGen.tvgLazyGenDefault,
3865 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003866 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003867 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003868 "error_if_validators": (
3869 TosaErrorValidator.evWrongInputType,
3870 TosaErrorValidator.evWrongOutputType,
3871 TosaErrorValidator.evWrongInputList,
3872 TosaErrorValidator.evWrongOutputList,
3873 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003874 "data_gen": {
3875 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3876 },
3877 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003878 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003879 # Elementwise Ternary operators
3880 "select": {
3881 "op": Op.SELECT,
3882 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003883 "build_fcn": (
3884 build_select,
3885 TosaTensorGen.tgBroadcastFuzz,
3886 TosaTensorValuesGen.tvgSelect,
3887 None,
3888 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003889 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003890 "error_if_validators": (
3891 TosaErrorValidator.evRankMismatch,
3892 TosaErrorValidator.evWrongInputType,
3893 TosaErrorValidator.evWrongOutputType,
3894 TosaErrorValidator.evWrongInputList,
3895 TosaErrorValidator.evWrongOutputList,
3896 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003897 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003898 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003899 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003900 # Comparison operators
3901 "equal": {
3902 "op": Op.EQUAL,
3903 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003904 "build_fcn": (
3905 build_comparison,
3906 TosaTensorGen.tgBroadcastFuzz,
3907 TosaTensorValuesGen.tvgEqual,
3908 None,
3909 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003910 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003911 "error_if_validators": (
3912 TosaErrorValidator.evRankMismatch,
3913 TosaErrorValidator.evWrongInputType,
3914 TosaErrorValidator.evWrongOutputType,
3915 TosaErrorValidator.evWrongInputList,
3916 TosaErrorValidator.evWrongOutputList,
3917 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003918 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003919 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003920 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003921 "greater_equal": {
3922 "op": Op.GREATER_EQUAL,
3923 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003924 "build_fcn": (
3925 build_comparison,
3926 TosaTensorGen.tgBroadcastFuzz,
3927 TosaTensorValuesGen.tvgDefault,
3928 None,
3929 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003930 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003931 "error_if_validators": (
3932 TosaErrorValidator.evRankMismatch,
3933 TosaErrorValidator.evWrongInputType,
3934 TosaErrorValidator.evWrongOutputType,
3935 TosaErrorValidator.evWrongInputList,
3936 TosaErrorValidator.evWrongOutputList,
3937 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003938 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003939 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003940 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003941 "greater": {
3942 "op": Op.GREATER,
3943 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003944 "build_fcn": (
3945 build_comparison,
3946 TosaTensorGen.tgBroadcastFuzz,
3947 TosaTensorValuesGen.tvgDefault,
3948 None,
3949 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003950 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003951 "error_if_validators": (
3952 TosaErrorValidator.evRankMismatch,
3953 TosaErrorValidator.evWrongInputType,
3954 TosaErrorValidator.evWrongOutputType,
3955 TosaErrorValidator.evWrongInputList,
3956 TosaErrorValidator.evWrongOutputList,
3957 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003958 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003959 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003960 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003961 # Reduction operators
3962 "reduce_all": {
3963 "op": Op.REDUCE_ALL,
3964 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003965 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003966 "build_fcn": (
3967 build_reduce,
3968 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003969 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003970 TosaArgGen.agAxis,
3971 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003972 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003973 "error_if_validators": (
3974 TosaErrorValidator.evAxisLargerRank,
3975 TosaErrorValidator.evAxisSmallerZero,
3976 TosaErrorValidator.evShapeOfAxisNotOne,
3977 TosaErrorValidator.evWrongInputType,
3978 TosaErrorValidator.evWrongOutputType,
3979 TosaErrorValidator.evWrongRank,
3980 TosaErrorValidator.evWrongInputList,
3981 TosaErrorValidator.evWrongOutputList,
3982 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003983 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003984 "reduce_any": {
3985 "op": Op.REDUCE_ANY,
3986 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003987 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003988 "build_fcn": (
3989 build_reduce,
3990 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003991 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003992 TosaArgGen.agAxis,
3993 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003994 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003995 "error_if_validators": (
3996 TosaErrorValidator.evAxisLargerRank,
3997 TosaErrorValidator.evAxisSmallerZero,
3998 TosaErrorValidator.evShapeOfAxisNotOne,
3999 TosaErrorValidator.evWrongInputType,
4000 TosaErrorValidator.evWrongOutputType,
4001 TosaErrorValidator.evWrongRank,
4002 TosaErrorValidator.evWrongInputList,
4003 TosaErrorValidator.evWrongOutputList,
4004 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004005 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004006 "reduce_max": {
4007 "op": Op.REDUCE_MAX,
4008 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004009 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004010 "build_fcn": (
4011 build_reduce,
4012 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004013 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004014 TosaArgGen.agAxis,
4015 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004016 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004017 "error_if_validators": (
4018 TosaErrorValidator.evAxisLargerRank,
4019 TosaErrorValidator.evAxisSmallerZero,
4020 TosaErrorValidator.evShapeOfAxisNotOne,
4021 TosaErrorValidator.evWrongInputType,
4022 TosaErrorValidator.evWrongOutputType,
4023 TosaErrorValidator.evWrongRank,
4024 TosaErrorValidator.evWrongInputList,
4025 TosaErrorValidator.evWrongOutputList,
4026 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004027 "data_gen": {
4028 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4029 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004030 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004031 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004032 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004033 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004034 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004035 "build_fcn": (
4036 build_reduce,
4037 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004038 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004039 TosaArgGen.agAxis,
4040 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004041 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004042 "error_if_validators": (
4043 TosaErrorValidator.evAxisLargerRank,
4044 TosaErrorValidator.evAxisSmallerZero,
4045 TosaErrorValidator.evShapeOfAxisNotOne,
4046 TosaErrorValidator.evWrongInputType,
4047 TosaErrorValidator.evWrongOutputType,
4048 TosaErrorValidator.evWrongRank,
4049 TosaErrorValidator.evWrongInputList,
4050 TosaErrorValidator.evWrongOutputList,
4051 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004052 "data_gen": {
4053 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4054 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004055 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004056 "reduce_product": {
4057 "op": Op.REDUCE_PRODUCT,
4058 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004059 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004060 "build_fcn": (
4061 build_reduce,
4062 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004063 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004064 TosaArgGen.agAxis,
4065 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004066 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004067 "error_if_validators": (
4068 TosaErrorValidator.evAxisLargerRank,
4069 TosaErrorValidator.evAxisSmallerZero,
4070 TosaErrorValidator.evShapeOfAxisNotOne,
4071 TosaErrorValidator.evWrongInputType,
4072 TosaErrorValidator.evWrongOutputType,
4073 TosaErrorValidator.evWrongRank,
4074 TosaErrorValidator.evWrongInputList,
4075 TosaErrorValidator.evWrongOutputList,
4076 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004077 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004078 "reduce_sum": {
4079 "op": Op.REDUCE_SUM,
4080 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004081 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004082 "build_fcn": (
4083 build_reduce,
4084 TosaTensorGen.tgBasic,
4085 TosaTensorValuesGen.tvgReduceSum,
4086 TosaArgGen.agAxis,
4087 ),
James Ward24dbc422022-10-19 12:20:31 +01004088 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004089 "error_if_validators": (
4090 TosaErrorValidator.evAxisLargerRank,
4091 TosaErrorValidator.evAxisSmallerZero,
4092 TosaErrorValidator.evShapeOfAxisNotOne,
4093 TosaErrorValidator.evWrongInputType,
4094 TosaErrorValidator.evWrongOutputType,
4095 TosaErrorValidator.evWrongRank,
4096 TosaErrorValidator.evWrongInputList,
4097 TosaErrorValidator.evWrongOutputList,
4098 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004099 "data_gen": {
4100 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4101 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004102 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004103 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004104 "concat": {
4105 "op": Op.CONCAT,
4106 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004107 "build_fcn": (
4108 build_concat,
4109 TosaTensorGen.tgConcat,
4110 TosaTensorValuesGen.tvgConcat,
4111 TosaArgGen.agAxis,
4112 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004113 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004114 "error_if_validators": (
4115 TosaErrorValidator.evAxisLargerRank,
4116 TosaErrorValidator.evAxisSmallerZero,
4117 TosaErrorValidator.evConcatInputRankMismatch,
4118 TosaErrorValidator.evConcatShapeSumMismatch,
4119 TosaErrorValidator.evConcatInputDimMismatch,
4120 TosaErrorValidator.evWrongInputType,
4121 TosaErrorValidator.evWrongOutputType,
4122 TosaErrorValidator.evWrongOutputList,
4123 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004124 },
4125 "pad": {
4126 "op": Op.PAD,
4127 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004128 "build_fcn": (
4129 build_pad,
4130 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004131 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004132 TosaArgGen.agPad,
4133 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004134 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004135 "error_if_validators": (
4136 TosaErrorValidator.evWrongInputType,
4137 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004138 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004139 TosaErrorValidator.evWrongOutputType,
4140 TosaErrorValidator.evWrongInputList,
4141 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004142 TosaErrorValidator.evRankMismatch,
4143 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004144 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004145 "data_gen": {
4146 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4147 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004148 },
Won Jeona21b2e82023-08-10 10:33:01 +00004149 "dim": {
4150 "op": Op.DIM,
4151 "operands": (1, 0),
4152 "build_fcn": (
4153 build_dim,
4154 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004155 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004156 TosaArgGen.agAxis,
4157 ),
4158 "types": TYPE_FIB,
4159 "error_if_validators": (
4160 TosaErrorValidator.evAxisLargerRank,
4161 TosaErrorValidator.evAxisSmallerZero,
4162 TosaErrorValidator.evWrongInputType,
4163 TosaErrorValidator.evWrongInputList,
4164 TosaErrorValidator.evWrongOutputList,
4165 TosaErrorValidator.evWrongRank,
4166 ),
4167 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004168 "reshape": {
4169 "op": Op.RESHAPE,
4170 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004171 "build_fcn": (
4172 build_reshape,
4173 TosaTensorGen.tgBasic,
4174 TosaTensorValuesGen.tvgDefault,
4175 TosaArgGen.agReshape,
4176 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004177 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004178 "error_if_validators": (
4179 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4180 TosaErrorValidator.evWrongInputType,
4181 TosaErrorValidator.evWrongOutputType,
4182 TosaErrorValidator.evWrongInputList,
4183 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00004184 TosaErrorValidator.evReshapeOutputSizeMultiInference,
4185 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004186 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004187 },
4188 "reverse": {
4189 "op": Op.REVERSE,
4190 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004191 "build_fcn": (
4192 build_reverse,
4193 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004194 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004195 TosaArgGen.agAxis,
4196 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004197 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004198 "error_if_validators": (
4199 TosaErrorValidator.evAxisSmallerZero,
4200 TosaErrorValidator.evAxisLargerRank,
4201 TosaErrorValidator.evWrongInputType,
4202 TosaErrorValidator.evWrongOutputType,
4203 TosaErrorValidator.evWrongInputList,
4204 TosaErrorValidator.evWrongOutputList,
4205 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004206 },
4207 "slice": {
4208 "op": Op.SLICE,
4209 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004210 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004211 "build_fcn": (
4212 build_slice,
4213 TosaTensorGen.tgBasic,
4214 TosaTensorValuesGen.tvgDefault,
4215 TosaArgGen.agSlice,
4216 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004217 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004218 "error_if_validators": (
4219 TosaErrorValidator.evStartSmallerZero,
4220 TosaErrorValidator.evSizeSmallerEqualZero,
4221 TosaErrorValidator.evStartSizeOutsideBounds,
4222 TosaErrorValidator.evSizeOutputShapeMismatch,
4223 TosaErrorValidator.evInputSizeStartLengthMismatch,
4224 TosaErrorValidator.evWrongRank,
4225 TosaErrorValidator.evWrongInputType,
4226 TosaErrorValidator.evWrongOutputType,
4227 TosaErrorValidator.evWrongInputList,
4228 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004229 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004230 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004231 },
4232 "tile": {
4233 "op": Op.TILE,
4234 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004235 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004236 "build_fcn": (
4237 build_tile,
4238 TosaTensorGen.tgBasic,
4239 TosaTensorValuesGen.tvgDefault,
4240 TosaArgGen.agTile,
4241 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004242 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004243 "error_if_validators": (
4244 TosaErrorValidator.evWrongInputType,
4245 TosaErrorValidator.evWrongOutputType,
4246 TosaErrorValidator.evWrongInputList,
4247 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004248 TosaErrorValidator.evRankMismatch,
4249 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004250 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004251 },
4252 "transpose": {
4253 "op": Op.TRANSPOSE,
4254 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004255 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004256 "build_fcn": (
4257 build_transpose,
4258 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004259 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004260 TosaArgGen.agTranspose,
4261 ),
4262 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004263 "error_if_validators": (
4264 TosaErrorValidator.evIndexOutsideBounds,
4265 TosaErrorValidator.evIndexUsedTwice,
4266 TosaErrorValidator.evWrongInputType,
4267 TosaErrorValidator.evWrongOutputType,
4268 TosaErrorValidator.evWrongInputList,
4269 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004270 TosaErrorValidator.evWrongRank,
4271 TosaErrorValidator.evRankMismatch,
4272 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004273 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004274 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004275 # Data nodes
4276 "const": {
4277 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004278 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004279 "build_fcn": (
4280 build_const,
4281 TosaTensorGen.tgBasic,
4282 TosaTensorValuesGen.tvgDefault,
4283 None,
4284 ),
Luke Hutton65872422023-02-20 10:33:04 +00004285 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004286 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004287 "identity": {
4288 "op": Op.IDENTITY,
4289 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004290 "build_fcn": (
4291 build_unary,
4292 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004293 TosaTensorValuesGen.tvgLazyGenDefault,
4294 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004295 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004296 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004297 "data_gen": {
4298 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4299 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004300 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004301 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004302 "gather": {
4303 "op": Op.GATHER,
4304 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4305 "operands": (1, 0),
4306 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004307 "build_fcn": (
4308 build_gather,
4309 TosaTensorGen.tgBasic,
4310 TosaTensorValuesGen.tvgDefault,
4311 None,
4312 ),
James Ward24dbc422022-10-19 12:20:31 +01004313 "types": (
4314 DType.INT8,
4315 DType.INT16,
4316 DType.INT32,
4317 DType.FP16,
4318 DType.BF16,
4319 DType.FP32,
4320 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004321 "error_if_validators": (
4322 TosaErrorValidator.evWrongInputType,
4323 TosaErrorValidator.evWrongOutputType,
4324 TosaErrorValidator.evWrongInputList,
4325 TosaErrorValidator.evWrongOutputList,
4326 TosaErrorValidator.evWrongRank,
4327 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004328 },
4329 "scatter": {
4330 "op": Op.SCATTER,
4331 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004332 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004333 "operands": (2, 0),
4334 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004335 "build_fcn": (
4336 build_scatter,
4337 TosaTensorGen.tgScatter,
4338 TosaTensorValuesGen.tvgDefault,
4339 None,
4340 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004341 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004342 "error_if_validators": (
4343 TosaErrorValidator.evWrongInputType,
4344 TosaErrorValidator.evWrongOutputType,
4345 TosaErrorValidator.evWrongInputList,
4346 TosaErrorValidator.evWrongOutputList,
4347 TosaErrorValidator.evWrongRank,
4348 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004349 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004350 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004351 "resize": {
4352 "op": Op.RESIZE,
4353 "operands": (1, 0),
4354 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004355 "build_fcn": (
4356 build_resize,
4357 TosaTensorGen.tgNHWC,
4358 TosaTensorValuesGen.tvgDefault,
4359 TosaArgGen.agResize,
4360 ),
James Ward24dbc422022-10-19 12:20:31 +01004361 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004362 "invalid_test_validators": (
4363 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004364 ),
4365 "error_if_validators": (
4366 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004367 TosaErrorValidator.evScaleSmallerEqualZero,
4368 TosaErrorValidator.evScaleNLargerMax,
4369 TosaErrorValidator.evScaleDLargerMax,
4370 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004371 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004372 TosaErrorValidator.evBorderSmallerMin,
4373 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004374 TosaErrorValidator.evWrongInputType,
4375 TosaErrorValidator.evWrongOutputType,
4376 TosaErrorValidator.evWrongRank,
4377 TosaErrorValidator.evWrongInputList,
4378 TosaErrorValidator.evWrongOutputList,
4379 TosaErrorValidator.evBatchMismatch,
4380 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004381 TosaErrorValidator.evResizeOutputShapeMismatch,
4382 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004383 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004384 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004385 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004386 "cast": {
4387 "op": Op.CAST,
4388 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004389 "build_fcn": (
4390 build_cast,
4391 TosaTensorGen.tgBasic,
4392 TosaTensorValuesGen.tvgDefault,
4393 TosaArgGen.agCast,
4394 ),
James Ward8b390432022-08-12 20:48:56 +01004395 "types": (
4396 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004397 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004398 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004399 DType.INT8,
4400 DType.INT16,
4401 DType.INT32,
4402 DType.BOOL,
4403 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004404 "error_if_validators": (
4405 TosaErrorValidator.evWrongInputType,
4406 TosaErrorValidator.evWrongOutputType,
4407 TosaErrorValidator.evWrongInputList,
4408 TosaErrorValidator.evWrongOutputList,
4409 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004410 },
4411 "rescale": {
4412 "op": Op.RESCALE,
4413 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004414 "build_fcn": (
4415 build_rescale,
4416 TosaTensorGen.tgBasic,
4417 TosaTensorValuesGen.tvgDefault,
4418 TosaArgGen.agRescale,
4419 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004420 "types": [
4421 DType.UINT8,
4422 DType.INT8,
4423 DType.INT16,
4424 DType.INT32,
4425 DType.INT48,
4426 DType.UINT16,
4427 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004428 "error_if_validators": (
4429 TosaErrorValidator.evInputZeroPointNotZero,
4430 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004431 TosaErrorValidator.evU16InputZeroPointNotValid,
4432 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004433 TosaErrorValidator.evScaleTrue,
4434 TosaErrorValidator.evScaleNotTrue,
4435 TosaErrorValidator.evWrongInputType,
4436 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004437 TosaErrorValidator.evWrongInputList,
4438 TosaErrorValidator.evWrongOutputList,
4439 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004440 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004441 # Custom
4442 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004443 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004444 # Two varients of cond_if, one that generates one of two constant tensors (no
4445 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4446 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004447 "cond_if_const": {
4448 "op": Op.COND_IF,
4449 "operands": (0, 2),
4450 "build_fcn": (
4451 build_cond_if_const,
4452 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004453 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004454 TosaArgGen.agCondIf,
4455 ),
4456 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004457 "error_if_validators": (
4458 TosaErrorValidator.evOutputListThenGraphMismatch,
4459 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004460 TosaErrorValidator.evCondIfCondNotMatchingBool,
4461 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004462 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004463 },
4464 "cond_if_binary": {
4465 "op": Op.COND_IF,
4466 "operands": (2, 0),
4467 "build_fcn": (
4468 build_cond_if_binary,
4469 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004470 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004471 TosaArgGen.agCondIf,
4472 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004473 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004474 "error_if_validators": (
4475 TosaErrorValidator.evInputListThenGraphMismatch,
4476 TosaErrorValidator.evInputListElseGraphMismatch,
4477 TosaErrorValidator.evOutputListThenGraphMismatch,
4478 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004479 TosaErrorValidator.evCondIfCondNotMatchingBool,
4480 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004481 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004482 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004483 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004484 "while_loop": {
4485 "op": Op.WHILE_LOOP,
4486 "operands": (0, 1),
4487 "build_fcn": (
4488 build_while_loop,
4489 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004490 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004491 TosaArgGen.agWhileLoop,
4492 ),
4493 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004494 "error_if_validators": (
4495 TosaErrorValidator.evInputListOutputListMismatch,
4496 TosaErrorValidator.evInputListCondGraphMismatch,
4497 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4498 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4499 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004500 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004501 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004502 },
Luke Hutton57287132023-02-06 14:54:18 +00004503 "fft2d": {
4504 "op": Op.FFT2D,
4505 "operands": (2, 0),
4506 "rank": (3, 3),
4507 "build_fcn": (
4508 build_fft2d,
4509 TosaTensorGen.tgFFT2d,
4510 TosaTensorValuesGen.tvgDefault,
4511 TosaArgGen.agFFT2d,
4512 ),
4513 "types": [DType.FP32],
4514 "error_if_validators": (
4515 TosaErrorValidator.evWrongInputType,
4516 TosaErrorValidator.evWrongOutputType,
4517 TosaErrorValidator.evWrongInputList,
4518 TosaErrorValidator.evWrongOutputList,
4519 TosaErrorValidator.evWrongRank,
4520 TosaErrorValidator.evBatchMismatch,
4521 TosaErrorValidator.evKernelNotPowerOfTwo,
4522 TosaErrorValidator.evFFTInputShapeMismatch,
4523 TosaErrorValidator.evFFTOutputShapeMismatch,
4524 ),
4525 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004526 "rfft2d": {
4527 "op": Op.RFFT2D,
4528 "operands": (1, 0),
4529 "rank": (3, 3),
4530 "build_fcn": (
4531 build_rfft2d,
4532 TosaTensorGen.tgRFFT2d,
4533 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004534 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004535 ),
4536 "types": [DType.FP32],
4537 "error_if_validators": (
4538 TosaErrorValidator.evWrongInputType,
4539 TosaErrorValidator.evWrongOutputType,
4540 TosaErrorValidator.evWrongInputList,
4541 TosaErrorValidator.evWrongOutputList,
4542 TosaErrorValidator.evWrongRank,
4543 TosaErrorValidator.evBatchMismatch,
4544 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004545 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004546 ),
4547 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004548 }
4549
Kevin Cheng550ccc52021-03-03 11:21:43 -08004550
Eric Kunzee5e26762020-10-13 16:11:07 -07004551class OutputShaper:
4552 # Methods in this class compute the expected output shape and datatype
4553 # for common classes of operations
4554 def __init__(self):
4555 pass
4556
4557 # These methods return arguments that can be used for
4558 # creating a new output tensor
4559 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004560 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4561 if error_name != ErrorIf.RankMismatch:
4562 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004563 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004564
4565 shape = []
4566 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004567 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004568 shape.append(b.shape[i])
4569 else:
4570 shape.append(a.shape[i])
4571
Jerry Ge135c9552023-05-23 20:59:32 +00004572 fuzz_idx = rng.integers(0, len(a.shape))
4573 if error_name == ErrorIf.DimensionMismatch:
4574 shape[fuzz_idx] += 1
4575
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004576 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004577 all_dtypes = [
4578 DType.INT8,
4579 DType.INT16,
4580 DType.INT32,
4581 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004582 DType.FP16,
4583 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004584 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004585 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004586 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4587 outputDType = rng.choice(wrong_dtypes)
4588 else:
4589 outputDType = a.dtype
4590
4591 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004592
4593 @staticmethod
4594 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004595 assert len(a.shape) == len(b.shape)
4596 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004597
4598 shape = []
4599 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004600 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004601 shape.append(a.shape[i])
4602
Kevin Cheng550ccc52021-03-03 11:21:43 -08004603 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004604
4605 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004606 def unaryOp(ser, rng, a, error_name=None):
4607 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004608 all_dtypes = [
4609 DType.INT8,
4610 DType.INT16,
4611 DType.INT32,
4612 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004613 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004614 DType.FP16,
4615 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004616 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004617 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4618 outputDType = rng.choice(wrong_dtypes)
4619 else:
4620 outputDType = a.dtype
4621
4622 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004623
4624 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004625 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004626 if error_name != ErrorIf.RankMismatch:
4627 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004628 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004629
4630 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004631 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004632 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004633 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4634 else:
4635 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004636
Jerry Ge135c9552023-05-23 20:59:32 +00004637 fuzz_idx = rng.integers(0, len(a.shape))
4638 if error_name == ErrorIf.DimensionMismatch:
4639 shape[fuzz_idx] += 1
4640
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004641 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004642 all_dtypes = [
4643 DType.INT8,
4644 DType.INT16,
4645 DType.INT32,
4646 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004647 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004648 DType.FP16,
4649 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004650 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004651 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4652 outputDType = rng.choice(wrong_dtypes)
4653 else:
4654 outputDType = a.dtype
4655
4656 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004657
4658 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004659 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004660 if error_name != ErrorIf.RankMismatch:
4661 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004662 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004663
4664 # Do broadcast
4665 shape = []
4666 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004667 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004668 shape.append(b.shape[i])
4669 else:
4670 shape.append(a.shape[i])
4671
Jerry Ge135c9552023-05-23 20:59:32 +00004672 fuzz_idx = rng.integers(0, len(a.shape))
4673 if error_name == ErrorIf.DimensionMismatch:
4674 shape[fuzz_idx] += 1
4675
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004676 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004677 wrong_dtypes = [
4678 DType.INT8,
4679 DType.INT16,
4680 DType.INT32,
4681 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004682 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004683 DType.FP16,
4684 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004685 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004686 outputDType = rng.choice(wrong_dtypes)
4687 else:
4688 outputDType = DType.BOOL
4689
4690 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004691
4692 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004693 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004694 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004695 if error_name not in [
4696 ErrorIf.AxisSmallerZero,
4697 ErrorIf.AxisLargerRank,
4698 ErrorIf.ShapeOfAxisNotOne,
4699 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004700 shape[axis] = 1
4701 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4702 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004703
Matthew Haddond6ce7252021-09-29 15:35:44 +01004704 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004705 all_dtypes = [
4706 DType.INT8,
4707 DType.INT16,
4708 DType.INT32,
4709 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004710 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004711 DType.FP16,
4712 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004713 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004714 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4715 outputDType = rng.choice(wrong_dtypes)
4716 else:
4717 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004718
Matthew Haddond6ce7252021-09-29 15:35:44 +01004719 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004720
4721 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004722 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004723 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004724
4725 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4726 del shape[axis]
4727
4728 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4729 remove = rng.choice([True, False])
4730 if remove and len(shape) > 1:
4731 del shape[0]
4732 else:
4733 shape.append(1)
4734 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4735 for i in range(len(shape)):
4736 shape[i] = shape[i] + rng.integers(1, 10)
4737
4738 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004739 all_dtypes = [
4740 DType.INT8,
4741 DType.INT16,
4742 DType.INT32,
4743 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004744 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004745 DType.FP16,
4746 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004747 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004748 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4749 outputDType = rng.choice(wrong_dtypes)
4750 else:
4751 outputDType = DType.INT32
4752
4753 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004754
4755 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004756 def conv2dOp(
4757 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4758 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004759
4760 # IFM: NHWC
4761 # Filter: OHWI
4762 # OFM: NHWC
4763
Kevin Cheng550ccc52021-03-03 11:21:43 -08004764 h = (
4765 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004766 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004767 + padding[0]
4768 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004769 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004770 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004771
Kevin Cheng550ccc52021-03-03 11:21:43 -08004772 w = (
4773 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004774 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004775 + padding[2]
4776 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004777 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004778 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004779
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004780 if error_name == ErrorIf.ConvOutputShapeMismatch:
4781 choices = [1, 2, 3]
4782 change = rng.choice(choices)
4783 # increment in multiples of stride to not hit non-integer error case
4784 if change in [1, 3]:
4785 h = h + (rng.choice(choices) * strides[0])
4786 if change in [2, 3]:
4787 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004788
Eric Kunzee5e26762020-10-13 16:11:07 -07004789 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4790
James Ward8b390432022-08-12 20:48:56 +01004791 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004792 # Pick some potentially correct output dtype if input type is incorrect
4793 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004794 else:
James Ward8b390432022-08-12 20:48:56 +01004795 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004796
4797 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004798 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004799 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004800 else:
4801 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004802 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004803 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004804
Kevin Cheng550ccc52021-03-03 11:21:43 -08004805 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004806
4807 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004808 def conv3dOp(
4809 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4810 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004811
4812 # IFM: NDHWC
4813 # Filter: ODHWI
4814 # OFM: NDHWC
4815
4816 d = (
4817 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004818 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004819 + padding[0]
4820 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004821 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004822 ) // strides[0] + 1
4823
4824 h = (
4825 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004826 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004827 + padding[2]
4828 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004829 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004830 ) // strides[1] + 1
4831
4832 w = (
4833 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004834 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004835 + padding[4]
4836 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004837 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004838 ) // strides[2] + 1
4839
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004840 if error_name == ErrorIf.ConvOutputShapeMismatch:
4841 choices = [1, 2, 3, 4]
4842 change = rng.choice(choices)
4843 # increment in multiples of stride to not hit non-integer error case
4844 if change in [1, 4]:
4845 d = d + (rng.choice(choices) * strides[0])
4846 if change in [2, 4]:
4847 h = h + (rng.choice(choices) * strides[1])
4848 if change in [3, 4]:
4849 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004850
Kevin Cheng1533b852021-09-01 12:51:58 -07004851 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4852
James Ward8b390432022-08-12 20:48:56 +01004853 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004854 # Pick some potentially correct output dtype if input type is incorrect
4855 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004856 else:
James Ward8b390432022-08-12 20:48:56 +01004857 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004858
4859 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004860 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004861 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004862 else:
4863 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004864 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004865 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004866
4867 return ser.addOutput(ofm_shape, out_dtype)
4868
4869 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004870 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004871 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004872 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004873 # IFM: NHWC
4874 # Filter: HWCM
4875 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004876
Kevin Cheng550ccc52021-03-03 11:21:43 -08004877 h = (
4878 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004879 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004880 + padding[0]
4881 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004882 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004883 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004884
Kevin Cheng550ccc52021-03-03 11:21:43 -08004885 w = (
4886 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004887 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004888 + padding[2]
4889 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004890 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004891 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004892
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004893 if error_name == ErrorIf.ConvOutputShapeMismatch:
4894 choices = [1, 2, 3]
4895 change = rng.choice(choices)
4896 # increment in multiples of stride to not hit non-integer error case
4897 if change in [1, 3]:
4898 h = h + (rng.choice(choices) * strides[0])
4899 if change in [2, 3]:
4900 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004901
Eric Kunzee5e26762020-10-13 16:11:07 -07004902 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4903
James Ward8b390432022-08-12 20:48:56 +01004904 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004905 # Pick some potentially correct output dtype if input type is incorrect
4906 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004907 else:
James Ward8b390432022-08-12 20:48:56 +01004908 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004909
4910 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004911 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004912 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004913 else:
4914 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004915 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004916 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004917
Kevin Cheng550ccc52021-03-03 11:21:43 -08004918 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004919
4920 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004921 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004922 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004923 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004924 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004925 h = 1
4926 w = 1
4927 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004928 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4929 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004930
4931 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004932 choices = [1, 2, 3]
4933 change = rng.choice(choices)
4934 # increment in multiples of stride to not hit non-integer error case
4935 if change in [1, 3]:
4936 h = h + (rng.choice(choices) * stride[0])
4937 if change in [2, 3]:
4938 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004939 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004940
4941 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004942 all_dtypes = [
4943 DType.INT8,
4944 DType.INT16,
4945 DType.INT32,
4946 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004947 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004948 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004949 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004950 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004951 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4952 outputDType = rng.choice(wrong_dtypes)
4953 else:
4954 outputDType = ifm.dtype
4955
4956 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004957
4958 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004959 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004960 # input: N, IC
4961 # filter: OC, IC
4962 # output: N, OC
4963
4964 output_shape = [input.shape[0], filter.shape[0]]
4965
James Ward8b390432022-08-12 20:48:56 +01004966 # Validated in arg_gen (also invalidated for ErrorIf)
4967 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004968
Kevin Cheng550ccc52021-03-03 11:21:43 -08004969 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004970
4971 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004972 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004973 # a: N, H, C
4974 # b: N, C, W
4975 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004976
Kevin Cheng2d60f002021-06-09 14:18:32 -07004977 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004978
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004979 if error_name == ErrorIf.WrongOutputType:
4980 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004981 incorrect_types = (
4982 DType.INT4,
4983 DType.INT8,
4984 DType.INT16,
4985 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004986 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004987 DType.FP16,
4988 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004989 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004990 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004991 incorrect_types = (
4992 DType.INT4,
4993 DType.INT8,
4994 DType.INT16,
4995 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004996 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004997 DType.FP16,
4998 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004999 )
James Ward24dbc422022-10-19 12:20:31 +01005000 elif (
5001 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5002 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005003 incorrect_types = (
5004 DType.INT4,
5005 DType.INT8,
5006 DType.INT16,
5007 DType.INT32,
5008 DType.INT48,
5009 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005010 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005011 elif error_name == ErrorIf.WrongInputType:
5012 # Pick some potentially correct output dtype if input type is incorrect
5013 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005014 else:
James Ward8b390432022-08-12 20:48:56 +01005015 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005016
Kevin Cheng550ccc52021-03-03 11:21:43 -08005017 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005018
5019 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005020 def concatOp(ser, rng, axis, inputs, error_name=None):
5021 input1 = inputs[0]
5022 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005023
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005024 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005025 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005026 if not (
5027 # unable to concat tensors of different ranks
5028 error_name == ErrorIf.ConcatInputRankMismatch
5029 # unable to concat tensors along an invalid axis
5030 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005031 ):
5032 for tensor in remaining_inputs:
5033 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005034
Matthew Haddon01c359d2021-10-15 16:30:48 +01005035 if error_name == ErrorIf.ConcatShapeSumMismatch:
5036 output_shape[axis] += rng.integers(5, 10)
5037
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005038 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005039 all_dtypes = {
5040 DType.INT8,
5041 DType.INT16,
5042 DType.INT32,
5043 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005044 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005045 DType.FP16,
5046 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005047 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005048 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5049 outputDType = rng.choice(wrong_dtypes)
5050 else:
5051 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005052
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005053 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005054
5055 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005056 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005057
5058 output_shape = a.shape.copy()
5059
5060 for i in range(len(output_shape)):
5061 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5062
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005063 if error_name == ErrorIf.PadOutputShapeMismatch:
5064 bad_dim = rng.choice(range(len(output_shape)))
5065 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005066 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005067 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005068
Matthew Haddone807aae2021-10-11 18:12:58 +01005069 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005070 all_dtypes = [
5071 DType.INT8,
5072 DType.INT16,
5073 DType.INT32,
5074 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005075 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005076 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005077 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005078 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005079 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5080 outputDType = rng.choice(wrong_dtypes)
5081 else:
5082 outputDType = a.dtype
5083
5084 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005085
5086 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005087 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00005088 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00005089
5090 if error_name == ErrorIf.WrongOutputType:
5091 all_dtypes = [
5092 DType.INT8,
5093 DType.INT16,
5094 DType.INT32,
5095 DType.INT48,
5096 DType.FP32,
5097 DType.FP16,
5098 DType.BF16,
5099 ]
5100 wrong_dtypes = list(set(all_dtypes))
5101 outputDType = rng.choice(wrong_dtypes)
5102 else:
5103 outputDType = DType.SHAPE
5104
5105 return ser.addOutput(output_shape, outputDType)
5106
5107 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005108 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005109 output_shape = shape.copy()
5110
Matthew Haddone807aae2021-10-11 18:12:58 +01005111 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5112 for i in range(len(output_shape)):
5113 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5114
5115 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005116 all_dtypes = [
5117 DType.INT8,
5118 DType.INT16,
5119 DType.INT32,
5120 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005121 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005122 DType.FP16,
5123 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005124 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005125 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5126 outputDType = rng.choice(wrong_dtypes)
5127 else:
5128 outputDType = a.dtype
5129
5130 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005131
5132 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005133 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005134
Matthew Haddone807aae2021-10-11 18:12:58 +01005135 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005136 all_dtypes = [
5137 DType.INT8,
5138 DType.INT16,
5139 DType.INT32,
5140 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005141 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005142 DType.FP16,
5143 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005144 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005145 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005146 outputDType = rng.choice(wrong_dtypes)
5147 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005148 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005149
Luke Huttona4e48ca2023-02-22 11:53:48 +00005150 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005151 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005152 for index in range(len(output_shape)):
5153 if output_shape[index] <= 2:
5154 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5155 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005156 output_shape[index] = output_shape[index] + rng.choice(
5157 [-2, -1, 1, 2]
5158 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005159 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5160 output_shape = input.shape.copy()
5161 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005162 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005163
5164 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005165
5166 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005167 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005168
5169 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005170 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005171
5172 for i in range(len(output_shape)):
5173 output_shape[i] = a.shape[i] * multiples[i]
5174
Luke Huttona4e48ca2023-02-22 11:53:48 +00005175 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005176 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005177
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005178 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005179 all_dtypes = [
5180 DType.INT8,
5181 DType.INT16,
5182 DType.INT32,
5183 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005184 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005185 DType.FP16,
5186 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005187 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005188 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5189 outputDType = rng.choice(wrong_dtypes)
5190 else:
5191 outputDType = a.dtype
5192
5193 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005194
5195 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005196 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005197 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005198
Kevin Cheng550ccc52021-03-03 11:21:43 -08005199 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005200
Luke Huttona4e48ca2023-02-22 11:53:48 +00005201 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005202 for i in range(len(output_shape)):
5203 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005204
Luke Huttona4e48ca2023-02-22 11:53:48 +00005205 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5206 for i in range(len(output_shape)):
5207 output_shape[i] += rng.integers(1, 10)
5208 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005209 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005210
Matthew Haddone807aae2021-10-11 18:12:58 +01005211 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005212 all_dtypes = [
5213 DType.INT8,
5214 DType.INT16,
5215 DType.INT32,
5216 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005217 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005218 DType.FP16,
5219 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005220 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005221 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5222 outputDType = rng.choice(wrong_dtypes)
5223 else:
5224 outputDType = a.dtype
5225
5226 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005227
5228 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005229 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005230 if error_name != ErrorIf.WrongRank:
5231 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005232 assert len(indices.shape) == 2
5233 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005234
Kevin Cheng77d0f762020-11-24 10:26:32 -08005235 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5236
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005237 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005238 all_dtypes = [
5239 DType.INT8,
5240 DType.INT16,
5241 DType.INT32,
5242 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005243 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005244 DType.FP16,
5245 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005246 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005247 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5248 outputDType = rng.choice(wrong_dtypes)
5249 else:
5250 outputDType = values.dtype
5251
5252 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005253
5254 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005255 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005256 if error_name != ErrorIf.WrongRank:
5257 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005258 assert len(indices.shape) == 2
5259 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005260 assert values_in.shape[0] == indices.shape[0] # N
5261 assert input.shape[1] == indices.shape[1] # W
5262 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005263
5264 output_shape = values_in.shape
5265
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005266 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005267 all_dtypes = [
5268 DType.INT8,
5269 DType.INT16,
5270 DType.INT32,
5271 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005272 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005273 DType.FP16,
5274 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005275 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005276 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5277 outputDType = rng.choice(wrong_dtypes)
5278 else:
5279 outputDType = values_in.dtype
5280
5281 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005282
5283 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005284 def tableOp(ser, rng, input, error_name=None):
5285 # Same shape as the input, dtype dependent on input dtype
5286 if error_name != ErrorIf.WrongInputType:
5287 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005288 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005289 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005290 wrong_dtypes = [
5291 DType.INT8,
5292 DType.INT16,
5293 DType.INT32,
5294 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005295 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005296 DType.FP16,
5297 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005298 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005299 wrong_dtypes.remove(output_dtype)
5300 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005301 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005302
5303 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005304 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005305 serializer,
5306 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005307 input,
5308 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005309 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005310 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005311 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005312 input_dtype,
5313 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005314 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005315 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005316 # Calculate OH, OW
5317 scale_y_n = scale[0]
5318 scale_y_d = scale[1]
5319 scale_x_n = scale[2]
5320 scale_x_d = scale[3]
5321 if error_name == ErrorIf.ScaleSmallerEqualZero:
5322 scale_y_n = max(scale_y_n, 1)
5323 scale_y_d = max(scale_y_d, 1)
5324 scale_x_n = max(scale_x_n, 1)
5325 scale_x_d = max(scale_x_d, 1)
5326
5327 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5328 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5329
5330 if error_name is not None:
5331 # Make sure the output tensor is valid, which can occur when
5332 # scale, offset or border have been changed for ERROR_IFs
5333 oh = max(oh, 1)
5334 ow = max(ow, 1)
5335 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005336 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5337 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005338
5339 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5340 choices = [1, 2, 3]
5341 change = rng.choice(choices)
5342 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5343 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005344 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005345 oh -= scale_y_d
5346 assert oh > 0 # Should have been caught in agResize
5347 else:
5348 oh += scale_y_d
5349 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005350 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005351 ow -= scale_x_d
5352 assert ow > 0 # Should have been caught in agResize
5353 else:
5354 ow += scale_x_d
5355
Matthew Haddon848efb42021-09-09 12:30:53 +01005356 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005357 output_dims = [
5358 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005359 oh,
5360 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005361 input.shape[0],
5362 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005363 elif error_name == ErrorIf.BatchMismatch:
5364 output_dims = [
5365 input.shape[0] + rng.integers(1, 10),
5366 oh,
5367 ow,
5368 input.shape[3],
5369 ]
5370 elif error_name == ErrorIf.ChannelMismatch:
5371 output_dims = [
5372 input.shape[0],
5373 oh,
5374 ow,
5375 input.shape[3] + rng.integers(1, 10),
5376 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005377 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005378 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005379
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005380 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005381
5382 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005383 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005384 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005385
5386 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005387 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005388 if error_name == ErrorIf.ConvOutputShapeMismatch:
5389 choices = [1, 2, 3]
5390 change = rng.choice(choices)
5391 if change in [1, 3]:
5392 output_shape[1] = output_shape[1] + rng.choice(choices)
5393 if change in [2, 3]:
5394 output_shape[2] = output_shape[2] + rng.choice(choices)
5395
James Ward8b390432022-08-12 20:48:56 +01005396 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005397 # Pick some potentially correct output dtype if input type is incorrect
5398 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005399 else:
James Ward8b390432022-08-12 20:48:56 +01005400 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005401
5402 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005403 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005404 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005405 else:
5406 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005407 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005408 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005409
Kevin Cheng550ccc52021-03-03 11:21:43 -08005410 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005411
5412 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005413 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5414 outputs = []
5415
5416 assert ifm1.dtype == ifm2.dtype
5417 input_dtype = ifm1.dtype
5418
5419 if error_name != ErrorIf.FFTInputShapeMismatch:
5420 assert ifm1.shape == ifm2.shape
5421
5422 input_shape = ifm1.shape
5423 if error_name != ErrorIf.WrongRank:
5424 assert len(input_shape) == 3
5425
5426 output_shape = input_shape.copy()
5427 output_dtype = input_dtype
5428
5429 if error_name == ErrorIf.WrongOutputType:
5430 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005431 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005432 output_dtype = rng.choice(wrong_dtypes)
5433 elif error_name == ErrorIf.BatchMismatch:
5434 output_shape[0] += rng.integers(1, 10)
5435 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5436 modify_dim = rng.choice([1, 2])
5437 output_shape[modify_dim] += rng.integers(1, 10)
5438
5439 outputs.append(serializer.addOutput(output_shape, output_dtype))
5440 outputs.append(serializer.addOutput(output_shape, output_dtype))
5441 return outputs
5442
5443 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005444 def rfft2dOp(serializer, rng, value, error_name=None):
5445 outputs = []
5446
5447 input_shape = value.shape
5448 if error_name != ErrorIf.WrongRank:
5449 assert len(input_shape) == 3
5450
5451 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5452
5453 output_dtype = value.dtype
5454 if error_name == ErrorIf.WrongOutputType:
5455 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005456 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005457 output_dtype = rng.choice(wrong_dtypes)
5458 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005459 output_shape[0] += rng.integers(1, 10)
5460 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5461 modify_dim = rng.choice([1, 2])
5462 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005463
5464 outputs.append(serializer.addOutput(output_shape, output_dtype))
5465 outputs.append(serializer.addOutput(output_shape, output_dtype))
5466 return outputs