blob: ba10dcff309dea313c5ac8b3d656476d58cccd87 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
69 vals.append(v)
70 return tuple(sorted(vals))
71
72 self.random_float_range = {}
73 for dtype in (DType.FP32, DType.FP16, DType.BF16):
74 self.random_float_range[dtype] = convertFPRange(
75 args.tensor_fp_value_range,
76 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
77 )
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 def createSerializer(self, opName, testPath):
80 self.testPath = os.path.join(opName, testPath)
81
82 fullPath = os.path.join(self.basePath, self.testPath)
83 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010084 # Embed const data in the flatbuffer
85 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010086 if self.args.lazy_data_gen:
87 # Lazy data generation - so make constants files
88 constMode = ts.ConstMode.INPUTS
89 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 constMode = ts.ConstMode.EMBED_DUMP
91 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070092
93 def getSerializer(self):
94 return self.ser
95
Jeremy Johnson1271c442023-09-05 11:39:26 +010096 def serialize(self, testName, metaData=None):
97 path = Path(self.basePath) / self.testPath
98
99 # Write out TOSA flatbuffer binary
100 path_fb = path / f"{testName}.tosa"
101 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 fd.write(self.ser.serialize())
103
Jeremy Johnson1271c442023-09-05 11:39:26 +0100104 # Get JSON descriptor from serializer
105 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
106
107 if metaData:
108 # Add extra meta data to desc.json
109 desc["meta"] = metaData
110
111 # Validate desc.json before we output it
112 self.descSchemaValidator.validate_config(desc)
113
114 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100115 if "data_gen" in metaData:
116 if self.args.lazy_data_gen:
117 # Output datagen meta data as CPP data
118 path_md = path / f"{testName}_meta_data_gen.cpp"
119 with path_md.open("w") as fd:
120 fd.write(TOSA_AUTOGENERATED_HEADER)
121 fd.write("// Test meta data for data generation setup\n\n")
122 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
123 json.dump(metaData["data_gen"], fd)
124 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100125 if "compliance" in metaData:
126 # Output datagen meta data as CPP data
127 path_md = path / f"{testName}_meta_compliance.cpp"
128 with path_md.open("w") as fd:
129 fd.write(TOSA_AUTOGENERATED_HEADER)
130 fd.write("// Test meta data for compliance validation\n\n")
131 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
132 json.dump(metaData["compliance"], fd)
133 fd.write(')";\n\n')
134
135 # Write desc.json
136 path_desc = path / "desc.json"
137 with path_desc.open("w") as fd:
138 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700139
Matthew Haddon74567092021-07-16 15:38:20 +0100140 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000141 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100142 seed = self.random_seed + 1
143 self.rng = np.random.default_rng(seed)
144
Jeremy Johnson1271c442023-09-05 11:39:26 +0100145 def getDTypeRange(self, dtype, high_inclusive=False):
146 # Returns dtype value range boundaries (low, high)
147 # The high boundary is excluded in the range
148 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100149 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100150 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 elif dtype == DType.BOOL:
152 rng = (0, 2)
153 elif dtype == DType.UINT8:
154 rng = (0, 256)
155 elif dtype == DType.UINT16:
156 rng = (0, 65536)
157 elif dtype == DType.INT4:
158 # TOSA specific INT4 weight range from -7 to 7
159 rng = (-7, 8)
160 elif dtype == DType.INT8:
161 rng = (-128, 128)
162 elif dtype == DType.INT16:
163 rng = (-32768, 32768)
164 elif dtype in (DType.INT32, DType.SHAPE):
165 # restricting too large value for SHAPE
166 rng = (-(1 << 31), (1 << 31))
167 elif dtype == DType.INT48:
168 rng = (-(1 << 47), (1 << 47))
169 else:
170 raise Exception("Unknown dtype: {}".format(dtype))
171
172 if not high_inclusive:
173 # Exclusive high: low <= range < high
174 return rng
175 else:
176 # Inclusive range: low <= range <= high
177 return (rng[0], rng[1] - 1)
178
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000179 def getRandTensor(self, shape, dtype, data_range=None):
180 if data_range is None:
181 low, high = self.getDTypeRange(dtype)
182 else:
183 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100184
Eric Kunzee5e26762020-10-13 16:11:07 -0700185 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100188 return np.int64(self.rng.integers(low=low, high=high, size=shape))
189 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
190 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
191
192 if dtype == DType.FP16:
193 return np.float16(f_tensor)
194 else:
195 f32_tensor = np.float32(f_tensor)
196 if dtype == DType.BF16:
197 # Floor the last 16 bits of each f32 value
198 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
199 else:
200 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100202 # All other integer types
203 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700204
Kevin Cheng989cb052021-04-28 16:29:44 -0700205 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 placeholders = []
207
Kevin Cheng989cb052021-04-28 16:29:44 -0700208 assert len(shape_list) == len(dtype_list)
209
Jeremy Johnson1271c442023-09-05 11:39:26 +0100210 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100212 if not self.args.lazy_data_gen:
213 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700214 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
216 return placeholders
217
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700219 consts = []
220
Kevin Cheng989cb052021-04-28 16:29:44 -0700221 assert len(shape_list) == len(dtype_list)
222
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100225 if not self.args.lazy_data_gen:
226 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700227 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700228
229 return consts
230
231 def makeShape(self, rank):
232 if self.targetted_shape:
233 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800234 return np.int32(
235 self.rng.integers(
236 low=self.args.tensor_shape_range[0],
237 high=self.args.tensor_shape_range[1],
238 size=rank,
239 )
240 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700241
242 def setTargetShape(self, shape):
243 self.targetted_shape = shape
244
245 def randInt(self, low=0, high=256):
246 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
247
248 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100249 low, high = self.getDTypeRange(dtype)
250
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100251 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100252 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100253 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100254 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100255 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100256 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
257 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 elif dtype == DType.BOOL:
259 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700260 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700261 # Special size
262 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700263
264 return np.int32(self.rng.integers(low, high, size=1))[0]
265
266 def shapeStr(self, shape):
267
268 sStr = []
269 # Convert to strings
270 for i in shape:
271 sStr.append(str(i))
272
Kevin Cheng550ccc52021-03-03 11:21:43 -0800273 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100275 def typeStr(self, dtype):
276 if isinstance(dtype, list) or isinstance(dtype, tuple):
277 assert len(dtype) >= 2
278 strs = [self.typeStr(t) for t in dtype]
279 # Limit types to the first 2 as the 3rd is the accumulator
280 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700281 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100282 if dtype in gtu.DTYPE_ATTRIBUTES:
283 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700284 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100285 raise Exception(
286 "Unknown dtype, cannot convert to string: {}".format(dtype)
287 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700288
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100289 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100290 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100291 if dtype in gtu.DTYPE_ATTRIBUTES:
292 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700293 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100294 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700295
Luke Hutton57287132023-02-06 14:54:18 +0000296 def constrictBatchSize(self, shape):
297 # Limit the batch size unless an explicit target shape set
298 if self.args.max_batch_size and not self.args.target_shapes:
299 shape[0] = min(shape[0], self.args.max_batch_size)
300 return shape
301
James Ward30124a82023-02-02 14:56:33 +0000302 def makeDimension(self):
303 return self.randInt(
304 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
305 )
306
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100307 def tensorComplianceMetaData(
308 self, op, inputType, argsDict, outputTensor, errorName
309 ):
Jeremy Johnson708da822023-11-15 16:25:45 +0000310 # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet
311 UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100312 if (
313 errorName
314 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000315 or (
316 not gtu.dtypeIsSupportedByCompliance(inputType)
317 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
318 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100319 ):
320 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100321 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100322
Jeremy Johnson1271c442023-09-05 11:39:26 +0100323 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100324 compliance_tens = {
325 "mode": None,
326 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
327 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
328 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100329 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
330 mode = gtu.ComplianceMode.DOT_PRODUCT
331 compliance_tens["dot_product_info"] = {
332 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100333 "ks": int(argsDict["ksb"])
334 if "ksb" in argsDict
335 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100336 }
337 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
338 mode = gtu.ComplianceMode.FP_SPECIAL
339 elif "compliance" in op and "ulp" in op["compliance"]:
340 mode = gtu.ComplianceMode.ULP
341 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
342 elif op["op"] == Op.REDUCE_PRODUCT:
343 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnson534923d2023-12-04 11:11:06 +0000344 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000345 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnson1271c442023-09-05 11:39:26 +0100346 else:
347 mode = gtu.ComplianceMode.EXACT
348 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
349
350 return compliance_tens
351
352 # Build Op functions
353 # Create the output tensor (calling OutputShaper as needed)
354 # Do final tweaks to attributes (if necessary for errorIf)
355 # Add Op into graph
356 # Return resulting tensor information or BuildInfo
357
358 class BuildInfo:
359 """Enhanced build information containing result tensor and associated compliance dict."""
360
361 def __init__(self, resultTensor, complianceDict):
362 self.resultTensor = resultTensor
363 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700364
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000365 def build_unary(
366 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
367 ):
368 assert len(inputs) == 1
369 a = inputs[0]
370 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100371
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000372 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100373
374 # Ensure new output type has correct qinfo
375 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000376 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000377 qinfo = [
378 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000379 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000380 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100381
382 # Invalidate Input/Output list for error if checks.
383 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000384 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100385 pCount, cCount = op["operands"]
386 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000387 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
388 self, error_name, input_list, output_list
389 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100390
Les Bell729b0352021-11-24 10:28:21 +0000391 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100392 self.ser,
393 validator_fcns,
394 error_name,
395 op=op,
396 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000397 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000398 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000399 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100400 input_list=input_list,
401 output_list=output_list,
402 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000403 ):
404 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100405
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000406 attr = None
407 if op["op"] == Op.NEGATE:
408 attr = ts.TosaSerializerAttribute()
409 attr.NegateAttribute(qinfo[0], qinfo[1])
410
411 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000412
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000413 compliance = self.tensorComplianceMetaData(
414 op, a.dtype, args_dict, result_tensor, error_name
415 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000416 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700417
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000418 def build_binary_broadcast(
419 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
420 ):
421 assert len(inputs) == 2
422 a, b = inputs
423 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000424 self.ser, self.rng, a, b, error_name
425 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100426
427 # Invalidate Input/Output list for error if checks.
428 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000429 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100430 pCount, cCount = op["operands"]
431 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000432 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
433 self, error_name, input_list, output_list
434 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100435
Les Bell729b0352021-11-24 10:28:21 +0000436 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100437 self.ser,
438 validator_fcns,
439 error_name,
440 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000441 input1=a,
442 input2=b,
443 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000444 output_dtype=result_tensor.dtype,
445 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100446 input_list=input_list,
447 output_list=output_list,
448 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000449 ):
450 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100451
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000452 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000453
Jeremy Johnson9a758382023-11-07 16:27:35 +0000454 compliance = self.tensorComplianceMetaData(
455 op, a.dtype, args_dict, result_tensor, error_name
456 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000457
458 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700459
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100460 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700461 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000462 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700463 return result_tens
464
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000465 def build_arithmetic_right_shift(
466 self, op, a, b, round, validator_fcns=None, error_name=None
467 ):
468 result_tens = OutputShaper.binaryBroadcastOp(
469 self.ser, self.rng, a, b, error_name
470 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100471
472 # Invalidate Input/Output list for error if checks.
473 input_list = [a.name, b.name]
474 output_list = [result_tens.name]
475 pCount, cCount = op["operands"]
476 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000477 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
478 self, error_name, input_list, output_list
479 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100480
Les Bell729b0352021-11-24 10:28:21 +0000481 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100482 self.ser,
483 validator_fcns,
484 error_name,
485 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000486 input1=a,
487 input2=b,
488 input_dtype=a.dtype,
489 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000490 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100491 input_list=input_list,
492 output_list=output_list,
493 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000494 ):
495 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800496
497 attr = ts.TosaSerializerAttribute()
498 attr.ArithmeticRightShiftAttribute(round)
499
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000500 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800501 return result_tens
502
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100503 def build_mul(
504 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
505 ):
506 assert len(inputs) == 2
507 a, b = inputs
508 shift = args_dict["shift"]
509
510 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000511 self.ser, self.rng, a, b, error_name
512 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700513
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100514 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100515 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100516 result_tensor.setDtype(DType.INT32)
517
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100518 if error_name == ErrorIf.WrongOutputType:
519 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
520 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100521 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100522
523 # Invalidate Input/Output list for error if checks.
524 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100525 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100526 pCount, cCount = op["operands"]
527 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000528 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
529 self, error_name, input_list, output_list
530 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100531
Les Bell729b0352021-11-24 10:28:21 +0000532 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100533 self.ser,
534 validator_fcns,
535 error_name,
536 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000537 input1=a,
538 input2=b,
539 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100540 output_dtype=result_tensor.dtype,
541 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100542 input_list=input_list,
543 output_list=output_list,
544 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000545 ):
546 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700547
Kevin Chengaee1fac2020-11-11 13:54:06 -0800548 attr = ts.TosaSerializerAttribute()
549 attr.MulAttribute(shift)
550
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000551 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100552
553 compliance = self.tensorComplianceMetaData(
554 op, a.dtype, args_dict, result_tensor, error_name
555 )
556
557 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700558
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100559 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
560 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700561
Kevin Chengfe392ce2021-10-18 21:51:55 +0000562 attr = ts.TosaSerializerAttribute()
563 attr.TableAttribute(table)
564
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100565 # Invalidate Input/Output list for error if checks.
566 input_list = [a.name]
567 output_list = [result_tens.name]
568 pCount, cCount = op["operands"]
569 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000570 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
571 self, error_name, input_list, output_list
572 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100573
Les Bell729b0352021-11-24 10:28:21 +0000574 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100575 self.ser,
576 validator_fcns,
577 error_name,
578 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000579 input_shape=a.shape,
580 input_dtype=a.dtype,
581 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000582 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100583 input_list=input_list,
584 output_list=output_list,
585 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000586 ):
587 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100588
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000589 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700590
591 return result_tens
592
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100593 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
594 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
595
596 # Invalidate Input/Output list for error if checks.
597 input_list = [cond.name, a.name, b.name]
598 output_list = [result_tens.name]
599 pCount, cCount = op["operands"]
600 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000601 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
602 self, error_name, input_list, output_list
603 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100604
Les Bell729b0352021-11-24 10:28:21 +0000605 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100606 self.ser,
607 validator_fcns,
608 error_name,
609 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000610 input1=cond,
611 input2=a,
612 input3=b,
613 input_shape=a.shape,
614 input_dtype=a.dtype,
615 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000616 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100617 input_list=input_list,
618 output_list=output_list,
619 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000620 ):
621 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000623 self.ser.addOperator(
624 op["op"],
625 input_list,
626 output_list,
627 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700628 return result_tens
629
Jeremy Johnsona0150012023-11-15 15:52:06 +0000630 def build_comparison(
631 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
632 ):
633 assert len(inputs) == 2
634 a, b = inputs
635
636 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000637 self.ser, self.rng, a, b, error_name
638 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100639
640 # Invalidate Input/Output list for error if checks.
641 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000642 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100643 pCount, cCount = op["operands"]
644 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000645 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
646 self, error_name, input_list, output_list
647 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100648
Les Bell729b0352021-11-24 10:28:21 +0000649 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100650 self.ser,
651 validator_fcns,
652 error_name,
653 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000654 input1=a,
655 input2=b,
656 input_shape=a.shape,
657 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000658 output_shape=result_tensor.shape,
659 output_dtype=result_tensor.dtype,
660 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100661 input_list=input_list,
662 output_list=output_list,
663 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000664 ):
665 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100666
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 self.ser.addOperator(
668 op["op"],
669 input_list,
670 output_list,
671 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000672
673 compliance = self.tensorComplianceMetaData(
674 op, a.dtype, args_dict, result_tensor, error_name
675 )
676 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700677
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000678 def build_argmax(
679 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
680 ):
681 assert len(inputs) == 1
682 a = inputs[0]
683 axis = args_dict["axis"]
684 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100685
686 # Invalidate Input/Output list for error if checks.
687 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000688 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100689 pCount, cCount = op["operands"]
690 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000691 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
692 self, error_name, input_list, output_list
693 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100694
Les Bell729b0352021-11-24 10:28:21 +0000695 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100696 self.ser,
697 validator_fcns,
698 error_name,
699 op=op,
700 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000701 input_shape=a.shape,
702 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000703 output_shape=result_tensor.shape,
704 output_dtype=result_tensor.dtype,
705 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100706 input_list=input_list,
707 output_list=output_list,
708 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000709 ):
710 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700711
712 attr = ts.TosaSerializerAttribute()
713 attr.AxisAttribute(axis)
714
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000715 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000716
717 compliance = self.tensorComplianceMetaData(
718 op, inputs[0].dtype, args_dict, result_tensor, error_name
719 )
720 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700721
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000722 def build_pool2d(
723 self,
724 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100725 inputs,
726 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000727 validator_fcns=None,
728 error_name=None,
729 qinfo=None,
730 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100731 assert len(inputs) == 1
732 input = inputs[0]
733 # max_pool has no accum_dtype
734 accum_dtype = (
735 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
736 )
737 stride = args_dict["stride"]
738 pad = args_dict["pad"]
739 kernel = args_dict["kernel"]
740
Jeremy Johnson0601f802023-11-08 16:28:09 +0000741 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000742 self.ser, self.rng, input, kernel, stride, pad, error_name
743 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100744
745 # Ensure new output type has correct qinfo
746 if error_name == ErrorIf.WrongInputType:
747 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000748 qinfo = [
749 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000750 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000751 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100752
753 # Invalidate Input/Output list for error if checks.
754 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000755 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100756 pCount, cCount = op["operands"]
757 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000758 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
759 self, error_name, input_list, output_list
760 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100761
Les Bell729b0352021-11-24 10:28:21 +0000762 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100763 self.ser,
764 validator_fcns,
765 error_name,
766 op=op,
767 input_shape=input.shape,
768 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000769 output_shape=result_tensor.shape,
770 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100771 kernel=kernel,
772 stride=stride,
773 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000774 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000775 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100776 input_list=input_list,
777 output_list=output_list,
778 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000779 ):
780 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700781
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000782 if qinfo is None:
783 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700784
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000785 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100786 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000787
788 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700789
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100790 compliance = self.tensorComplianceMetaData(
791 op, inputs[0].dtype, args_dict, result_tensor, error_name
792 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100793
794 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100795
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000796 def build_conv2d(
797 self,
798 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100799 inputs,
800 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000801 validator_fcns=None,
802 error_name=None,
803 qinfo=None,
804 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100805 assert len(inputs) == 3
806 ifm, filter, bias = inputs
807 accum_dtype = args_dict["acc_type"]
808 strides = args_dict["stride"]
809 padding = args_dict["pad"]
810 dilations = args_dict["dilation"]
811
Kevin Cheng550ccc52021-03-03 11:21:43 -0800812 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100813 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100814 self.ser,
815 self.rng,
816 ifm,
817 filter,
818 accum_dtype,
819 strides,
820 padding,
821 dilations,
822 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000823 )
824
825 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000826 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
827 DType.INT8,
828 DType.UINT8,
829 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000830 qinfo = [
831 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100832 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000833 ]
Les Bell0e027d42021-11-09 14:42:14 +0000834
835 # Invalidate Input/Output list for error_if checks.
836 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100837 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000838 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000839 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
840 self, error_name, input_list, output_list
841 )
Les Bell0e027d42021-11-09 14:42:14 +0000842
Les Bell729b0352021-11-24 10:28:21 +0000843 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000844 self.ser,
845 validator_fcns,
846 error_name,
847 op=op,
848 input_dtype=ifm.dtype,
849 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100850 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000851 qinfo=qinfo,
852 input_list=input_list,
853 num_operands=num_operands,
854 output_list=output_list,
855 pad=padding,
856 stride=strides,
857 dilation=dilations,
858 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100859 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100860 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000861 ):
862 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700863
Tai Lyd3797f02023-11-15 23:06:19 +0000864 # TODO - Test local_bound, for now set local bound attribute to False
865 local_bound = False
866
Eric Kunzee5e26762020-10-13 16:11:07 -0700867 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000868 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700869
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000870 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100871
872 compliance = self.tensorComplianceMetaData(
873 op, ifm.dtype, args_dict, result_tensor, error_name
874 )
875
876 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700877
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000878 def build_conv3d(
879 self,
880 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100881 inputs,
882 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000883 validator_fcns=None,
884 error_name=None,
885 qinfo=None,
886 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100887 assert len(inputs) == 3
888 ifm, filter, bias = inputs
889 accum_dtype = args_dict["acc_type"]
890 strides = args_dict["stride"]
891 padding = args_dict["pad"]
892 dilations = args_dict["dilation"]
893
Kevin Cheng1533b852021-09-01 12:51:58 -0700894 assert len(padding) == 6
895 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100896 self.ser,
897 self.rng,
898 ifm,
899 filter,
900 accum_dtype,
901 strides,
902 padding,
903 dilations,
904 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000905 )
906
907 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000908 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
909 DType.INT8,
910 DType.UINT8,
911 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000912 qinfo = [
913 TosaQuantGen.getZeroPoint(self, ifm.dtype),
914 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
915 ]
Les Bell0e027d42021-11-09 14:42:14 +0000916
917 # Invalidate Input/Output list for error_if checks.
918 input_list = [ifm.name, filter.name, bias.name]
919 output_list = [result_tens.name]
920 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000921 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
922 self, error_name, input_list, output_list
923 )
Les Bell0e027d42021-11-09 14:42:14 +0000924
Les Bell729b0352021-11-24 10:28:21 +0000925 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000926 self.ser,
927 validator_fcns,
928 error_name,
929 op=op,
930 input_dtype=ifm.dtype,
931 weight_dtype=filter.dtype,
932 output_dtype=result_tens.dtype,
933 qinfo=qinfo,
934 input_list=input_list,
935 num_operands=num_operands,
936 output_list=output_list,
937 pad=padding,
938 stride=strides,
939 dilation=dilations,
940 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100941 weight_shape=filter.shape,
942 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000943 ):
944 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700945
Tai Lyd3797f02023-11-15 23:06:19 +0000946 # TODO - Test local_bound, for now set local bound attribute to False
947 local_bound = False
948
Kevin Cheng1533b852021-09-01 12:51:58 -0700949 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000950 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700951
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000952 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700953 return result_tens
954
Kevin Cheng550ccc52021-03-03 11:21:43 -0800955 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000956 self,
957 op,
958 ifm,
959 filter,
960 bias,
James Ward8b390432022-08-12 20:48:56 +0100961 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000962 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700963 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000964 output_shape,
965 validator_fcns=None,
966 error_name=None,
967 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800968 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700969 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000970 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100971 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000972 )
Les Bell0e027d42021-11-09 14:42:14 +0000973
974 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000975 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
976 DType.INT8,
977 DType.UINT8,
978 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000979 qinfo = [
980 TosaQuantGen.getZeroPoint(self, ifm.dtype),
981 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
982 ]
Les Bell0e027d42021-11-09 14:42:14 +0000983
984 # Invalidate Input/Output list for error_if checks.
985 input_list = [ifm.name, filter.name, bias.name]
986 output_list = [result_tens.name]
987 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000988 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
989 self, error_name, input_list, output_list
990 )
Les Bell0e027d42021-11-09 14:42:14 +0000991
Les Bell729b0352021-11-24 10:28:21 +0000992 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000993 self.ser,
994 validator_fcns,
995 error_name,
996 op=op,
997 input_dtype=ifm.dtype,
998 weight_dtype=filter.dtype,
999 output_dtype=result_tens.dtype,
1000 qinfo=qinfo,
1001 input_list=input_list,
1002 num_operands=num_operands,
1003 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001004 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001005 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001006 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001007 weight_shape=filter.shape,
1008 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001009 ):
1010 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001011
Tai Lyd3797f02023-11-15 23:06:19 +00001012 # TODO - Test local_bound, for now set local bound attribute to False
1013 local_bound = False
1014
Eric Kunzee5e26762020-10-13 16:11:07 -07001015 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001016 attr.TransposeConvAttribute(
1017 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1018 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001019
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001020 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001021 return result_tens
1022
Kevin Cheng550ccc52021-03-03 11:21:43 -08001023 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001024 self,
1025 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001026 inputs,
1027 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001028 validator_fcns=None,
1029 error_name=None,
1030 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001031 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001032 assert len(inputs) == 3
1033 ifm, filter, bias = inputs
1034 accum_dtype = args_dict["acc_type"]
1035 strides = args_dict["stride"]
1036 padding = args_dict["pad"]
1037 dilations = args_dict["dilation"]
1038
Kevin Cheng550ccc52021-03-03 11:21:43 -08001039 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001040 self.ser,
1041 self.rng,
1042 ifm,
1043 filter,
1044 accum_dtype,
1045 strides,
1046 padding,
1047 dilations,
1048 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001049 )
1050
1051 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001052 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1053 DType.INT8,
1054 DType.UINT8,
1055 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001056 qinfo = [
1057 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1058 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1059 ]
Les Bell0e027d42021-11-09 14:42:14 +00001060
1061 # Invalidate Input/Output list for error_if checks.
1062 input_list = [ifm.name, filter.name, bias.name]
1063 output_list = [result_tens.name]
1064 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001065 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1066 self, error_name, input_list, output_list
1067 )
Les Bell0e027d42021-11-09 14:42:14 +00001068
Les Bell729b0352021-11-24 10:28:21 +00001069 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001070 self.ser,
1071 validator_fcns,
1072 error_name,
1073 op=op,
1074 input_dtype=ifm.dtype,
1075 weight_dtype=filter.dtype,
1076 output_dtype=result_tens.dtype,
1077 qinfo=qinfo,
1078 input_list=input_list,
1079 num_operands=num_operands,
1080 output_list=output_list,
1081 pad=padding,
1082 stride=strides,
1083 dilation=dilations,
1084 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001085 weight_shape=filter.shape,
1086 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001087 ):
1088 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001089
Tai Lyd3797f02023-11-15 23:06:19 +00001090 # TODO - Test local_bound, for now set local bound attribute to False
1091 local_bound = False
1092
Eric Kunzee5e26762020-10-13 16:11:07 -07001093 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001094 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001095
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001096 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001097 return result_tens
1098
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001099 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001100 self,
1101 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001102 inputs,
1103 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001104 validator_fcns=None,
1105 error_name=None,
1106 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001107 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001108 assert len(inputs) == 3
1109 ifm, filter, bias = inputs
1110 accum_dtype = args_dict["acc_type"]
1111
1112 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001113 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001114 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001115
1116 # Invalidate Input/Output list for error if checks.
1117 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001118 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001119 pCount, cCount = op["operands"]
1120 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001121 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1122 self, error_name, input_list, output_list
1123 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001124
Les Bell729b0352021-11-24 10:28:21 +00001125 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001126 self.ser,
1127 validator_fcns,
1128 error_name,
1129 op=op,
1130 input_shape=ifm.shape,
1131 input_dtype=ifm.dtype,
1132 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001133 output_shape=result_tensor.shape,
1134 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001135 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001136 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001137 input_list=input_list,
1138 output_list=output_list,
1139 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001140 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001141 ):
1142 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001143
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001144 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001145 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001146
1147 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001148
1149 compliance = self.tensorComplianceMetaData(
1150 op, ifm.dtype, args_dict, result_tensor, error_name
1151 )
1152
1153 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001154
James Ward8b390432022-08-12 20:48:56 +01001155 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001156 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001157 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001158 assert len(inputs) == 2
1159 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001160 accum_dtype = args_dict["acc_type"]
1161 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001162 self.ser, self.rng, a, b, accum_dtype, error_name
1163 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001164
1165 # Invalidate Input/Output list for error if checks.
1166 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001167 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001168 pCount, cCount = op["operands"]
1169 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001170 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1171 self, error_name, input_list, output_list
1172 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001173
Les Bell729b0352021-11-24 10:28:21 +00001174 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001175 self.ser,
1176 validator_fcns,
1177 error_name,
1178 op=op,
1179 input_shape=a.shape,
1180 input_dtype=a.dtype,
1181 input2_shape=b.shape,
1182 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001183 output_shape=result_tensor.shape,
1184 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001185 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001186 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001187 input_list=input_list,
1188 output_list=output_list,
1189 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001190 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001191 ):
1192 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001193
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001194 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001195 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001196
1197 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001198
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001199 compliance = self.tensorComplianceMetaData(
1200 op, a.dtype, args_dict, result_tensor, error_name
1201 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001202
1203 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001204
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001205 def build_reduce(
1206 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1207 ):
1208 assert len(inputs) == 1
1209 a = inputs[0]
1210 axis = args_dict["axis"]
1211 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001212
1213 # Invalidate Input/Output list for error if checks.
1214 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001215 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001216 pCount, cCount = op["operands"]
1217 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001218 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1219 self, error_name, input_list, output_list
1220 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001221
Les Bell729b0352021-11-24 10:28:21 +00001222 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001223 self.ser,
1224 validator_fcns,
1225 error_name,
1226 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001227 axis=axis,
1228 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001229 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001230 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001231 output_dtype=result_tensor.dtype,
1232 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001233 input_list=input_list,
1234 output_list=output_list,
1235 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001236 ):
1237 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001238
1239 attr = ts.TosaSerializerAttribute()
1240 attr.AxisAttribute(axis)
1241
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001242 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001243
1244 if op["op"] == Op.REDUCE_PRODUCT:
1245 # TODO: Add compliance support!
1246 compliance = None
1247 else:
1248 compliance = self.tensorComplianceMetaData(
1249 op, a.dtype, args_dict, result_tensor, error_name
1250 )
1251
1252 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001253
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001254 def build_clamp(
1255 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1256 ):
1257 assert len(inputs) == 1
1258 a = inputs[0]
1259
1260 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001261
Jeremy Johnson18e26662021-07-22 16:15:29 +01001262 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001263
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001264 if error_name == ErrorIf.MaxSmallerMin:
1265 # Make sure the numbers are different to invoke this error
1266 while v[0] == v[1]:
1267 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1268 max_val = min(v)
1269 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001270 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001271 max_val = max(v)
1272 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001273
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001274 # Invalidate Input/Output list for error if checks.
1275 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001276 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001277 pCount, cCount = op["operands"]
1278 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001279 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1280 self, error_name, input_list, output_list
1281 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001282
Les Bell729b0352021-11-24 10:28:21 +00001283 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001284 self.ser,
1285 validator_fcns,
1286 error_name,
1287 op=op,
1288 max_val=max_val,
1289 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001290 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001291 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001292 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001293 output_dtype=result_tensor.dtype,
1294 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001295 input_list=input_list,
1296 output_list=output_list,
1297 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001298 ):
1299 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001300
1301 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001302 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1303 if a.dtype == DType.FP16:
1304 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1305 min_val = min_val.astype(np.float32)
1306 max_val = max_val.astype(np.float32)
1307
1308 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001309 else:
James Ward34071252022-12-07 15:48:47 +00001310 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001311
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001312 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001313
1314 compliance = self.tensorComplianceMetaData(
1315 op, a.dtype, args_dict, result_tensor, error_name
1316 )
1317
1318 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001319
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001320 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1321 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001322 attr = ts.TosaSerializerAttribute()
1323
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001324 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001325
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001326 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001327 return result_tens
1328
1329 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001330 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1331 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001332
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001333 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001334 return result_tens
1335
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001336 def build_activation(
1337 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1338 ):
1339 assert len(inputs) == 1
1340 a = inputs[0]
1341
1342 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001343
1344 # Invalidate Input/Output list for error if checks.
1345 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001346 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001347 pCount, cCount = op["operands"]
1348 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001349 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1350 self, error_name, input_list, output_list
1351 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001352
Les Bell729b0352021-11-24 10:28:21 +00001353 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001354 self.ser,
1355 validator_fcns,
1356 error_name,
1357 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001358 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001359 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001360 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001361 output_dtype=result_tensor.dtype,
1362 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001363 input_list=input_list,
1364 output_list=output_list,
1365 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001366 ):
1367 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001368
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001369 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001370
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001371 compliance = self.tensorComplianceMetaData(
1372 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001373 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001375 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001376
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001377 def build_concat(
1378 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1379 ):
1380 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001381 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001382 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001383
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001384 result_tensor = OutputShaper.concatOp(
1385 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001386 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001387
Matthew Haddon818ab902021-07-27 09:12:49 +01001388 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001389 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001390 input_tensor_names.append(tensor.name)
1391
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001392 # Invalidate Input/Output list for error if checks.
1393 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001394 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001395 pCount, cCount = op["operands"]
1396 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001397 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1398 self, error_name, input_list, output_list
1399 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001400
Les Bell729b0352021-11-24 10:28:21 +00001401 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001402 self.ser,
1403 validator_fcns,
1404 error_name,
1405 op=op,
1406 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001407 input_shape=inputs[0].shape,
1408 output_shape=result_tensor.shape,
1409 input_dtype=inputs[0].dtype,
1410 output_dtype=result_tensor.dtype,
1411 inputs=inputs,
1412 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001413 input_list=input_list,
1414 output_list=output_list,
1415 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001416 ):
1417 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001418
1419 attr = ts.TosaSerializerAttribute()
1420 attr.AxisAttribute(axis)
1421
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001422 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001423 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001424
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001425 def build_pad(
1426 self,
1427 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001428 inputs,
1429 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001430 validator_fcns=None,
1431 error_name=None,
1432 qinfo=None,
1433 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001434 assert len(inputs) == 1
1435 a = inputs[0]
1436 padding = args_dict["pad"]
1437 pad_const_int = args_dict["pad_const_int"]
1438 pad_const_float = args_dict["pad_const_fp"]
1439
1440 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001441
Kevin Chengfe392ce2021-10-18 21:51:55 +00001442 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001443 attr.PadAttribute(
1444 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1445 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001446
Matthew Haddone807aae2021-10-11 18:12:58 +01001447 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001448 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001449 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001450 pCount, cCount = op["operands"]
1451 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001452 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1453 self, error_name, input_list, output_list
1454 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001455
Les Bell729b0352021-11-24 10:28:21 +00001456 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001457 self.ser,
1458 validator_fcns,
1459 error_name,
1460 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001461 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001462 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001463 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001464 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001465 pad=padding,
1466 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001467 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001468 input_list=input_list,
1469 output_list=output_list,
1470 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001471 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001472 ):
1473 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001474
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001475 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001476
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001477 compliance = self.tensorComplianceMetaData(
1478 op, a.dtype, args_dict, result_tensor, error_name
1479 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001480
1481 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001482
Won Jeona21b2e82023-08-10 10:33:01 +00001483 def build_dim(
1484 self,
1485 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001486 inputs,
1487 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001488 validator_fcns=None,
1489 error_name=None,
1490 qinfo=None,
1491 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001492 assert len(inputs) == 1
1493 a = inputs[0]
1494 axis = args_dict["axis"]
1495 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001496
1497 # Invalidate Input/Output list for error if checks.
1498 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001499 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001500 pCount, cCount = op["operands"]
1501 num_operands = pCount + cCount
1502 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1503 self, error_name, input_list, output_list
1504 )
1505
1506 if not TosaErrorValidator.evValidateErrorIfs(
1507 self.ser,
1508 validator_fcns,
1509 error_name,
1510 op=op,
1511 axis=axis,
1512 input_shape=a.shape,
1513 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001514 output_shape=result_tensor.shape,
1515 output_dtype=result_tensor.dtype,
1516 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001517 input_list=input_list,
1518 output_list=output_list,
1519 num_operands=num_operands,
1520 ):
1521 return None
1522
1523 attr = ts.TosaSerializerAttribute()
1524 attr.AxisAttribute(axis)
1525
1526 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001527 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001528
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001529 def build_reshape(
1530 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1531 ):
1532 assert len(inputs) == 1
1533 a = inputs[0]
1534 new_shape = args_dict["new_shape"]
1535 result_tensor = OutputShaper.reshapeOp(
1536 self.ser, self.rng, a, new_shape, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001537 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001538
1539 # Invalidate Input/Output list for error if checks.
1540 input_list = [a.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001541 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001542 pCount, cCount = op["operands"]
1543 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001544 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1545 self, error_name, input_list, output_list
1546 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001547
Les Bell729b0352021-11-24 10:28:21 +00001548 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001549 self.ser,
1550 validator_fcns,
1551 error_name,
1552 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001553 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001554 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001555 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001556 output_dtype=result_tensor.dtype,
1557 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001558 input_list=input_list,
1559 output_list=output_list,
1560 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001561 ):
1562 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001563
1564 attr = ts.TosaSerializerAttribute()
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001565 attr.ReshapeAttribute(new_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07001566
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001567 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001568
1569 compliance = self.tensorComplianceMetaData(
1570 op, a.dtype, args_dict, result_tensor, error_name
1571 )
1572
1573 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001574
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001575 def build_reverse(
1576 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1577 ):
1578 assert len(inputs) == 1
1579 a = inputs[0]
1580 axis = args_dict["axis"]
1581 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001582
1583 # Invalidate Input/Output list for error if checks.
1584 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001585 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001586 pCount, cCount = op["operands"]
1587 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001588 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1589 self, error_name, input_list, output_list
1590 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001591
Les Bell729b0352021-11-24 10:28:21 +00001592 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001593 self.ser,
1594 validator_fcns,
1595 error_name,
1596 op=op,
1597 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001598 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001599 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001600 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001601 output_dtype=result_tensor.dtype,
1602 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001603 input_list=input_list,
1604 output_list=output_list,
1605 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001606 ):
1607 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001608
1609 attr = ts.TosaSerializerAttribute()
1610 attr.AxisAttribute(axis)
1611
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001612 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001613 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001614
Matthew Haddone807aae2021-10-11 18:12:58 +01001615 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1616 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001617
Kevin Chengfe392ce2021-10-18 21:51:55 +00001618 attr = ts.TosaSerializerAttribute()
1619 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001620
Matthew Haddone807aae2021-10-11 18:12:58 +01001621 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001622 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001623 output_list = [result_tens.name]
1624 pCount, cCount = op["operands"]
1625 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001626 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1627 self, error_name, input_list, output_list
1628 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001629
Les Bell729b0352021-11-24 10:28:21 +00001630 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001631 self.ser,
1632 validator_fcns,
1633 error_name,
1634 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001635 input_shape=a.shape,
1636 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001637 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001638 input_dtype=a.dtype,
1639 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001640 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001641 input_list=input_list,
1642 output_list=output_list,
1643 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001644 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001645 ):
1646 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001647
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001648 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001649 return result_tens
1650
Matthew Haddone807aae2021-10-11 18:12:58 +01001651 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001652 result_tens = OutputShaper.sliceOp(
1653 self.ser, self.rng, a, start, size, error_name
1654 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001655
1656 # Invalidate Input/Output list for error if checks.
1657 input_list = [a.name]
1658 output_list = [result_tens.name]
1659 pCount, cCount = op["operands"]
1660 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001661 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1662 self, error_name, input_list, output_list
1663 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001664
Les Bell729b0352021-11-24 10:28:21 +00001665 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001666 self.ser,
1667 validator_fcns,
1668 error_name,
1669 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001670 input_shape=a.shape,
1671 output_shape=result_tens.shape,
1672 input_dtype=a.dtype,
1673 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001674 start=start,
1675 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001676 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001677 input_list=input_list,
1678 output_list=output_list,
1679 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001680 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001681 ):
1682 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001683
1684 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001685 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001686
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001687 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001688 return result_tens
1689
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001690 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1691 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1692
1693 # Invalidate Input/Output list for error if checks.
1694 input_list = [a.name]
1695 output_list = [result_tens.name]
1696 pCount, cCount = op["operands"]
1697 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001698 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1699 self, error_name, input_list, output_list
1700 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001701
Les Bell729b0352021-11-24 10:28:21 +00001702 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001703 self.ser,
1704 validator_fcns,
1705 error_name,
1706 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001707 input_shape=a.shape,
1708 output_shape=result_tens.shape,
1709 input_dtype=a.dtype,
1710 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001711 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001712 input_list=input_list,
1713 output_list=output_list,
1714 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001715 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001716 ):
1717 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001718
1719 attr = ts.TosaSerializerAttribute()
1720 attr.TileAttribute(multiples)
1721
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001722 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001723 return result_tens
1724
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001725 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001726
1727 # Create a new indicies tensor
1728 # here with data that doesn't exceed the dimensions of the values tensor
1729
Kevin Cheng550ccc52021-03-03 11:21:43 -08001730 K = values.shape[1] # K
1731 W = self.randInt(
1732 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1733 ) # W
1734 indicies_arr = np.int32(
1735 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1736 ) # (N, W)
1737 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001738
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001739 result_tens = OutputShaper.gatherOp(
1740 self.ser, self.rng, values, indicies, error_name
1741 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001742
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001743 # Invalidate Input/Output list for error if checks.
1744 input_list = [values.name, indicies.name]
1745 output_list = [result_tens.name]
1746 pCount, cCount = op["operands"]
1747 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001748 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1749 self, error_name, input_list, output_list
1750 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001751
Les Bell729b0352021-11-24 10:28:21 +00001752 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001753 self.ser,
1754 validator_fcns,
1755 error_name,
1756 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001757 input_shape=values.shape,
1758 output_shape=result_tens.shape,
1759 input_dtype=values.dtype,
1760 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001761 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001762 input_list=input_list,
1763 output_list=output_list,
1764 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001765 ):
1766 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001767
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001768 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001769
1770 return result_tens
1771
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001772 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001773
1774 # Create a new indicies tensor
1775 # here with data that doesn't exceed the dimensions of the values_in tensor
1776
Kevin Cheng550ccc52021-03-03 11:21:43 -08001777 K = values_in.shape[1] # K
1778 W = input.shape[1] # W
1779 indicies_arr = np.int32(
1780 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1781 ) # (N, W)
1782 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001783
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001784 result_tens = OutputShaper.scatterOp(
1785 self.ser, self.rng, values_in, indicies, input, error_name
1786 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001787
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001788 # Invalidate Input/Output list for error if checks.
1789 input_list = [values_in.name, indicies.name, input.name]
1790 output_list = [result_tens.name]
1791 pCount, cCount = op["operands"]
1792 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001793 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1794 self, error_name, input_list, output_list
1795 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001796
Les Bell729b0352021-11-24 10:28:21 +00001797 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001798 self.ser,
1799 validator_fcns,
1800 error_name,
1801 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001802 input_shape=values_in.shape,
1803 output_shape=result_tens.shape,
1804 input_dtype=values_in.dtype,
1805 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001806 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001807 input_list=input_list,
1808 output_list=output_list,
1809 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001810 ):
1811 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001812
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001813 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001814
Kevin Cheng77d0f762020-11-24 10:26:32 -08001815 return result_tens
1816
Kevin Cheng550ccc52021-03-03 11:21:43 -08001817 def build_resize(
1818 self,
1819 op,
1820 input,
1821 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001822 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001823 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001824 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001825 input_dtype,
1826 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001827 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001828 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001829 ):
1830 result_tens = OutputShaper.resizeOp(
1831 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001832 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001833 input,
1834 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001835 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001836 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001837 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001838 input_dtype,
1839 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001840 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001841 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001842
Matthew Haddon848efb42021-09-09 12:30:53 +01001843 # Invalidate Input/Output list for error if checks.
1844 input_list = [input.name]
1845 output_list = [result_tens.name]
1846 pCount, cCount = op["operands"]
1847 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001848 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1849 self, error_name, input_list, output_list
1850 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001851
Les Bell729b0352021-11-24 10:28:21 +00001852 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001853 self.ser,
1854 validator_fcns,
1855 error_name,
1856 op=op,
1857 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001858 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001859 input_dtype=input_dtype,
1860 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001861 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001862 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001863 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001864 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001865 input_list=input_list,
1866 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001867 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001868 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001869 ):
1870 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001871
Eric Kunzee5e26762020-10-13 16:11:07 -07001872 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001873
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001874 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001875
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001876 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001877 return result_tens
1878
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001879 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1880 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1881 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001882 self.ser.addOperator(
1883 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1884 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001885 return result_tens
1886
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001887 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001888 self.ser.addOutputTensor(val)
1889 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001890
1891 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00001892 def build_cast(
1893 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1894 ):
1895 assert len(inputs) == 1
1896 val = inputs[0]
1897 out_dtype = args_dict["out_type"]
1898
1899 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001900 self.ser, self.rng, val, out_dtype, error_name
1901 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001902
1903 # Invalidate Input/Output list for error if checks.
1904 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00001905 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001906 pCount, cCount = op["operands"]
1907 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001908 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1909 self, error_name, input_list, output_list
1910 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001911
Les Bell729b0352021-11-24 10:28:21 +00001912 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001913 self.ser,
1914 validator_fcns,
1915 error_name,
1916 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001917 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00001918 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001919 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00001920 output_dtype=result_tensor.dtype,
1921 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001922 input_list=input_list,
1923 output_list=output_list,
1924 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001925 ):
1926 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001927
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001928 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00001929
1930 compliance = self.tensorComplianceMetaData(
1931 op, val.dtype, args_dict, result_tensor, error_name
1932 )
1933
1934 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001935
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001936 def build_rescale(
1937 self,
1938 op,
1939 val,
1940 out_dtype,
1941 scale32,
1942 double_round,
1943 per_channel,
1944 validator_fcns,
1945 error_name,
1946 ):
1947 result_tens = OutputShaper.typeConversionOp(
1948 self.ser, self.rng, val, out_dtype, error_name
1949 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001950
1951 if per_channel:
1952 nc = val.shape[-1]
1953 else:
1954 nc = 1
1955
1956 in_type_width = self.typeWidth(val.dtype)
1957 out_type_width = self.typeWidth(out_dtype)
1958
Kevin Cheng3a478572021-01-22 17:21:02 -08001959 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001960 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001961 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001962 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001963 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001964 in_type_width += 1
1965 elif error_name in [
1966 ErrorIf.InputZeroPointNotZero,
1967 ErrorIf.U16InputZeroPointNotValid,
1968 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001969 input_zp = self.randInt(-128, 128)
1970 if input_zp == 0:
1971 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001972 in_type_width += 1
1973 elif val.dtype == DType.UINT16:
1974 # Must come after ErrorIf.U16InputZeroPointNotValid check
1975 input_zp = self.rng.choice([0, 32768])
1976 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001977 else:
1978 input_zp = 0
1979
Kevin Cheng3a478572021-01-22 17:21:02 -08001980 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001981 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001982 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001983 elif out_dtype == DType.UINT8:
1984 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001985 out_type_width += 1
1986 elif error_name in [
1987 ErrorIf.OutputZeroPointNotZero,
1988 ErrorIf.U16OutputZeroPointNotValid,
1989 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001990 output_zp = self.randInt(-128, 128)
1991 if output_zp == 0:
1992 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001993 out_type_width += 1
1994 elif out_dtype == DType.UINT16:
1995 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1996 output_zp = self.rng.choice([0, 32768])
1997 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001998 else:
1999 output_zp = 0
2000
2001 # Calculate scale based on:
2002 # scale = a *(2^output_width)/(2^input_width))
2003
2004 a = np.float32(self.rng.random(size=[nc]))
2005 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2006
2007 if scale32:
2008 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002009 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002010 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2011 else:
2012 # Cap the scaling at 2^15 - 1 for scale16
2013 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2014
Kevin Cheng550ccc52021-03-03 11:21:43 -08002015 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002016
2017 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2018 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002019 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2020 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002021
2022 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002023 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2024 scale_arr[i], scale32
2025 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002026 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2027 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002028
Kevin Cheng550ccc52021-03-03 11:21:43 -08002029 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002030 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002031 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002032 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002033 assert val.placeholderFilename
2034 values = np.load(
2035 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2036 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002037 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2038 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2039 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2040 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002041 if not np.all(np.array_equal(values, val_adj)):
2042 # Values changed so overwrite file with new values
2043 np.save(
2044 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2045 val_adj,
2046 False,
2047 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002048
Matthew Haddonc2025212021-10-08 21:21:05 +01002049 # Invalidate Input/Output list for error if checks.
2050 input_list = [val.name]
2051 output_list = [result_tens.name]
2052 pCount, cCount = op["operands"]
2053 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002054 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2055 self, error_name, input_list, output_list
2056 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002057
2058 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002059 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002060 self.ser,
2061 validator_fcns,
2062 error_name,
2063 op=op,
2064 input_dtype=val.dtype,
2065 output_dtype=out_dtype,
2066 input_shape=val.shape,
2067 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002068 scale32=scale32,
2069 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002070 input_list=input_list,
2071 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002072 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002073 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002074 ):
2075 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002076
Eric Kunzee5e26762020-10-13 16:11:07 -07002077 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002078 attr.RescaleAttribute(
2079 input_zp,
2080 output_zp,
2081 multiplier_arr,
2082 shift_arr,
2083 scale32,
2084 double_round,
2085 per_channel,
2086 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002087
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002088 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002089 return result_tens
2090
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002091 def _get_condition_tensor(self, op, cond, error_name):
2092 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002093 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002094 else:
2095 cond_type = DType.BOOL
2096 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2097 choice = self.rng.choice([1, 2])
2098 if choice == 1:
2099 cond_shape = [2]
2100 else:
2101 cond_shape = [1, 2]
2102 else:
2103 # Must be of size 1 (rank 0)
2104 cond_shape = []
2105 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2106 return cond_tens
2107
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002108 def build_cond_if_const(
2109 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2110 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002111 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002112 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002113 # and fill them with const nodes for the body.
2114
2115 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002116 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002117
2118 # Make then/else tensors
2119 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002120
2121 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002122 if error_name in [
2123 ErrorIf.CondIfOutputListThenGraphMismatch,
2124 ErrorIf.CondIfOutputListElseGraphMismatch,
2125 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002126 incorrect_shape = deepcopy(then_tens.shape)
2127 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002128 incorrect_shape[i] += (
2129 self.rng.choice([-3, -2, 2, 3])
2130 if incorrect_shape[i] > 3
2131 else self.rng.choice([1, 2, 4])
2132 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002133 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2134
Jeremy Johnson18e26662021-07-22 16:15:29 +01002135 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2136 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002137
2138 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002139 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002140
2141 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002142 then_block = "THEN_BLOCK"
2143 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002144 attr = ts.TosaSerializerAttribute()
2145 attr.CondIfAttribute(then_block, else_block)
2146
2147 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002148 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002149
Jerry Ge9e94af82022-10-27 09:57:00 -07002150 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002151 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002152 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2153 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2154 else:
2155 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002156 self.ser.addOutputTensor(then_tens)
2157
Jerry Ge9e94af82022-10-27 09:57:00 -07002158 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002159 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2160 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2161 else:
2162 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002163 self.ser.addOutputTensor(else_tens)
2164
Les Bell729b0352021-11-24 10:28:21 +00002165 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002166 self.ser,
2167 validator_fcns,
2168 error_name,
2169 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002170 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002171 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002172 ):
2173 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002174
Eric Kunzee5e26762020-10-13 16:11:07 -07002175 return result_tens
2176
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002177 def build_cond_if_binary(
2178 self, op, a, b, cond, validator_fcns=None, error_name=None
2179 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002180 # For cond_if with a binary op in the then/else blocks, take a and b and
2181 # alternately add or subtract them based on the condition
2182
2183 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002184 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002185
Kevin Cheng550ccc52021-03-03 11:21:43 -08002186 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002187
2188 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002189 then_block = "THEN_BLOCK"
2190 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002191 attr = ts.TosaSerializerAttribute()
2192 attr.CondIfAttribute(then_block, else_block)
2193
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002194 if error_name in [
2195 ErrorIf.CondIfInputListThenGraphMismatch,
2196 ErrorIf.CondIfInputListElseGraphMismatch,
2197 ErrorIf.CondIfOutputListElseGraphMismatch,
2198 ErrorIf.CondIfOutputListThenGraphMismatch,
2199 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002200 incorrect_shape = a.shape.copy()
2201 for i in range(len(incorrect_shape)):
2202 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2203 incorrect_block_input = deepcopy(a)
2204 incorrect_block_input.shape = incorrect_shape
2205
Eric Kunzee5e26762020-10-13 16:11:07 -07002206 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002207 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002208 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002209 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002210
James Ward24dbc422022-10-19 12:20:31 +01002211 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002212 then_op, else_op = Op.ADD, Op.SUB
2213 elif a.dtype in (DType.INT8, DType.INT16):
2214 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2215 else:
2216 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002217
Les Bell6040b4d2021-10-11 12:50:31 +01002218 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002219 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002220 if (
2221 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2222 and block == then_block
2223 ) or (
2224 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2225 and block == else_block
2226 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002227 self.ser.addInputTensor(incorrect_block_input)
2228 self.ser.addInputTensor(b)
2229 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002230 elif (
2231 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2232 and block == then_block
2233 ) or (
2234 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2235 and block == else_block
2236 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002237 self.ser.addInputTensor(a)
2238 self.ser.addInputTensor(b)
2239 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2240 else:
2241 self.ser.addInputTensor(a)
2242 self.ser.addInputTensor(b)
2243 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002244 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002245
Les Bell729b0352021-11-24 10:28:21 +00002246 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002247 self.ser,
2248 validator_fcns,
2249 error_name,
2250 op=op,
2251 a=a,
2252 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002253 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002254 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002255 ):
2256 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002257
Eric Kunzee5e26762020-10-13 16:11:07 -07002258 return result_tens
2259
Matthew Haddon630c17c2021-10-14 15:05:41 +01002260 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002261 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002262
Kevin Cheng550ccc52021-03-03 11:21:43 -08002263 cond_block = "COND_BLOCK"
2264 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002265
2266 attr = ts.TosaSerializerAttribute()
2267 attr.WhileLoopAttribute(cond_block, body_block)
2268
2269 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002270 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002271 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002272 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002273
2274 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002275 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2276 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002277 if error_name == ErrorIf.InputListOutputListMismatch:
2278 incorrect_acc = deepcopy(acc)
2279 for i in range(len(incorrect_acc.shape)):
2280 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2281 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2282 else:
2283 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002284
2285 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002286 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002287 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002288 [iter.name, a.name, acc.name],
2289 [iter_out.name, a_out.name, acc_out.name],
2290 attr,
2291 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002292 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002293
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002294 if error_name in [
2295 ErrorIf.InputListCondGraphMismatch,
2296 ErrorIf.InputListBodyGraphInputMismatch,
2297 ErrorIf.InputListBodyGraphOutputMismatch,
2298 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002299 incorrect_iter = deepcopy(iter)
2300 for i in range(len(incorrect_iter.shape)):
2301 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2302 if len(incorrect_iter.shape) == 0:
2303 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2304
2305 incorrect_acc = deepcopy(acc)
2306 for i in range(len(incorrect_acc.shape)):
2307 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2308
Eric Kunzee5e26762020-10-13 16:11:07 -07002309 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002310 self.ser.addBasicBlock(cond_block)
2311
Matthew Haddon630c17c2021-10-14 15:05:41 +01002312 if error_name == ErrorIf.InputListCondGraphMismatch:
2313 self.ser.addInputTensor(incorrect_iter)
2314 self.ser.addInputTensor(a)
2315 self.ser.addInputTensor(incorrect_acc)
2316 else:
2317 self.ser.addInputTensor(iter)
2318 self.ser.addInputTensor(a)
2319 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002320 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002321
2322 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002323 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002324 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002325 cond_type = DType.BOOL
2326 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2327 choice = self.rng.choice([1, 2])
2328 if choice == 1:
2329 cond_shape = [3]
2330 else:
2331 cond_shape = [1, 2]
2332 else:
2333 cond_shape = []
2334 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002335
Kevin Cheng550ccc52021-03-03 11:21:43 -08002336 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002337
2338 # BODY block (input: a, acc, iter, output: a, acc, iter)
2339 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002340 self.ser.addBasicBlock(body_block)
2341
Matthew Haddon630c17c2021-10-14 15:05:41 +01002342 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2343 self.ser.addInputTensor(incorrect_iter)
2344 self.ser.addInputTensor(a)
2345 self.ser.addInputTensor(incorrect_acc)
2346 else:
2347 self.ser.addInputTensor(iter)
2348 self.ser.addInputTensor(a)
2349 self.ser.addInputTensor(acc)
2350
Kevin Cheng550ccc52021-03-03 11:21:43 -08002351 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002352
2353 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002354 iter_body_out = self.ser.addIntermediate(
2355 incorrect_iter.shape, incorrect_iter.dtype
2356 )
2357 acc_body_out = self.ser.addIntermediate(
2358 incorrect_acc.shape, incorrect_acc.dtype
2359 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002360 else:
2361 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2362 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2363
Eric Kunzee5e26762020-10-13 16:11:07 -07002364 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2365 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2366 self.ser.addOutputTensor(iter_body_out)
2367 self.ser.addOutputTensor(a)
2368 self.ser.addOutputTensor(acc_body_out)
2369
Les Bell729b0352021-11-24 10:28:21 +00002370 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002371 self.ser,
2372 validator_fcns,
2373 error_name,
2374 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002375 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002376 ):
2377 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002378
Eric Kunzee5e26762020-10-13 16:11:07 -07002379 return acc_out
2380
Luke Hutton57287132023-02-06 14:54:18 +00002381 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002382 self,
2383 op,
2384 val1,
2385 val2,
2386 inverse,
2387 validator_fcns=None,
2388 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002389 ):
2390 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2391
2392 input_names = [val1.name, val2.name]
2393 pCount, cCount = op["operands"]
2394 num_operands = pCount + cCount
2395
2396 output_names = [res.name for res in results]
2397 output_shapes = [res.shape for res in results]
2398 output_dtypes = [res.dtype for res in results]
2399
2400 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2401 self, error_name, input_names, output_names
2402 )
2403
2404 if not TosaErrorValidator.evValidateErrorIfs(
2405 self.ser,
2406 validator_fcns,
2407 error_name,
2408 op=op,
2409 inverse=inverse,
2410 input1=val1,
2411 input2=val2,
2412 input_shape=val1.shape,
2413 input_dtype=val1.dtype,
2414 output_shape=output_shapes,
2415 output_dtype=output_dtypes,
2416 result_tensors=results,
2417 input_list=input_names,
2418 output_list=output_names,
2419 num_operands=num_operands,
2420 ):
2421 return None
2422
Tai Lyd3797f02023-11-15 23:06:19 +00002423 # TODO - Test local_bound, for now set local bound attribute to False
2424 local_bound = False
2425
Luke Hutton57287132023-02-06 14:54:18 +00002426 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002427 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002428
2429 self.ser.addOperator(op["op"], input_names, output_names, attr)
2430 return results
2431
Tai Lyd3797f02023-11-15 23:06:19 +00002432 def build_rfft2d(
2433 self,
2434 op,
2435 val,
2436 validator_fcns=None,
2437 error_name=None,
2438 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002439 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2440
2441 input_names = [val.name]
2442 pCount, cCount = op["operands"]
2443 num_operands = pCount + cCount
2444
2445 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002446 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002447 output_dtypes = [res.dtype for res in results]
2448
2449 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2450 self, error_name, input_names, output_names
2451 )
2452
2453 if not TosaErrorValidator.evValidateErrorIfs(
2454 self.ser,
2455 validator_fcns,
2456 error_name,
2457 op=op,
2458 input_shape=val.shape,
2459 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002460 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002461 output_dtype=output_dtypes,
2462 result_tensors=results,
2463 input_list=input_names,
2464 output_list=output_names,
2465 num_operands=num_operands,
2466 ):
2467 return None
2468
Tai Lyd3797f02023-11-15 23:06:19 +00002469 # TODO - Test local_bound, for now set local bound attribute to False
2470 local_bound = False
2471
2472 attr = ts.TosaSerializerAttribute()
2473 attr.RFFTAttribute(local_bound)
2474
2475 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002476 return results
2477
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002478 def create_filter_lists(
2479 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2480 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002481 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2482 default_test_rank_range = range(1, 5)
2483 if not shapeFilter:
2484 shapeFilter = [None]
2485
2486 # Calculate the filters based on what is requested and what the operator allows
2487 rmin, rmax = op["rank"]
2488 if rankFilter is not None:
2489 cleanRankFilter = []
2490 # Ensure rankFilter values are allowed by operator
2491 for rank in rankFilter:
2492 if rank >= rmin and rank <= rmax:
2493 cleanRankFilter.append(rank)
2494 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002495 # Ensure default behaviour is bounded by default range or by operator,
2496 # whichever is the smaller range of ranks.
2497 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002498 cleanRankFilter = (
2499 opRankRange
2500 if len(opRankRange) <= len(default_test_rank_range)
2501 else default_test_rank_range
2502 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002503 else:
2504 cleanRankFilter = range(rmin, rmax + 1)
2505
2506 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002507
Matthew Haddon1c00b712021-10-01 15:51:03 +01002508 if dtypeFilter is not None:
2509 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002510 # Create list of operator dtypes filtered by requested dtypes
2511 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002512 if dtype in dtypeFilter or (
2513 isinstance(dtype, list) and dtype[0] in dtypeFilter
2514 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002515 cleanDtypeFilter.append(dtype)
2516 else:
2517 cleanDtypeFilter = dtypes
2518
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002519 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002520 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002521 "shapeFilter": shapeFilter,
2522 "rankFilter": cleanRankFilter,
2523 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002524 }
2525 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002526 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002527 if validator is not None:
2528 validator_info = validator(check=False, op=op)
2529 else:
2530 return None
2531
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002532 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002533
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002534 # Set parameters as required
2535 if error_arguments["rank"] is not None:
2536 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002537 else:
2538 rankFilter = cleanRankFilter
2539
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002540 if error_arguments["dtype"] is not None:
2541 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002542 else:
2543 dtypeFilter = cleanDtypeFilter
2544
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002545 if error_arguments["shape"] is not None:
2546 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002547 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002548 shapeFilter = shapeFilter[
2549 :2
2550 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002551
2552 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002553 "shapeFilter": shapeFilter,
2554 "rankFilter": rankFilter,
2555 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002556 }
2557 return filterDict
2558
Kevin Cheng550ccc52021-03-03 11:21:43 -08002559 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002560 self,
2561 opName,
2562 shapeFilter=[None],
2563 rankFilter=None,
2564 dtypeFilter=None,
2565 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002566 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002567
2568 try:
2569 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002570 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002571 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002572
2573 # Initialize a new random number generator
2574 self.rng = np.random.default_rng(self.random_seed)
2575
Jeremy Johnson1271c442023-09-05 11:39:26 +01002576 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002577
Eric Kunzee5e26762020-10-13 16:11:07 -07002578 # Test list consists of a tuple of:
2579 # (opName, testNameStr, dtype, shapeList, argumentsList)
2580 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002581 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002582 error_if_validators = op["error_if_validators"]
2583 else:
2584 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002585
Matthew Haddon1c00b712021-10-01 15:51:03 +01002586 for validator in error_if_validators:
2587 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002588 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002589 else:
2590 error_name = None
2591
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002592 filterDict = self.create_filter_lists(
2593 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2594 )
2595 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002596 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002597 cleanRankFilter = filterDict["rankFilter"]
2598 cleanDtypeFilter = filterDict["dtypeFilter"]
2599 cleanShapeFilter = filterDict["shapeFilter"]
2600 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002601
2602 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002603 for t in cleanDtypeFilter:
2604 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002605 # Filter out by rank
2606 if shape is not None and len(shape) != r:
2607 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002608 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002609 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002610
Matthew Haddon74567092021-07-16 15:38:20 +01002611 shapeStr = self.shapeStr(shapeList[0])
2612 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002613
Matthew Haddon74567092021-07-16 15:38:20 +01002614 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2615 argList = []
2616 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002617 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002618 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002619 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002620
Matthew Haddon74567092021-07-16 15:38:20 +01002621 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002622 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002623 if argStr:
2624 testStr = "{}_{}_{}_{}".format(
2625 opName, shapeStr, typeStr, argStr
2626 )
2627 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002628 testStr = "{}_{}_{}".format(
2629 opName, shapeStr, typeStr
2630 )
2631 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002632 if argStr:
2633 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2634 opName, error_name, shapeStr, typeStr, argStr
2635 )
2636 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002637 testStr = "{}_ERRORIF_{}_{}_{}".format(
2638 opName, error_name, shapeStr, typeStr
2639 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002640
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002641 testList.append(
2642 (opName, testStr, t, error_name, shapeList, args)
2643 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002644
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002645 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002646 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2647 if "invalid_test_validators" in op:
2648 invalid_test_validators = op["invalid_test_validators"]
2649 clean_testList = []
2650 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002651 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002652 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002653 if validator_fcn(
2654 opName=test[0],
2655 input_dtype=test[2],
2656 shapeList=test[4],
2657 args=test[5],
2658 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002659 remove_test = True
2660 if not remove_test:
2661 clean_testList.append(test)
2662 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002663
2664 return testList
2665
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002666 def serializeTest(
2667 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2668 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002669 try:
2670 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002671 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002672 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002673
Jeremy Johnson0c716862023-04-13 17:18:19 +01002674 if self.args.verbose:
2675 print(f"Creating {testStr}")
2676
Eric Kunzee5e26762020-10-13 16:11:07 -07002677 # Create a serializer
2678 self.createSerializer(opName, testStr)
2679
Jeremy Johnson1271c442023-09-05 11:39:26 +01002680 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002681 if "error_if_validators" in op:
2682 error_if_validators = op["error_if_validators"]
2683 else:
2684 error_if_validators = None
2685
Kevin Cheng550ccc52021-03-03 11:21:43 -08002686 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002687 num_operands = pCount + cCount
2688
2689 if isinstance(dtype_or_dtypeList, list):
2690 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002691 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002692 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002693 else:
2694 dtypeList = [dtype_or_dtypeList] * (num_operands)
2695
Kevin Cheng93a16282021-08-31 16:14:03 -07002696 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002697 assert (
2698 len(shapeList) == num_operands
2699 ), "shapeList length {} must match number of operands {}".format(
2700 len(shapeList), num_operands
2701 )
2702 assert (
2703 len(dtypeList) == num_operands
2704 ), "dtypeList length {} must match number of operands {}".format(
2705 len(dtypeList), num_operands
2706 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002707
2708 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002709 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002710 except KeyError:
2711 qgen = None
2712
2713 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002714
Matthew Haddon1c00b712021-10-01 15:51:03 +01002715 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002716 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002717 else:
2718 qinfo = None
2719
Jeremy Johnson1271c442023-09-05 11:39:26 +01002720 # Extra meta data for the desc.json
2721 tensMeta = {}
2722
2723 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002724 if isinstance(testArgs, dict):
2725 # New interface with args info in dictionary
2726 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002727 assert "dg_type" in argsDict
2728 tvgInfo = tvgen_fcn(
2729 self, opName, dtypeList, shapeList, argsDict, error_name
2730 )
2731 if tvgInfo.dataGenDict:
2732 tensMeta["data_gen"] = tvgInfo.dataGenDict
2733 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002734
2735 result = build_fcn(
2736 self,
2737 op,
2738 tens,
2739 argsDict,
2740 validator_fcns=error_if_validators,
2741 error_name=error_name,
2742 qinfo=qinfo,
2743 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002744 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002745 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002746 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002747
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002748 try:
2749 if error_if_validators is None:
2750 if qinfo is not None:
2751 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2752 else:
2753 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002754 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002755 if qinfo is not None:
2756 result = build_fcn(
2757 self,
2758 op,
2759 *tens,
2760 *testArgs,
2761 validator_fcns=error_if_validators,
2762 error_name=error_name,
2763 qinfo=qinfo,
2764 )
2765 else:
2766 result = build_fcn(
2767 self,
2768 op,
2769 *tens,
2770 *testArgs,
2771 validator_fcns=error_if_validators,
2772 error_name=error_name,
2773 )
2774 except TypeError as e:
2775 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2776 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002777
Jeremy Johnson1271c442023-09-05 11:39:26 +01002778 if result:
Les Bell729b0352021-11-24 10:28:21 +00002779 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002780 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2781 # Add the compliance meta data
2782 # NOTE: This currently expects only one result output
2783 tensMeta["compliance"] = {
2784 "version": "0.1",
2785 "tensors": {result.resultTensor.name: result.complianceDict},
2786 }
2787 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002788 else:
2789 # The test is not valid
2790 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002791
Eric Kunzee5e26762020-10-13 16:11:07 -07002792 def createDynamicOpLists(self):
2793
Jeremy Johnson00423432022-09-12 17:27:37 +01002794 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2795 # Already created these lists (can occur when class is initialized more than once)
2796 return
2797
Eric Kunzee5e26762020-10-13 16:11:07 -07002798 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002799 if not self.args.level8k:
2800 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2801 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2802 else:
2803 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2804 KERNELS_2D = [[1, bigK], [bigK, 2]]
2805 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002806
Kevin Cheng1533b852021-09-01 12:51:58 -07002807 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002808 testName = "conv2d_{}x{}".format(k[0], k[1])
2809 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2810 self.TOSA_OP_LIST[testName]["filter"] = k
2811 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002812
Kevin Cheng550ccc52021-03-03 11:21:43 -08002813 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2814 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2815 "depthwise_conv2d_TEMPLATE"
2816 ].copy()
2817 self.TOSA_OP_LIST[testName]["filter"] = k
2818 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002819
Kevin Cheng550ccc52021-03-03 11:21:43 -08002820 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2821 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2822 "transpose_conv2d_TEMPLATE"
2823 ].copy()
2824 self.TOSA_OP_LIST[testName]["filter"] = k
2825 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002826
Kevin Cheng1533b852021-09-01 12:51:58 -07002827 for k in KERNELS_3D:
2828 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2829 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2830 self.TOSA_OP_LIST[testName]["filter"] = k
2831 self.TOSA_OP_LIST[testName]["template"] = False
2832
Eric Kunzee5e26762020-10-13 16:11:07 -07002833 # Delete any templates after having created any dynamic ops
2834 # This is a two-pass operation because it's bad practice to delete
2835 # keys from dictionaries while iterating
2836 keyList = []
2837 for k in self.TOSA_OP_LIST:
2838 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002839 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002840 keyList.append(k)
2841 continue
2842 except KeyError:
2843 pass
2844
2845 for k in keyList:
2846 del self.TOSA_OP_LIST[k]
2847
2848 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002849 """Fill in default fields for ops if they aren't already specified.
2850 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002851 for op in self.TOSA_OP_LIST:
2852
2853 # Required fields
2854 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002855 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002856 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002857 raise Exception(
2858 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2859 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002860
2861 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002862 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002863 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002864 raise Exception(
2865 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2866 op
2867 )
2868 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002869
2870 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002871 _ = self.TOSA_OP_LIST[op]["types"]
2872 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002873 raise Exception(
2874 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2875 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002876
2877 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002878 _ = self.TOSA_OP_LIST[op]["op"]
2879 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002880 raise Exception(
2881 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2882 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002883
2884 # Put in default rank range, if missing
2885 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002886 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002887 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002888 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002889
2890 # Tensor operator list
2891 # 'op': op name
2892 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002893 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2894 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002895 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2896 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002897 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002898
Kevin Cheng550ccc52021-03-03 11:21:43 -08002899 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002900 TYPE_INT_FP = [
2901 DType.INT8,
2902 DType.INT16,
2903 DType.INT32,
2904 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002905 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002906 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002907 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002908
Kevin Cheng550ccc52021-03-03 11:21:43 -08002909 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002910 TYPE_FI32 = [
2911 DType.FP32,
2912 DType.FP16,
2913 DType.BF16,
2914 DType.INT32,
2915 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002916 TYPE_FIB = [
2917 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002918 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002919 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002920 DType.INT8,
2921 DType.INT16,
2922 DType.INT32,
2923 DType.BOOL,
2924 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002925 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002926
James Ward24dbc422022-10-19 12:20:31 +01002927 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002928
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002929 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002930 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002931 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002932 [DType.INT8, DType.INT8, DType.INT32],
2933 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002934 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002935 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002936 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002937 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002938 ]
2939
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002940 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002941
2942 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002943 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002944 "argmax": {
2945 "op": Op.ARGMAX,
2946 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002947 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002948 "build_fcn": (
2949 build_argmax,
2950 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002951 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002952 TosaArgGen.agAxis,
2953 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002954 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002955 "error_if_validators": (
2956 TosaErrorValidator.evAxisSmallerZero,
2957 TosaErrorValidator.evAxisLargerRank,
2958 TosaErrorValidator.evArgmaxOutputRankMismatch,
2959 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2960 TosaErrorValidator.evWrongRank,
2961 TosaErrorValidator.evWrongInputType,
2962 TosaErrorValidator.evWrongOutputType,
2963 TosaErrorValidator.evWrongInputList,
2964 TosaErrorValidator.evWrongOutputList,
2965 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002966 "data_gen": {
2967 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
2968 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002969 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002970 "avg_pool2d": {
2971 "op": Op.AVG_POOL2D,
2972 "operands": (1, 0),
2973 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002974 "build_fcn": (
2975 build_pool2d,
2976 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002977 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002978 TosaArgGen.agPooling,
2979 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002980 "qgen": TosaQuantGen.qgUnary,
2981 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002982 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002983 "error_if_validators": (
2984 TosaErrorValidator.evKernelSmallerOne,
2985 TosaErrorValidator.evStrideSmallerOne,
2986 TosaErrorValidator.evPadSmallerZero,
2987 TosaErrorValidator.evWrongRank,
2988 TosaErrorValidator.evWrongInputType,
2989 TosaErrorValidator.evWrongOutputType,
2990 TosaErrorValidator.evWrongInputList,
2991 TosaErrorValidator.evWrongOutputList,
2992 TosaErrorValidator.evInputZeroPointNotZero,
2993 TosaErrorValidator.evOutputZeroPointNotZero,
2994 TosaErrorValidator.evPadLargerEqualKernel,
2995 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002996 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002997 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00002998 "data_gen": {
2999 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3000 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003001 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003002 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003003 "conv2d_TEMPLATE": {
3004 "op": Op.CONV2D,
3005 "operands": (1, 2),
3006 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003007 "build_fcn": (
3008 build_conv2d,
3009 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003010 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003011 TosaArgGen.agConv,
3012 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003013 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003014 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003015 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3016 "error_if_validators": (
3017 TosaErrorValidator.evWrongInputType,
3018 TosaErrorValidator.evWrongOutputType,
3019 TosaErrorValidator.evWrongInputList,
3020 TosaErrorValidator.evWrongOutputList,
3021 TosaErrorValidator.evInputZeroPointNotZero,
3022 TosaErrorValidator.evWeightZeroPointNotZero,
3023 TosaErrorValidator.evPadSmallerZero,
3024 TosaErrorValidator.evStrideSmallerOne,
3025 TosaErrorValidator.evDilationSmallerOne,
3026 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003027 TosaErrorValidator.evConvOutputShapeMismatch,
3028 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003029 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003030 "data_gen": {
3031 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3032 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003033 "template": True,
3034 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003035 # Templated operator. Filled in by createDynamicOpLists
3036 "conv3d_TEMPLATE": {
3037 "op": Op.CONV3D,
3038 "operands": (1, 2),
3039 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003040 "build_fcn": (
3041 build_conv3d,
3042 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003043 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003044 TosaArgGen.agConv,
3045 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003046 "qgen": TosaQuantGen.qgConv,
3047 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003048 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3049 "error_if_validators": (
3050 TosaErrorValidator.evWrongInputType,
3051 TosaErrorValidator.evWrongOutputType,
3052 TosaErrorValidator.evWrongInputList,
3053 TosaErrorValidator.evWrongOutputList,
3054 TosaErrorValidator.evInputZeroPointNotZero,
3055 TosaErrorValidator.evWeightZeroPointNotZero,
3056 TosaErrorValidator.evPadSmallerZero,
3057 TosaErrorValidator.evStrideSmallerOne,
3058 TosaErrorValidator.evDilationSmallerOne,
3059 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003060 TosaErrorValidator.evConvOutputShapeMismatch,
3061 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003062 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003063 "template": True,
3064 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003065 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003066 "depthwise_conv2d_TEMPLATE": {
3067 "op": Op.DEPTHWISE_CONV2D,
3068 "operands": (1, 2),
3069 "filter": [1, 1],
3070 "rank": (4, 4),
3071 "build_fcn": (
3072 build_depthwise_conv2d,
3073 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003074 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003075 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003076 ),
3077 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003078 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003079 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3080 "error_if_validators": (
3081 TosaErrorValidator.evWrongInputType,
3082 TosaErrorValidator.evWrongOutputType,
3083 TosaErrorValidator.evWrongInputList,
3084 TosaErrorValidator.evWrongOutputList,
3085 TosaErrorValidator.evInputZeroPointNotZero,
3086 TosaErrorValidator.evWeightZeroPointNotZero,
3087 TosaErrorValidator.evPadSmallerZero,
3088 TosaErrorValidator.evStrideSmallerOne,
3089 TosaErrorValidator.evDilationSmallerOne,
3090 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003091 TosaErrorValidator.evConvOutputShapeMismatch,
3092 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003093 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003094 "template": True,
3095 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003096 "fully_connected": {
3097 "op": Op.FULLY_CONNECTED,
3098 "operands": (1, 2),
3099 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003100 "build_fcn": (
3101 build_fully_connected,
3102 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003103 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003104 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003105 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003106 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003107 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003108 "error_if_validators": (
3109 TosaErrorValidator.evInputZeroPointNotZero,
3110 TosaErrorValidator.evWeightZeroPointNotZero,
3111 TosaErrorValidator.evWrongRank,
3112 TosaErrorValidator.evWrongInputType,
3113 TosaErrorValidator.evWrongOutputType,
3114 TosaErrorValidator.evWrongInputList,
3115 TosaErrorValidator.evWrongOutputList,
3116 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003117 "data_gen": {
3118 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3119 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003120 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003121 "matmul": {
3122 "op": Op.MATMUL,
3123 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003124 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003125 "build_fcn": (
3126 build_matmul,
3127 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003128 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003129 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003130 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003131 "qgen": TosaQuantGen.qgMatmul,
3132 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003133 "error_if_validators": (
3134 TosaErrorValidator.evInputZeroPointNotZero,
3135 TosaErrorValidator.evWrongRank,
3136 TosaErrorValidator.evWrongInputType,
3137 TosaErrorValidator.evWrongOutputType,
3138 TosaErrorValidator.evWrongInputList,
3139 TosaErrorValidator.evWrongOutputList,
3140 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003141 "data_gen": {
3142 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003143 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003144 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003145 "max_pool2d": {
3146 "op": Op.MAX_POOL2D,
3147 "operands": (1, 0),
3148 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003149 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003150 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003151 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003152 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003153 TosaArgGen.agPooling,
3154 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003155 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003156 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003157 "error_if_validators": (
3158 TosaErrorValidator.evKernelSmallerOne,
3159 TosaErrorValidator.evStrideSmallerOne,
3160 TosaErrorValidator.evPadSmallerZero,
3161 TosaErrorValidator.evWrongRank,
3162 TosaErrorValidator.evWrongInputType,
3163 TosaErrorValidator.evWrongOutputType,
3164 TosaErrorValidator.evWrongInputList,
3165 TosaErrorValidator.evWrongOutputList,
3166 TosaErrorValidator.evPadLargerEqualKernel,
3167 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003168 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003169 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003170 "data_gen": {
3171 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3172 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003173 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003174 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003175 "transpose_conv2d_TEMPLATE": {
3176 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003177 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003178 "rank": (4, 4),
3179 "build_fcn": (
3180 build_transpose_conv2d,
3181 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003182 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003183 TosaArgGen.agTransposeConv2D,
3184 ),
3185 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003186 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003187 "invalid_test_validators": (
3188 TosaInvalidValidator.ivHeightWidthInvalid,
3189 TosaInvalidValidator.ivNonPositiveOutputShape,
3190 ),
3191 "error_if_validators": (
3192 TosaErrorValidator.evWrongInputType,
3193 TosaErrorValidator.evWrongOutputType,
3194 TosaErrorValidator.evWrongInputList,
3195 TosaErrorValidator.evWrongOutputList,
3196 TosaErrorValidator.evInputZeroPointNotZero,
3197 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003198 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003199 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003200 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003201 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003202 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003203 "template": True,
3204 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003205 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003206 "clamp": {
3207 "op": Op.CLAMP,
3208 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003209 "build_fcn": (
3210 build_clamp,
3211 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003212 TosaTensorValuesGen.tvgLazyGenDefault,
3213 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003214 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003215 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003216 "error_if_validators": (
3217 TosaErrorValidator.evMaxSmallerMin,
3218 TosaErrorValidator.evWrongInputType,
3219 TosaErrorValidator.evWrongOutputType,
3220 TosaErrorValidator.evWrongInputList,
3221 TosaErrorValidator.evWrongOutputList,
3222 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003223 "data_gen": {
3224 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3225 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003226 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003227 "sigmoid": {
3228 "op": Op.SIGMOID,
3229 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003230 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003231 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003232 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003233 TosaTensorValuesGen.tvgLazyGenDefault,
3234 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003235 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003236 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003237 "error_if_validators": (
3238 TosaErrorValidator.evWrongInputType,
3239 TosaErrorValidator.evWrongOutputType,
3240 TosaErrorValidator.evWrongInputList,
3241 TosaErrorValidator.evWrongOutputList,
3242 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003243 "data_gen": {
3244 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3245 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003246 },
3247 "tanh": {
3248 "op": Op.TANH,
3249 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003250 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003251 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003252 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003253 TosaTensorValuesGen.tvgLazyGenDefault,
3254 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003255 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003256 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003257 "error_if_validators": (
3258 TosaErrorValidator.evWrongInputType,
3259 TosaErrorValidator.evWrongOutputType,
3260 TosaErrorValidator.evWrongInputList,
3261 TosaErrorValidator.evWrongOutputList,
3262 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003263 "data_gen": {
3264 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3265 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003266 },
Won Jeon78155c62023-06-10 00:20:04 +00003267 "erf": {
3268 "op": Op.ERF,
3269 "operands": (1, 0),
3270 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003271 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003272 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003273 TosaTensorValuesGen.tvgLazyGenDefault,
3274 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003275 ),
3276 "types": TYPE_FP,
3277 "error_if_validators": (
3278 TosaErrorValidator.evWrongInputType,
3279 TosaErrorValidator.evWrongOutputType,
3280 TosaErrorValidator.evWrongInputList,
3281 TosaErrorValidator.evWrongOutputList,
3282 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003283 "data_gen": {
3284 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3285 },
3286 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003287 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003288 # Elementwise Binary Operators
3289 "add": {
3290 "op": Op.ADD,
3291 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003292 "build_fcn": (
3293 build_binary_broadcast,
3294 TosaTensorGen.tgBroadcastFuzz,
3295 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003296 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003297 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003298 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003299 "error_if_validators": (
3300 TosaErrorValidator.evRankMismatch,
3301 TosaErrorValidator.evWrongInputType,
3302 TosaErrorValidator.evWrongOutputType,
3303 TosaErrorValidator.evWrongInputList,
3304 TosaErrorValidator.evWrongOutputList,
3305 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003306 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003307 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003308 "data_gen": {
3309 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3310 },
3311 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003312 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003313 "arithmetic_right_shift": {
3314 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3315 "operands": (2, 0),
3316 "build_fcn": (
3317 build_arithmetic_right_shift,
3318 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003319 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003320 TosaArgGen.agArithmeticRightShift,
3321 ),
3322 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003323 "error_if_validators": (
3324 TosaErrorValidator.evRankMismatch,
3325 TosaErrorValidator.evWrongInputType,
3326 TosaErrorValidator.evWrongOutputType,
3327 TosaErrorValidator.evWrongInputList,
3328 TosaErrorValidator.evWrongOutputList,
3329 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003330 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003331 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003332 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003333 "bitwise_and": {
3334 "op": Op.BITWISE_AND,
3335 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003336 "build_fcn": (
3337 build_binary_broadcast,
3338 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003339 TosaTensorValuesGen.tvgLazyGenDefault,
3340 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003341 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003342 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003343 "error_if_validators": (
3344 TosaErrorValidator.evRankMismatch,
3345 TosaErrorValidator.evWrongInputType,
3346 TosaErrorValidator.evWrongOutputType,
3347 TosaErrorValidator.evWrongInputList,
3348 TosaErrorValidator.evWrongOutputList,
3349 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003350 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003351 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003352 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003353 "bitwise_or": {
3354 "op": Op.BITWISE_OR,
3355 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003356 "build_fcn": (
3357 build_binary_broadcast,
3358 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003359 TosaTensorValuesGen.tvgLazyGenDefault,
3360 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003361 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003362 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003363 "error_if_validators": (
3364 TosaErrorValidator.evRankMismatch,
3365 TosaErrorValidator.evWrongInputType,
3366 TosaErrorValidator.evWrongOutputType,
3367 TosaErrorValidator.evWrongInputList,
3368 TosaErrorValidator.evWrongOutputList,
3369 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003370 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003371 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003372 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003373 "bitwise_xor": {
3374 "op": Op.BITWISE_XOR,
3375 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003376 "build_fcn": (
3377 build_binary_broadcast,
3378 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003379 TosaTensorValuesGen.tvgLazyGenDefault,
3380 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003381 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003382 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003383 "error_if_validators": (
3384 TosaErrorValidator.evRankMismatch,
3385 TosaErrorValidator.evWrongInputType,
3386 TosaErrorValidator.evWrongOutputType,
3387 TosaErrorValidator.evWrongInputList,
3388 TosaErrorValidator.evWrongOutputList,
3389 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003390 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003391 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003392 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003393 "intdiv": {
3394 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003395 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003396 "build_fcn": (
3397 build_binary_broadcast,
3398 TosaTensorGen.tgBroadcastFuzz,
3399 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003400 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003401 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003402 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003403 "error_if_validators": (
3404 TosaErrorValidator.evRankMismatch,
3405 TosaErrorValidator.evWrongInputType,
3406 TosaErrorValidator.evWrongOutputType,
3407 TosaErrorValidator.evWrongInputList,
3408 TosaErrorValidator.evWrongOutputList,
3409 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003410 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003411 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003412 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003413 "logical_and": {
3414 "op": Op.LOGICAL_AND,
3415 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003416 "build_fcn": (
3417 build_binary_broadcast,
3418 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003419 TosaTensorValuesGen.tvgLazyGenDefault,
3420 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003421 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003422 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003423 "error_if_validators": (
3424 TosaErrorValidator.evRankMismatch,
3425 TosaErrorValidator.evWrongInputType,
3426 TosaErrorValidator.evWrongOutputType,
3427 TosaErrorValidator.evWrongInputList,
3428 TosaErrorValidator.evWrongOutputList,
3429 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003430 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003431 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003432 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003433 "logical_left_shift": {
3434 "op": Op.LOGICAL_LEFT_SHIFT,
3435 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003436 "build_fcn": (
3437 build_binary_broadcast,
3438 TosaTensorGen.tgBroadcastFuzz,
3439 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003440 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003441 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003442 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003443 "error_if_validators": (
3444 TosaErrorValidator.evRankMismatch,
3445 TosaErrorValidator.evWrongInputType,
3446 TosaErrorValidator.evWrongOutputType,
3447 TosaErrorValidator.evWrongInputList,
3448 TosaErrorValidator.evWrongOutputList,
3449 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003450 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003451 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003452 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003453 "logical_right_shift": {
3454 "op": Op.LOGICAL_RIGHT_SHIFT,
3455 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003456 "build_fcn": (
3457 build_binary_broadcast,
3458 TosaTensorGen.tgBroadcastFuzz,
3459 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003460 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003461 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003462 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003463 "error_if_validators": (
3464 TosaErrorValidator.evRankMismatch,
3465 TosaErrorValidator.evWrongInputType,
3466 TosaErrorValidator.evWrongOutputType,
3467 TosaErrorValidator.evWrongInputList,
3468 TosaErrorValidator.evWrongOutputList,
3469 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003470 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003471 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003472 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 "logical_or": {
3474 "op": Op.LOGICAL_OR,
3475 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003476 "build_fcn": (
3477 build_binary_broadcast,
3478 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003479 TosaTensorValuesGen.tvgLazyGenDefault,
3480 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003481 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003482 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003483 "error_if_validators": (
3484 TosaErrorValidator.evRankMismatch,
3485 TosaErrorValidator.evWrongInputType,
3486 TosaErrorValidator.evWrongOutputType,
3487 TosaErrorValidator.evWrongInputList,
3488 TosaErrorValidator.evWrongOutputList,
3489 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003490 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003491 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003492 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003493 "logical_xor": {
3494 "op": Op.LOGICAL_XOR,
3495 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003496 "build_fcn": (
3497 build_binary_broadcast,
3498 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003499 TosaTensorValuesGen.tvgLazyGenDefault,
3500 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003501 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003502 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003503 "error_if_validators": (
3504 TosaErrorValidator.evRankMismatch,
3505 TosaErrorValidator.evWrongInputType,
3506 TosaErrorValidator.evWrongOutputType,
3507 TosaErrorValidator.evWrongInputList,
3508 TosaErrorValidator.evWrongOutputList,
3509 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003510 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003511 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003512 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003513 "maximum": {
3514 "op": Op.MAXIMUM,
3515 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003516 "build_fcn": (
3517 build_binary_broadcast,
3518 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003519 TosaTensorValuesGen.tvgLazyGenDefault,
3520 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003521 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003522 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003523 "error_if_validators": (
3524 TosaErrorValidator.evRankMismatch,
3525 TosaErrorValidator.evWrongInputType,
3526 TosaErrorValidator.evWrongOutputType,
3527 TosaErrorValidator.evWrongInputList,
3528 TosaErrorValidator.evWrongOutputList,
3529 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003530 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003531 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003532 "data_gen": {
3533 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3534 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003536 "minimum": {
3537 "op": Op.MINIMUM,
3538 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003539 "build_fcn": (
3540 build_binary_broadcast,
3541 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003542 TosaTensorValuesGen.tvgLazyGenDefault,
3543 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003544 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003545 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003546 "error_if_validators": (
3547 TosaErrorValidator.evRankMismatch,
3548 TosaErrorValidator.evWrongInputType,
3549 TosaErrorValidator.evWrongOutputType,
3550 TosaErrorValidator.evWrongInputList,
3551 TosaErrorValidator.evWrongOutputList,
3552 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003553 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003554 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003555 "data_gen": {
3556 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3557 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003558 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003559 "mul": {
3560 "op": Op.MUL,
3561 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003562 "build_fcn": (
3563 build_mul,
3564 TosaTensorGen.tgBroadcastFuzz,
3565 TosaTensorValuesGen.tvgMul,
3566 TosaArgGen.agMul,
3567 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003568 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003569 "error_if_validators": (
3570 TosaErrorValidator.evWrongInputType,
3571 TosaErrorValidator.evWrongOutputType,
3572 TosaErrorValidator.evWrongInputList,
3573 TosaErrorValidator.evWrongOutputList,
3574 TosaErrorValidator.evRankMismatch,
3575 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003576 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003577 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003578 "data_gen": {
3579 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3580 },
3581 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003582 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003583 "pow": {
3584 "op": Op.POW,
3585 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003586 "build_fcn": (
3587 build_binary_broadcast,
3588 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003589 TosaTensorValuesGen.tvgPow,
3590 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003591 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003592 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003593 "error_if_validators": (
3594 TosaErrorValidator.evRankMismatch,
3595 TosaErrorValidator.evWrongInputType,
3596 TosaErrorValidator.evWrongOutputType,
3597 TosaErrorValidator.evWrongInputList,
3598 TosaErrorValidator.evWrongOutputList,
3599 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003600 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003601 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003602 "data_gen": {
3603 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3604 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003605 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003606 "sub": {
3607 "op": Op.SUB,
3608 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003609 "build_fcn": (
3610 build_binary_broadcast,
3611 TosaTensorGen.tgBroadcastFuzz,
3612 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003613 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003614 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003615 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003616 "error_if_validators": (
3617 TosaErrorValidator.evRankMismatch,
3618 TosaErrorValidator.evWrongInputType,
3619 TosaErrorValidator.evWrongOutputType,
3620 TosaErrorValidator.evWrongInputList,
3621 TosaErrorValidator.evWrongOutputList,
3622 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003623 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003624 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003625 "data_gen": {
3626 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3627 },
3628 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003629 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003630 "table": {
3631 "op": Op.TABLE,
3632 # Use the automatic generation functions to create the input array
3633 # but create the table tensor in the build function, as it may be
3634 # a different type from the input
3635 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003636 "build_fcn": (
3637 build_table,
3638 TosaTensorGen.tgBasic,
3639 TosaTensorValuesGen.tvgDefault,
3640 TosaArgGen.agTable,
3641 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003642 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003643 "error_if_validators": (
3644 TosaErrorValidator.evWrongInputType,
3645 TosaErrorValidator.evWrongOutputType,
3646 TosaErrorValidator.evWrongInputList,
3647 TosaErrorValidator.evWrongOutputList,
3648 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003649 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003650 # Elementwise Unary operators
3651 "abs": {
3652 "op": Op.ABS,
3653 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003654 "build_fcn": (
3655 build_unary,
3656 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003657 TosaTensorValuesGen.tvgLazyGenDefault,
3658 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003659 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003660 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003661 "error_if_validators": (
3662 TosaErrorValidator.evWrongInputType,
3663 TosaErrorValidator.evWrongOutputType,
3664 TosaErrorValidator.evWrongInputList,
3665 TosaErrorValidator.evWrongOutputList,
3666 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003667 "data_gen": {
3668 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3669 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003670 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003671 "bitwise_not": {
3672 "op": Op.BITWISE_NOT,
3673 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003674 "build_fcn": (
3675 build_unary,
3676 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003677 TosaTensorValuesGen.tvgLazyGenDefault,
3678 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003679 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003680 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003681 "error_if_validators": (
3682 TosaErrorValidator.evWrongInputType,
3683 TosaErrorValidator.evWrongOutputType,
3684 TosaErrorValidator.evWrongInputList,
3685 TosaErrorValidator.evWrongOutputList,
3686 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003687 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003688 "ceil": {
3689 "op": Op.CEIL,
3690 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003691 "build_fcn": (
3692 build_unary,
3693 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003694 TosaTensorValuesGen.tvgLazyGenDefault,
3695 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003696 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003697 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003698 "error_if_validators": (
3699 TosaErrorValidator.evWrongInputType,
3700 TosaErrorValidator.evWrongOutputType,
3701 TosaErrorValidator.evWrongInputList,
3702 TosaErrorValidator.evWrongOutputList,
3703 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003704 "data_gen": {
3705 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3706 },
3707 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003708 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003709 "clz": {
3710 "op": Op.CLZ,
3711 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003712 "build_fcn": (
3713 build_unary,
3714 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003715 TosaTensorValuesGen.tvgLazyGenDefault,
3716 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003717 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003718 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003719 "error_if_validators": (
3720 TosaErrorValidator.evWrongInputType,
3721 TosaErrorValidator.evWrongOutputType,
3722 TosaErrorValidator.evWrongInputList,
3723 TosaErrorValidator.evWrongOutputList,
3724 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003725 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003726 "exp": {
3727 "op": Op.EXP,
3728 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003729 "build_fcn": (
3730 build_unary,
3731 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003732 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003733 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003734 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003735 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003736 "error_if_validators": (
3737 TosaErrorValidator.evWrongInputType,
3738 TosaErrorValidator.evWrongOutputType,
3739 TosaErrorValidator.evWrongInputList,
3740 TosaErrorValidator.evWrongOutputList,
3741 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003742 "data_gen": {
3743 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3744 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003745 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003746 "floor": {
3747 "op": Op.FLOOR,
3748 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003749 "build_fcn": (
3750 build_unary,
3751 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003752 TosaTensorValuesGen.tvgLazyGenDefault,
3753 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003754 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003755 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003756 "error_if_validators": (
3757 TosaErrorValidator.evWrongInputType,
3758 TosaErrorValidator.evWrongOutputType,
3759 TosaErrorValidator.evWrongInputList,
3760 TosaErrorValidator.evWrongOutputList,
3761 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003762 "data_gen": {
3763 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3764 },
3765 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003766 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003767 "log": {
3768 "op": Op.LOG,
3769 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003770 "build_fcn": (
3771 build_unary,
3772 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003773 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003774 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003775 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003776 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003777 "error_if_validators": (
3778 TosaErrorValidator.evWrongInputType,
3779 TosaErrorValidator.evWrongOutputType,
3780 TosaErrorValidator.evWrongInputList,
3781 TosaErrorValidator.evWrongOutputList,
3782 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003783 "data_gen": {
3784 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3785 },
3786 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003787 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003788 "logical_not": {
3789 "op": Op.LOGICAL_NOT,
3790 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003791 "build_fcn": (
3792 build_unary,
3793 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003794 TosaTensorValuesGen.tvgLazyGenDefault,
3795 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003796 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003797 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003798 "error_if_validators": (
3799 TosaErrorValidator.evWrongInputType,
3800 TosaErrorValidator.evWrongOutputType,
3801 TosaErrorValidator.evWrongInputList,
3802 TosaErrorValidator.evWrongOutputList,
3803 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003804 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003805 "negate": {
3806 "op": Op.NEGATE,
3807 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003808 "build_fcn": (
3809 build_unary,
3810 TosaTensorGen.tgBasic,
3811 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003812 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003813 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003814 "qgen": TosaQuantGen.qgUnary,
3815 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003816 "error_if_validators": (
3817 TosaErrorValidator.evInputZeroPointNotZero,
3818 TosaErrorValidator.evOutputZeroPointNotZero,
3819 TosaErrorValidator.evWrongInputType,
3820 TosaErrorValidator.evWrongOutputType,
3821 TosaErrorValidator.evWrongInputList,
3822 TosaErrorValidator.evWrongOutputList,
3823 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003824 "data_gen": {
3825 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3826 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003827 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003828 "reciprocal": {
3829 "op": Op.RECIPROCAL,
3830 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003831 "build_fcn": (
3832 build_unary,
3833 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003834 TosaTensorValuesGen.tvgLazyGenDefault,
3835 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003836 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003837 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003838 "error_if_validators": (
3839 TosaErrorValidator.evWrongInputType,
3840 TosaErrorValidator.evWrongOutputType,
3841 TosaErrorValidator.evWrongInputList,
3842 TosaErrorValidator.evWrongOutputList,
3843 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003844 "data_gen": {
3845 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3846 },
3847 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003848 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003849 "rsqrt": {
3850 "op": Op.RSQRT,
3851 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003852 "build_fcn": (
3853 build_unary,
3854 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003855 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003856 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003857 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003858 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003859 "error_if_validators": (
3860 TosaErrorValidator.evWrongInputType,
3861 TosaErrorValidator.evWrongOutputType,
3862 TosaErrorValidator.evWrongInputList,
3863 TosaErrorValidator.evWrongOutputList,
3864 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003865 "data_gen": {
3866 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3867 },
3868 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003869 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003870 # Elementwise Ternary operators
3871 "select": {
3872 "op": Op.SELECT,
3873 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003874 "build_fcn": (
3875 build_select,
3876 TosaTensorGen.tgBroadcastFuzz,
3877 TosaTensorValuesGen.tvgSelect,
3878 None,
3879 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003880 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003881 "error_if_validators": (
3882 TosaErrorValidator.evRankMismatch,
3883 TosaErrorValidator.evWrongInputType,
3884 TosaErrorValidator.evWrongOutputType,
3885 TosaErrorValidator.evWrongInputList,
3886 TosaErrorValidator.evWrongOutputList,
3887 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003888 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003889 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003890 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003891 # Comparison operators
3892 "equal": {
3893 "op": Op.EQUAL,
3894 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003895 "build_fcn": (
3896 build_comparison,
3897 TosaTensorGen.tgBroadcastFuzz,
3898 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003899 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003900 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003901 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003902 "error_if_validators": (
3903 TosaErrorValidator.evRankMismatch,
3904 TosaErrorValidator.evWrongInputType,
3905 TosaErrorValidator.evWrongOutputType,
3906 TosaErrorValidator.evWrongInputList,
3907 TosaErrorValidator.evWrongOutputList,
3908 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003909 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003910 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003911 "data_gen": {
3912 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3913 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003914 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003915 "greater_equal": {
3916 "op": Op.GREATER_EQUAL,
3917 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003918 "build_fcn": (
3919 build_comparison,
3920 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003921 TosaTensorValuesGen.tvgLazyGenDefault,
3922 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003923 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003924 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003925 "error_if_validators": (
3926 TosaErrorValidator.evRankMismatch,
3927 TosaErrorValidator.evWrongInputType,
3928 TosaErrorValidator.evWrongOutputType,
3929 TosaErrorValidator.evWrongInputList,
3930 TosaErrorValidator.evWrongOutputList,
3931 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003932 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003933 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003934 "data_gen": {
3935 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3936 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003937 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003938 "greater": {
3939 "op": Op.GREATER,
3940 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003941 "build_fcn": (
3942 build_comparison,
3943 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003944 TosaTensorValuesGen.tvgLazyGenDefault,
3945 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003946 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003947 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003948 "error_if_validators": (
3949 TosaErrorValidator.evRankMismatch,
3950 TosaErrorValidator.evWrongInputType,
3951 TosaErrorValidator.evWrongOutputType,
3952 TosaErrorValidator.evWrongInputList,
3953 TosaErrorValidator.evWrongOutputList,
3954 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003955 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003956 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003957 "data_gen": {
3958 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3959 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003960 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003961 # Reduction operators
3962 "reduce_all": {
3963 "op": Op.REDUCE_ALL,
3964 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003965 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003966 "build_fcn": (
3967 build_reduce,
3968 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003969 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003970 TosaArgGen.agAxis,
3971 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003972 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003973 "error_if_validators": (
3974 TosaErrorValidator.evAxisLargerRank,
3975 TosaErrorValidator.evAxisSmallerZero,
3976 TosaErrorValidator.evShapeOfAxisNotOne,
3977 TosaErrorValidator.evWrongInputType,
3978 TosaErrorValidator.evWrongOutputType,
3979 TosaErrorValidator.evWrongRank,
3980 TosaErrorValidator.evWrongInputList,
3981 TosaErrorValidator.evWrongOutputList,
3982 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003983 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003984 "reduce_any": {
3985 "op": Op.REDUCE_ANY,
3986 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003987 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003988 "build_fcn": (
3989 build_reduce,
3990 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003991 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003992 TosaArgGen.agAxis,
3993 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003994 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003995 "error_if_validators": (
3996 TosaErrorValidator.evAxisLargerRank,
3997 TosaErrorValidator.evAxisSmallerZero,
3998 TosaErrorValidator.evShapeOfAxisNotOne,
3999 TosaErrorValidator.evWrongInputType,
4000 TosaErrorValidator.evWrongOutputType,
4001 TosaErrorValidator.evWrongRank,
4002 TosaErrorValidator.evWrongInputList,
4003 TosaErrorValidator.evWrongOutputList,
4004 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004005 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004006 "reduce_max": {
4007 "op": Op.REDUCE_MAX,
4008 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004009 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004010 "build_fcn": (
4011 build_reduce,
4012 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004013 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004014 TosaArgGen.agAxis,
4015 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004016 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004017 "error_if_validators": (
4018 TosaErrorValidator.evAxisLargerRank,
4019 TosaErrorValidator.evAxisSmallerZero,
4020 TosaErrorValidator.evShapeOfAxisNotOne,
4021 TosaErrorValidator.evWrongInputType,
4022 TosaErrorValidator.evWrongOutputType,
4023 TosaErrorValidator.evWrongRank,
4024 TosaErrorValidator.evWrongInputList,
4025 TosaErrorValidator.evWrongOutputList,
4026 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004027 "data_gen": {
4028 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4029 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004030 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004031 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004032 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004033 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004034 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004035 "build_fcn": (
4036 build_reduce,
4037 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004038 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004039 TosaArgGen.agAxis,
4040 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004041 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004042 "error_if_validators": (
4043 TosaErrorValidator.evAxisLargerRank,
4044 TosaErrorValidator.evAxisSmallerZero,
4045 TosaErrorValidator.evShapeOfAxisNotOne,
4046 TosaErrorValidator.evWrongInputType,
4047 TosaErrorValidator.evWrongOutputType,
4048 TosaErrorValidator.evWrongRank,
4049 TosaErrorValidator.evWrongInputList,
4050 TosaErrorValidator.evWrongOutputList,
4051 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004052 "data_gen": {
4053 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4054 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004055 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004056 "reduce_product": {
4057 "op": Op.REDUCE_PRODUCT,
4058 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004059 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004060 "build_fcn": (
4061 build_reduce,
4062 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004063 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004064 TosaArgGen.agAxis,
4065 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004066 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004067 "error_if_validators": (
4068 TosaErrorValidator.evAxisLargerRank,
4069 TosaErrorValidator.evAxisSmallerZero,
4070 TosaErrorValidator.evShapeOfAxisNotOne,
4071 TosaErrorValidator.evWrongInputType,
4072 TosaErrorValidator.evWrongOutputType,
4073 TosaErrorValidator.evWrongRank,
4074 TosaErrorValidator.evWrongInputList,
4075 TosaErrorValidator.evWrongOutputList,
4076 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004077 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004078 "reduce_sum": {
4079 "op": Op.REDUCE_SUM,
4080 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004081 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004082 "build_fcn": (
4083 build_reduce,
4084 TosaTensorGen.tgBasic,
4085 TosaTensorValuesGen.tvgReduceSum,
4086 TosaArgGen.agAxis,
4087 ),
James Ward24dbc422022-10-19 12:20:31 +01004088 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004089 "error_if_validators": (
4090 TosaErrorValidator.evAxisLargerRank,
4091 TosaErrorValidator.evAxisSmallerZero,
4092 TosaErrorValidator.evShapeOfAxisNotOne,
4093 TosaErrorValidator.evWrongInputType,
4094 TosaErrorValidator.evWrongOutputType,
4095 TosaErrorValidator.evWrongRank,
4096 TosaErrorValidator.evWrongInputList,
4097 TosaErrorValidator.evWrongOutputList,
4098 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004099 "data_gen": {
4100 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4101 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004102 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004103 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004104 "concat": {
4105 "op": Op.CONCAT,
4106 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004107 "build_fcn": (
4108 build_concat,
4109 TosaTensorGen.tgConcat,
4110 TosaTensorValuesGen.tvgConcat,
4111 TosaArgGen.agAxis,
4112 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004113 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004114 "error_if_validators": (
4115 TosaErrorValidator.evAxisLargerRank,
4116 TosaErrorValidator.evAxisSmallerZero,
4117 TosaErrorValidator.evConcatInputRankMismatch,
4118 TosaErrorValidator.evConcatShapeSumMismatch,
4119 TosaErrorValidator.evConcatInputDimMismatch,
4120 TosaErrorValidator.evWrongInputType,
4121 TosaErrorValidator.evWrongOutputType,
4122 TosaErrorValidator.evWrongOutputList,
4123 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004124 },
4125 "pad": {
4126 "op": Op.PAD,
4127 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004128 "build_fcn": (
4129 build_pad,
4130 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004131 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004132 TosaArgGen.agPad,
4133 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004134 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004135 "error_if_validators": (
4136 TosaErrorValidator.evWrongInputType,
4137 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004138 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004139 TosaErrorValidator.evWrongOutputType,
4140 TosaErrorValidator.evWrongInputList,
4141 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004142 TosaErrorValidator.evRankMismatch,
4143 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004144 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004145 "data_gen": {
4146 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4147 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004148 },
Won Jeona21b2e82023-08-10 10:33:01 +00004149 "dim": {
4150 "op": Op.DIM,
4151 "operands": (1, 0),
4152 "build_fcn": (
4153 build_dim,
4154 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004155 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004156 TosaArgGen.agAxis,
4157 ),
4158 "types": TYPE_FIB,
4159 "error_if_validators": (
4160 TosaErrorValidator.evAxisLargerRank,
4161 TosaErrorValidator.evAxisSmallerZero,
4162 TosaErrorValidator.evWrongInputType,
4163 TosaErrorValidator.evWrongInputList,
4164 TosaErrorValidator.evWrongOutputList,
4165 TosaErrorValidator.evWrongRank,
4166 ),
4167 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004168 "reshape": {
4169 "op": Op.RESHAPE,
4170 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004171 "build_fcn": (
4172 build_reshape,
4173 TosaTensorGen.tgBasic,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004174 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004175 TosaArgGen.agReshape,
4176 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004177 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004178 "error_if_validators": (
4179 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4180 TosaErrorValidator.evWrongInputType,
4181 TosaErrorValidator.evWrongOutputType,
4182 TosaErrorValidator.evWrongInputList,
4183 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00004184 TosaErrorValidator.evReshapeOutputSizeMultiInference,
4185 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004186 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004187 "data_gen": {
4188 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4189 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004190 },
4191 "reverse": {
4192 "op": Op.REVERSE,
4193 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004194 "build_fcn": (
4195 build_reverse,
4196 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004197 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004198 TosaArgGen.agAxis,
4199 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004200 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004201 "error_if_validators": (
4202 TosaErrorValidator.evAxisSmallerZero,
4203 TosaErrorValidator.evAxisLargerRank,
4204 TosaErrorValidator.evWrongInputType,
4205 TosaErrorValidator.evWrongOutputType,
4206 TosaErrorValidator.evWrongInputList,
4207 TosaErrorValidator.evWrongOutputList,
4208 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004209 },
4210 "slice": {
4211 "op": Op.SLICE,
4212 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004213 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004214 "build_fcn": (
4215 build_slice,
4216 TosaTensorGen.tgBasic,
4217 TosaTensorValuesGen.tvgDefault,
4218 TosaArgGen.agSlice,
4219 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004220 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004221 "error_if_validators": (
4222 TosaErrorValidator.evStartSmallerZero,
4223 TosaErrorValidator.evSizeSmallerEqualZero,
4224 TosaErrorValidator.evStartSizeOutsideBounds,
4225 TosaErrorValidator.evSizeOutputShapeMismatch,
4226 TosaErrorValidator.evInputSizeStartLengthMismatch,
4227 TosaErrorValidator.evWrongRank,
4228 TosaErrorValidator.evWrongInputType,
4229 TosaErrorValidator.evWrongOutputType,
4230 TosaErrorValidator.evWrongInputList,
4231 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004232 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004233 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004234 },
4235 "tile": {
4236 "op": Op.TILE,
4237 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004238 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004239 "build_fcn": (
4240 build_tile,
4241 TosaTensorGen.tgBasic,
4242 TosaTensorValuesGen.tvgDefault,
4243 TosaArgGen.agTile,
4244 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004245 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004246 "error_if_validators": (
4247 TosaErrorValidator.evWrongInputType,
4248 TosaErrorValidator.evWrongOutputType,
4249 TosaErrorValidator.evWrongInputList,
4250 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004251 TosaErrorValidator.evRankMismatch,
4252 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004253 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004254 },
4255 "transpose": {
4256 "op": Op.TRANSPOSE,
4257 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004258 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004259 "build_fcn": (
4260 build_transpose,
4261 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004262 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004263 TosaArgGen.agTranspose,
4264 ),
4265 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004266 "error_if_validators": (
4267 TosaErrorValidator.evIndexOutsideBounds,
4268 TosaErrorValidator.evIndexUsedTwice,
4269 TosaErrorValidator.evWrongInputType,
4270 TosaErrorValidator.evWrongOutputType,
4271 TosaErrorValidator.evWrongInputList,
4272 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004273 TosaErrorValidator.evWrongRank,
4274 TosaErrorValidator.evRankMismatch,
4275 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004276 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004277 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004278 # Data nodes
4279 "const": {
4280 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004281 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004282 "build_fcn": (
4283 build_const,
4284 TosaTensorGen.tgBasic,
4285 TosaTensorValuesGen.tvgDefault,
4286 None,
4287 ),
Luke Hutton65872422023-02-20 10:33:04 +00004288 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004289 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004290 "identity": {
4291 "op": Op.IDENTITY,
4292 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004293 "build_fcn": (
4294 build_unary,
4295 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004296 TosaTensorValuesGen.tvgLazyGenDefault,
4297 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004298 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004299 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004300 "data_gen": {
4301 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4302 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004303 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004304 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004305 "gather": {
4306 "op": Op.GATHER,
4307 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4308 "operands": (1, 0),
4309 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004310 "build_fcn": (
4311 build_gather,
4312 TosaTensorGen.tgBasic,
4313 TosaTensorValuesGen.tvgDefault,
4314 None,
4315 ),
James Ward24dbc422022-10-19 12:20:31 +01004316 "types": (
4317 DType.INT8,
4318 DType.INT16,
4319 DType.INT32,
4320 DType.FP16,
4321 DType.BF16,
4322 DType.FP32,
4323 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004324 "error_if_validators": (
4325 TosaErrorValidator.evWrongInputType,
4326 TosaErrorValidator.evWrongOutputType,
4327 TosaErrorValidator.evWrongInputList,
4328 TosaErrorValidator.evWrongOutputList,
4329 TosaErrorValidator.evWrongRank,
4330 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004331 },
4332 "scatter": {
4333 "op": Op.SCATTER,
4334 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004335 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004336 "operands": (2, 0),
4337 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004338 "build_fcn": (
4339 build_scatter,
4340 TosaTensorGen.tgScatter,
4341 TosaTensorValuesGen.tvgDefault,
4342 None,
4343 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004344 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004345 "error_if_validators": (
4346 TosaErrorValidator.evWrongInputType,
4347 TosaErrorValidator.evWrongOutputType,
4348 TosaErrorValidator.evWrongInputList,
4349 TosaErrorValidator.evWrongOutputList,
4350 TosaErrorValidator.evWrongRank,
4351 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004352 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004353 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004354 "resize": {
4355 "op": Op.RESIZE,
4356 "operands": (1, 0),
4357 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004358 "build_fcn": (
4359 build_resize,
4360 TosaTensorGen.tgNHWC,
4361 TosaTensorValuesGen.tvgDefault,
4362 TosaArgGen.agResize,
4363 ),
James Ward24dbc422022-10-19 12:20:31 +01004364 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004365 "invalid_test_validators": (
4366 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004367 ),
4368 "error_if_validators": (
4369 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004370 TosaErrorValidator.evScaleSmallerEqualZero,
4371 TosaErrorValidator.evScaleNLargerMax,
4372 TosaErrorValidator.evScaleDLargerMax,
4373 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004374 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004375 TosaErrorValidator.evBorderSmallerMin,
4376 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004377 TosaErrorValidator.evWrongInputType,
4378 TosaErrorValidator.evWrongOutputType,
4379 TosaErrorValidator.evWrongRank,
4380 TosaErrorValidator.evWrongInputList,
4381 TosaErrorValidator.evWrongOutputList,
4382 TosaErrorValidator.evBatchMismatch,
4383 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004384 TosaErrorValidator.evResizeOutputShapeMismatch,
4385 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004386 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004387 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004388 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004389 "cast": {
4390 "op": Op.CAST,
4391 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004392 "build_fcn": (
4393 build_cast,
4394 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004395 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004396 TosaArgGen.agCast,
4397 ),
James Ward8b390432022-08-12 20:48:56 +01004398 "types": (
4399 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004400 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004401 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004402 DType.INT8,
4403 DType.INT16,
4404 DType.INT32,
4405 DType.BOOL,
4406 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004407 "error_if_validators": (
4408 TosaErrorValidator.evWrongInputType,
4409 TosaErrorValidator.evWrongOutputType,
4410 TosaErrorValidator.evWrongInputList,
4411 TosaErrorValidator.evWrongOutputList,
4412 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004413 "data_gen": {
4414 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4415 },
4416 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004417 },
4418 "rescale": {
4419 "op": Op.RESCALE,
4420 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004421 "build_fcn": (
4422 build_rescale,
4423 TosaTensorGen.tgBasic,
4424 TosaTensorValuesGen.tvgDefault,
4425 TosaArgGen.agRescale,
4426 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004427 "types": [
4428 DType.UINT8,
4429 DType.INT8,
4430 DType.INT16,
4431 DType.INT32,
4432 DType.INT48,
4433 DType.UINT16,
4434 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004435 "error_if_validators": (
4436 TosaErrorValidator.evInputZeroPointNotZero,
4437 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004438 TosaErrorValidator.evU16InputZeroPointNotValid,
4439 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004440 TosaErrorValidator.evScaleTrue,
4441 TosaErrorValidator.evScaleNotTrue,
4442 TosaErrorValidator.evWrongInputType,
4443 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004444 TosaErrorValidator.evWrongInputList,
4445 TosaErrorValidator.evWrongOutputList,
4446 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004447 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004448 # Custom
4449 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004450 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004451 # Two varients of cond_if, one that generates one of two constant tensors (no
4452 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4453 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004454 "cond_if_const": {
4455 "op": Op.COND_IF,
4456 "operands": (0, 2),
4457 "build_fcn": (
4458 build_cond_if_const,
4459 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004460 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004461 TosaArgGen.agCondIf,
4462 ),
4463 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004464 "error_if_validators": (
4465 TosaErrorValidator.evOutputListThenGraphMismatch,
4466 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004467 TosaErrorValidator.evCondIfCondNotMatchingBool,
4468 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004469 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004470 },
4471 "cond_if_binary": {
4472 "op": Op.COND_IF,
4473 "operands": (2, 0),
4474 "build_fcn": (
4475 build_cond_if_binary,
4476 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004477 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004478 TosaArgGen.agCondIf,
4479 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004480 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004481 "error_if_validators": (
4482 TosaErrorValidator.evInputListThenGraphMismatch,
4483 TosaErrorValidator.evInputListElseGraphMismatch,
4484 TosaErrorValidator.evOutputListThenGraphMismatch,
4485 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004486 TosaErrorValidator.evCondIfCondNotMatchingBool,
4487 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004488 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004489 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004490 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004491 "while_loop": {
4492 "op": Op.WHILE_LOOP,
4493 "operands": (0, 1),
4494 "build_fcn": (
4495 build_while_loop,
4496 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004497 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004498 TosaArgGen.agWhileLoop,
4499 ),
4500 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004501 "error_if_validators": (
4502 TosaErrorValidator.evInputListOutputListMismatch,
4503 TosaErrorValidator.evInputListCondGraphMismatch,
4504 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4505 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4506 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004507 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004508 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004509 },
Luke Hutton57287132023-02-06 14:54:18 +00004510 "fft2d": {
4511 "op": Op.FFT2D,
4512 "operands": (2, 0),
4513 "rank": (3, 3),
4514 "build_fcn": (
4515 build_fft2d,
4516 TosaTensorGen.tgFFT2d,
4517 TosaTensorValuesGen.tvgDefault,
4518 TosaArgGen.agFFT2d,
4519 ),
4520 "types": [DType.FP32],
4521 "error_if_validators": (
4522 TosaErrorValidator.evWrongInputType,
4523 TosaErrorValidator.evWrongOutputType,
4524 TosaErrorValidator.evWrongInputList,
4525 TosaErrorValidator.evWrongOutputList,
4526 TosaErrorValidator.evWrongRank,
4527 TosaErrorValidator.evBatchMismatch,
4528 TosaErrorValidator.evKernelNotPowerOfTwo,
4529 TosaErrorValidator.evFFTInputShapeMismatch,
4530 TosaErrorValidator.evFFTOutputShapeMismatch,
4531 ),
4532 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004533 "rfft2d": {
4534 "op": Op.RFFT2D,
4535 "operands": (1, 0),
4536 "rank": (3, 3),
4537 "build_fcn": (
4538 build_rfft2d,
4539 TosaTensorGen.tgRFFT2d,
4540 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004541 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004542 ),
4543 "types": [DType.FP32],
4544 "error_if_validators": (
4545 TosaErrorValidator.evWrongInputType,
4546 TosaErrorValidator.evWrongOutputType,
4547 TosaErrorValidator.evWrongInputList,
4548 TosaErrorValidator.evWrongOutputList,
4549 TosaErrorValidator.evWrongRank,
4550 TosaErrorValidator.evBatchMismatch,
4551 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004552 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004553 ),
4554 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004555 }
4556
Kevin Cheng550ccc52021-03-03 11:21:43 -08004557
Eric Kunzee5e26762020-10-13 16:11:07 -07004558class OutputShaper:
4559 # Methods in this class compute the expected output shape and datatype
4560 # for common classes of operations
4561 def __init__(self):
4562 pass
4563
4564 # These methods return arguments that can be used for
4565 # creating a new output tensor
4566 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004567 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4568 if error_name != ErrorIf.RankMismatch:
4569 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004570 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004571
4572 shape = []
4573 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004574 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004575 shape.append(b.shape[i])
4576 else:
4577 shape.append(a.shape[i])
4578
Jerry Ge135c9552023-05-23 20:59:32 +00004579 fuzz_idx = rng.integers(0, len(a.shape))
4580 if error_name == ErrorIf.DimensionMismatch:
4581 shape[fuzz_idx] += 1
4582
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004583 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004584 all_dtypes = [
4585 DType.INT8,
4586 DType.INT16,
4587 DType.INT32,
4588 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004589 DType.FP16,
4590 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004591 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004592 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004593 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4594 outputDType = rng.choice(wrong_dtypes)
4595 else:
4596 outputDType = a.dtype
4597
4598 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004599
4600 @staticmethod
4601 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004602 assert len(a.shape) == len(b.shape)
4603 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004604
4605 shape = []
4606 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004607 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004608 shape.append(a.shape[i])
4609
Kevin Cheng550ccc52021-03-03 11:21:43 -08004610 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004611
4612 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004613 def unaryOp(ser, rng, a, error_name=None):
4614 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004615 all_dtypes = [
4616 DType.INT8,
4617 DType.INT16,
4618 DType.INT32,
4619 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004620 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004621 DType.FP16,
4622 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004623 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004624 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4625 outputDType = rng.choice(wrong_dtypes)
4626 else:
4627 outputDType = a.dtype
4628
4629 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004630
4631 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004632 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004633 if error_name != ErrorIf.RankMismatch:
4634 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004635 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004636
4637 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004638 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004639 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004640 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4641 else:
4642 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004643
Jerry Ge135c9552023-05-23 20:59:32 +00004644 fuzz_idx = rng.integers(0, len(a.shape))
4645 if error_name == ErrorIf.DimensionMismatch:
4646 shape[fuzz_idx] += 1
4647
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004648 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004649 all_dtypes = [
4650 DType.INT8,
4651 DType.INT16,
4652 DType.INT32,
4653 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004654 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004655 DType.FP16,
4656 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004657 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004658 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4659 outputDType = rng.choice(wrong_dtypes)
4660 else:
4661 outputDType = a.dtype
4662
4663 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004664
4665 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004666 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004667 if error_name != ErrorIf.RankMismatch:
4668 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004669 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004670
4671 # Do broadcast
4672 shape = []
4673 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004674 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004675 shape.append(b.shape[i])
4676 else:
4677 shape.append(a.shape[i])
4678
Jerry Ge135c9552023-05-23 20:59:32 +00004679 fuzz_idx = rng.integers(0, len(a.shape))
4680 if error_name == ErrorIf.DimensionMismatch:
4681 shape[fuzz_idx] += 1
4682
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004683 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004684 wrong_dtypes = [
4685 DType.INT8,
4686 DType.INT16,
4687 DType.INT32,
4688 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004689 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004690 DType.FP16,
4691 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004692 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004693 outputDType = rng.choice(wrong_dtypes)
4694 else:
4695 outputDType = DType.BOOL
4696
4697 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004698
4699 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004700 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004701 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004702 if error_name not in [
4703 ErrorIf.AxisSmallerZero,
4704 ErrorIf.AxisLargerRank,
4705 ErrorIf.ShapeOfAxisNotOne,
4706 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004707 shape[axis] = 1
4708 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4709 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004710
Matthew Haddond6ce7252021-09-29 15:35:44 +01004711 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004712 all_dtypes = [
4713 DType.INT8,
4714 DType.INT16,
4715 DType.INT32,
4716 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004717 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004718 DType.FP16,
4719 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004720 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004721 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4722 outputDType = rng.choice(wrong_dtypes)
4723 else:
4724 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004725
Matthew Haddond6ce7252021-09-29 15:35:44 +01004726 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004727
4728 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004729 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004730 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004731
4732 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4733 del shape[axis]
4734
4735 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4736 remove = rng.choice([True, False])
4737 if remove and len(shape) > 1:
4738 del shape[0]
4739 else:
4740 shape.append(1)
4741 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4742 for i in range(len(shape)):
4743 shape[i] = shape[i] + rng.integers(1, 10)
4744
4745 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004746 all_dtypes = [
4747 DType.INT8,
4748 DType.INT16,
4749 DType.INT32,
4750 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004751 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004752 DType.FP16,
4753 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004754 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004755 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4756 outputDType = rng.choice(wrong_dtypes)
4757 else:
4758 outputDType = DType.INT32
4759
4760 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004761
4762 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004763 def conv2dOp(
4764 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4765 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004766
4767 # IFM: NHWC
4768 # Filter: OHWI
4769 # OFM: NHWC
4770
Kevin Cheng550ccc52021-03-03 11:21:43 -08004771 h = (
4772 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004773 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004774 + padding[0]
4775 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004776 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004777 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004778
Kevin Cheng550ccc52021-03-03 11:21:43 -08004779 w = (
4780 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004781 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004782 + padding[2]
4783 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004784 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004785 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004786
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004787 if error_name == ErrorIf.ConvOutputShapeMismatch:
4788 choices = [1, 2, 3]
4789 change = rng.choice(choices)
4790 # increment in multiples of stride to not hit non-integer error case
4791 if change in [1, 3]:
4792 h = h + (rng.choice(choices) * strides[0])
4793 if change in [2, 3]:
4794 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004795
Eric Kunzee5e26762020-10-13 16:11:07 -07004796 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4797
James Ward8b390432022-08-12 20:48:56 +01004798 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004799 # Pick some potentially correct output dtype if input type is incorrect
4800 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004801 else:
James Ward8b390432022-08-12 20:48:56 +01004802 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004803
4804 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004805 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004806 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004807 else:
4808 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004809 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004810 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004811
Kevin Cheng550ccc52021-03-03 11:21:43 -08004812 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004813
4814 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004815 def conv3dOp(
4816 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4817 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004818
4819 # IFM: NDHWC
4820 # Filter: ODHWI
4821 # OFM: NDHWC
4822
4823 d = (
4824 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004825 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004826 + padding[0]
4827 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004828 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004829 ) // strides[0] + 1
4830
4831 h = (
4832 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004833 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004834 + padding[2]
4835 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004836 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004837 ) // strides[1] + 1
4838
4839 w = (
4840 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004841 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004842 + padding[4]
4843 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004844 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004845 ) // strides[2] + 1
4846
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004847 if error_name == ErrorIf.ConvOutputShapeMismatch:
4848 choices = [1, 2, 3, 4]
4849 change = rng.choice(choices)
4850 # increment in multiples of stride to not hit non-integer error case
4851 if change in [1, 4]:
4852 d = d + (rng.choice(choices) * strides[0])
4853 if change in [2, 4]:
4854 h = h + (rng.choice(choices) * strides[1])
4855 if change in [3, 4]:
4856 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004857
Kevin Cheng1533b852021-09-01 12:51:58 -07004858 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4859
James Ward8b390432022-08-12 20:48:56 +01004860 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004861 # Pick some potentially correct output dtype if input type is incorrect
4862 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004863 else:
James Ward8b390432022-08-12 20:48:56 +01004864 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004865
4866 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004867 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004868 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004869 else:
4870 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004871 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004872 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004873
4874 return ser.addOutput(ofm_shape, out_dtype)
4875
4876 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004877 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004878 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004879 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004880 # IFM: NHWC
4881 # Filter: HWCM
4882 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004883
Kevin Cheng550ccc52021-03-03 11:21:43 -08004884 h = (
4885 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004886 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004887 + padding[0]
4888 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004889 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004890 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004891
Kevin Cheng550ccc52021-03-03 11:21:43 -08004892 w = (
4893 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004894 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004895 + padding[2]
4896 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004897 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004898 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004899
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004900 if error_name == ErrorIf.ConvOutputShapeMismatch:
4901 choices = [1, 2, 3]
4902 change = rng.choice(choices)
4903 # increment in multiples of stride to not hit non-integer error case
4904 if change in [1, 3]:
4905 h = h + (rng.choice(choices) * strides[0])
4906 if change in [2, 3]:
4907 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004908
Eric Kunzee5e26762020-10-13 16:11:07 -07004909 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4910
James Ward8b390432022-08-12 20:48:56 +01004911 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004912 # Pick some potentially correct output dtype if input type is incorrect
4913 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004914 else:
James Ward8b390432022-08-12 20:48:56 +01004915 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004916
4917 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004918 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004919 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004920 else:
4921 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004922 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004923 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004924
Kevin Cheng550ccc52021-03-03 11:21:43 -08004925 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004926
4927 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004928 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004929 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004930 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004931 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004932 h = 1
4933 w = 1
4934 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004935 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4936 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004937
4938 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004939 choices = [1, 2, 3]
4940 change = rng.choice(choices)
4941 # increment in multiples of stride to not hit non-integer error case
4942 if change in [1, 3]:
4943 h = h + (rng.choice(choices) * stride[0])
4944 if change in [2, 3]:
4945 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004946 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004947
4948 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004949 all_dtypes = [
4950 DType.INT8,
4951 DType.INT16,
4952 DType.INT32,
4953 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004954 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004955 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004956 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004957 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004958 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4959 outputDType = rng.choice(wrong_dtypes)
4960 else:
4961 outputDType = ifm.dtype
4962
4963 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004964
4965 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004966 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004967 # input: N, IC
4968 # filter: OC, IC
4969 # output: N, OC
4970
4971 output_shape = [input.shape[0], filter.shape[0]]
4972
James Ward8b390432022-08-12 20:48:56 +01004973 # Validated in arg_gen (also invalidated for ErrorIf)
4974 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004975
Kevin Cheng550ccc52021-03-03 11:21:43 -08004976 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004977
4978 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004979 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004980 # a: N, H, C
4981 # b: N, C, W
4982 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004983
Kevin Cheng2d60f002021-06-09 14:18:32 -07004984 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004985
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004986 if error_name == ErrorIf.WrongOutputType:
4987 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004988 incorrect_types = (
4989 DType.INT4,
4990 DType.INT8,
4991 DType.INT16,
4992 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004993 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004994 DType.FP16,
4995 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004996 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004997 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004998 incorrect_types = (
4999 DType.INT4,
5000 DType.INT8,
5001 DType.INT16,
5002 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005003 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005004 DType.FP16,
5005 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005006 )
James Ward24dbc422022-10-19 12:20:31 +01005007 elif (
5008 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5009 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005010 incorrect_types = (
5011 DType.INT4,
5012 DType.INT8,
5013 DType.INT16,
5014 DType.INT32,
5015 DType.INT48,
5016 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005017 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005018 elif error_name == ErrorIf.WrongInputType:
5019 # Pick some potentially correct output dtype if input type is incorrect
5020 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005021 else:
James Ward8b390432022-08-12 20:48:56 +01005022 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005023
Kevin Cheng550ccc52021-03-03 11:21:43 -08005024 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005025
5026 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005027 def concatOp(ser, rng, axis, inputs, error_name=None):
5028 input1 = inputs[0]
5029 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005030
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005031 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005032 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005033 if not (
5034 # unable to concat tensors of different ranks
5035 error_name == ErrorIf.ConcatInputRankMismatch
5036 # unable to concat tensors along an invalid axis
5037 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005038 ):
5039 for tensor in remaining_inputs:
5040 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005041
Matthew Haddon01c359d2021-10-15 16:30:48 +01005042 if error_name == ErrorIf.ConcatShapeSumMismatch:
5043 output_shape[axis] += rng.integers(5, 10)
5044
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005045 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005046 all_dtypes = {
5047 DType.INT8,
5048 DType.INT16,
5049 DType.INT32,
5050 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005051 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005052 DType.FP16,
5053 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005054 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005055 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5056 outputDType = rng.choice(wrong_dtypes)
5057 else:
5058 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005059
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005060 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005061
5062 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005063 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005064
5065 output_shape = a.shape.copy()
5066
5067 for i in range(len(output_shape)):
5068 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5069
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005070 if error_name == ErrorIf.PadOutputShapeMismatch:
5071 bad_dim = rng.choice(range(len(output_shape)))
5072 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005073 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005074 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005075
Matthew Haddone807aae2021-10-11 18:12:58 +01005076 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005077 all_dtypes = [
5078 DType.INT8,
5079 DType.INT16,
5080 DType.INT32,
5081 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005082 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005083 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005084 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005085 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005086 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5087 outputDType = rng.choice(wrong_dtypes)
5088 else:
5089 outputDType = a.dtype
5090
5091 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005092
5093 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005094 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00005095 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00005096
5097 if error_name == ErrorIf.WrongOutputType:
5098 all_dtypes = [
5099 DType.INT8,
5100 DType.INT16,
5101 DType.INT32,
5102 DType.INT48,
5103 DType.FP32,
5104 DType.FP16,
5105 DType.BF16,
5106 ]
5107 wrong_dtypes = list(set(all_dtypes))
5108 outputDType = rng.choice(wrong_dtypes)
5109 else:
5110 outputDType = DType.SHAPE
5111
5112 return ser.addOutput(output_shape, outputDType)
5113
5114 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005115 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005116 output_shape = shape.copy()
5117
Matthew Haddone807aae2021-10-11 18:12:58 +01005118 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5119 for i in range(len(output_shape)):
5120 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5121
5122 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005123 all_dtypes = [
5124 DType.INT8,
5125 DType.INT16,
5126 DType.INT32,
5127 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005128 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005129 DType.FP16,
5130 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005131 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005132 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5133 outputDType = rng.choice(wrong_dtypes)
5134 else:
5135 outputDType = a.dtype
5136
5137 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005138
5139 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005140 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005141
Matthew Haddone807aae2021-10-11 18:12:58 +01005142 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005143 all_dtypes = [
5144 DType.INT8,
5145 DType.INT16,
5146 DType.INT32,
5147 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005148 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005149 DType.FP16,
5150 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005151 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005152 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005153 outputDType = rng.choice(wrong_dtypes)
5154 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005155 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005156
Luke Huttona4e48ca2023-02-22 11:53:48 +00005157 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005158 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005159 for index in range(len(output_shape)):
5160 if output_shape[index] <= 2:
5161 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5162 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005163 output_shape[index] = output_shape[index] + rng.choice(
5164 [-2, -1, 1, 2]
5165 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005166 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5167 output_shape = input.shape.copy()
5168 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005169 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005170
5171 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005172
5173 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005174 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005175
5176 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005177 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005178
5179 for i in range(len(output_shape)):
5180 output_shape[i] = a.shape[i] * multiples[i]
5181
Luke Huttona4e48ca2023-02-22 11:53:48 +00005182 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005183 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005184
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005185 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005186 all_dtypes = [
5187 DType.INT8,
5188 DType.INT16,
5189 DType.INT32,
5190 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005191 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005192 DType.FP16,
5193 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005194 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005195 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5196 outputDType = rng.choice(wrong_dtypes)
5197 else:
5198 outputDType = a.dtype
5199
5200 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005201
5202 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005203 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005204 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005205
Kevin Cheng550ccc52021-03-03 11:21:43 -08005206 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005207
Luke Huttona4e48ca2023-02-22 11:53:48 +00005208 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005209 for i in range(len(output_shape)):
5210 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005211
Luke Huttona4e48ca2023-02-22 11:53:48 +00005212 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5213 for i in range(len(output_shape)):
5214 output_shape[i] += rng.integers(1, 10)
5215 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005216 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005217
Matthew Haddone807aae2021-10-11 18:12:58 +01005218 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005219 all_dtypes = [
5220 DType.INT8,
5221 DType.INT16,
5222 DType.INT32,
5223 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005224 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005225 DType.FP16,
5226 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005227 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005228 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5229 outputDType = rng.choice(wrong_dtypes)
5230 else:
5231 outputDType = a.dtype
5232
5233 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005234
5235 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005236 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005237 if error_name != ErrorIf.WrongRank:
5238 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005239 assert len(indices.shape) == 2
5240 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005241
Kevin Cheng77d0f762020-11-24 10:26:32 -08005242 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5243
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005244 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005245 all_dtypes = [
5246 DType.INT8,
5247 DType.INT16,
5248 DType.INT32,
5249 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005250 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005251 DType.FP16,
5252 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005253 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005254 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5255 outputDType = rng.choice(wrong_dtypes)
5256 else:
5257 outputDType = values.dtype
5258
5259 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005260
5261 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005262 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005263 if error_name != ErrorIf.WrongRank:
5264 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005265 assert len(indices.shape) == 2
5266 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005267 assert values_in.shape[0] == indices.shape[0] # N
5268 assert input.shape[1] == indices.shape[1] # W
5269 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005270
5271 output_shape = values_in.shape
5272
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005273 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005274 all_dtypes = [
5275 DType.INT8,
5276 DType.INT16,
5277 DType.INT32,
5278 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005279 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005280 DType.FP16,
5281 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005282 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005283 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5284 outputDType = rng.choice(wrong_dtypes)
5285 else:
5286 outputDType = values_in.dtype
5287
5288 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005289
5290 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005291 def tableOp(ser, rng, input, error_name=None):
5292 # Same shape as the input, dtype dependent on input dtype
5293 if error_name != ErrorIf.WrongInputType:
5294 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005295 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005296 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005297 wrong_dtypes = [
5298 DType.INT8,
5299 DType.INT16,
5300 DType.INT32,
5301 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005302 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005303 DType.FP16,
5304 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005305 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005306 wrong_dtypes.remove(output_dtype)
5307 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005308 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005309
5310 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005311 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005312 serializer,
5313 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005314 input,
5315 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005316 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005317 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005318 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005319 input_dtype,
5320 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005321 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005322 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005323 # Calculate OH, OW
5324 scale_y_n = scale[0]
5325 scale_y_d = scale[1]
5326 scale_x_n = scale[2]
5327 scale_x_d = scale[3]
5328 if error_name == ErrorIf.ScaleSmallerEqualZero:
5329 scale_y_n = max(scale_y_n, 1)
5330 scale_y_d = max(scale_y_d, 1)
5331 scale_x_n = max(scale_x_n, 1)
5332 scale_x_d = max(scale_x_d, 1)
5333
5334 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5335 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5336
5337 if error_name is not None:
5338 # Make sure the output tensor is valid, which can occur when
5339 # scale, offset or border have been changed for ERROR_IFs
5340 oh = max(oh, 1)
5341 ow = max(ow, 1)
5342 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005343 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5344 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005345
5346 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5347 choices = [1, 2, 3]
5348 change = rng.choice(choices)
5349 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5350 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005351 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005352 oh -= scale_y_d
5353 assert oh > 0 # Should have been caught in agResize
5354 else:
5355 oh += scale_y_d
5356 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005357 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005358 ow -= scale_x_d
5359 assert ow > 0 # Should have been caught in agResize
5360 else:
5361 ow += scale_x_d
5362
Matthew Haddon848efb42021-09-09 12:30:53 +01005363 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005364 output_dims = [
5365 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005366 oh,
5367 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005368 input.shape[0],
5369 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005370 elif error_name == ErrorIf.BatchMismatch:
5371 output_dims = [
5372 input.shape[0] + rng.integers(1, 10),
5373 oh,
5374 ow,
5375 input.shape[3],
5376 ]
5377 elif error_name == ErrorIf.ChannelMismatch:
5378 output_dims = [
5379 input.shape[0],
5380 oh,
5381 ow,
5382 input.shape[3] + rng.integers(1, 10),
5383 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005384 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005385 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005386
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005387 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005388
5389 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005390 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005391 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005392
5393 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005394 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005395 if error_name == ErrorIf.ConvOutputShapeMismatch:
5396 choices = [1, 2, 3]
5397 change = rng.choice(choices)
5398 if change in [1, 3]:
5399 output_shape[1] = output_shape[1] + rng.choice(choices)
5400 if change in [2, 3]:
5401 output_shape[2] = output_shape[2] + rng.choice(choices)
5402
James Ward8b390432022-08-12 20:48:56 +01005403 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005404 # Pick some potentially correct output dtype if input type is incorrect
5405 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005406 else:
James Ward8b390432022-08-12 20:48:56 +01005407 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005408
5409 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005410 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005411 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005412 else:
5413 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005414 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005415 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005416
Kevin Cheng550ccc52021-03-03 11:21:43 -08005417 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005418
5419 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005420 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5421 outputs = []
5422
5423 assert ifm1.dtype == ifm2.dtype
5424 input_dtype = ifm1.dtype
5425
5426 if error_name != ErrorIf.FFTInputShapeMismatch:
5427 assert ifm1.shape == ifm2.shape
5428
5429 input_shape = ifm1.shape
5430 if error_name != ErrorIf.WrongRank:
5431 assert len(input_shape) == 3
5432
5433 output_shape = input_shape.copy()
5434 output_dtype = input_dtype
5435
5436 if error_name == ErrorIf.WrongOutputType:
5437 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005438 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005439 output_dtype = rng.choice(wrong_dtypes)
5440 elif error_name == ErrorIf.BatchMismatch:
5441 output_shape[0] += rng.integers(1, 10)
5442 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5443 modify_dim = rng.choice([1, 2])
5444 output_shape[modify_dim] += rng.integers(1, 10)
5445
5446 outputs.append(serializer.addOutput(output_shape, output_dtype))
5447 outputs.append(serializer.addOutput(output_shape, output_dtype))
5448 return outputs
5449
5450 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005451 def rfft2dOp(serializer, rng, value, error_name=None):
5452 outputs = []
5453
5454 input_shape = value.shape
5455 if error_name != ErrorIf.WrongRank:
5456 assert len(input_shape) == 3
5457
5458 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5459
5460 output_dtype = value.dtype
5461 if error_name == ErrorIf.WrongOutputType:
5462 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005463 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005464 output_dtype = rng.choice(wrong_dtypes)
5465 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005466 output_shape[0] += rng.integers(1, 10)
5467 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5468 modify_dim = rng.choice([1, 2])
5469 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005470
5471 outputs.append(serializer.addOutput(output_shape, output_dtype))
5472 outputs.append(serializer.addOutput(output_shape, output_dtype))
5473 return outputs