blob: 53b0b75d13032d8f2148028fa78d462ec3d3ba46 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
69 vals.append(v)
70 return tuple(sorted(vals))
71
72 self.random_float_range = {}
73 for dtype in (DType.FP32, DType.FP16, DType.BF16):
74 self.random_float_range[dtype] = convertFPRange(
75 args.tensor_fp_value_range,
76 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
77 )
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 def createSerializer(self, opName, testPath):
80 self.testPath = os.path.join(opName, testPath)
81
82 fullPath = os.path.join(self.basePath, self.testPath)
83 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010084 # Embed const data in the flatbuffer
85 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010086 if self.args.lazy_data_gen:
87 # Lazy data generation - so make constants files
88 constMode = ts.ConstMode.INPUTS
89 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 constMode = ts.ConstMode.EMBED_DUMP
91 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070092
93 def getSerializer(self):
94 return self.ser
95
Jeremy Johnson1271c442023-09-05 11:39:26 +010096 def serialize(self, testName, metaData=None):
97 path = Path(self.basePath) / self.testPath
98
99 # Write out TOSA flatbuffer binary
100 path_fb = path / f"{testName}.tosa"
101 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 fd.write(self.ser.serialize())
103
Jeremy Johnson1271c442023-09-05 11:39:26 +0100104 # Get JSON descriptor from serializer
105 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
106
107 if metaData:
108 # Add extra meta data to desc.json
109 desc["meta"] = metaData
110
111 # Validate desc.json before we output it
112 self.descSchemaValidator.validate_config(desc)
113
114 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100115 if "data_gen" in metaData:
116 if self.args.lazy_data_gen:
117 # Output datagen meta data as CPP data
118 path_md = path / f"{testName}_meta_data_gen.cpp"
119 with path_md.open("w") as fd:
120 fd.write(TOSA_AUTOGENERATED_HEADER)
121 fd.write("// Test meta data for data generation setup\n\n")
122 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
123 json.dump(metaData["data_gen"], fd)
124 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100125 if "compliance" in metaData:
126 # Output datagen meta data as CPP data
127 path_md = path / f"{testName}_meta_compliance.cpp"
128 with path_md.open("w") as fd:
129 fd.write(TOSA_AUTOGENERATED_HEADER)
130 fd.write("// Test meta data for compliance validation\n\n")
131 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
132 json.dump(metaData["compliance"], fd)
133 fd.write(')";\n\n')
134
135 # Write desc.json
136 path_desc = path / "desc.json"
137 with path_desc.open("w") as fd:
138 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700139
Matthew Haddon74567092021-07-16 15:38:20 +0100140 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000141 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100142 seed = self.random_seed + 1
143 self.rng = np.random.default_rng(seed)
144
Jeremy Johnson1271c442023-09-05 11:39:26 +0100145 def getDTypeRange(self, dtype, high_inclusive=False):
146 # Returns dtype value range boundaries (low, high)
147 # The high boundary is excluded in the range
148 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100149 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100150 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 elif dtype == DType.BOOL:
152 rng = (0, 2)
153 elif dtype == DType.UINT8:
154 rng = (0, 256)
155 elif dtype == DType.UINT16:
156 rng = (0, 65536)
157 elif dtype == DType.INT4:
158 # TOSA specific INT4 weight range from -7 to 7
159 rng = (-7, 8)
160 elif dtype == DType.INT8:
161 rng = (-128, 128)
162 elif dtype == DType.INT16:
163 rng = (-32768, 32768)
164 elif dtype in (DType.INT32, DType.SHAPE):
165 # restricting too large value for SHAPE
166 rng = (-(1 << 31), (1 << 31))
167 elif dtype == DType.INT48:
168 rng = (-(1 << 47), (1 << 47))
169 else:
170 raise Exception("Unknown dtype: {}".format(dtype))
171
172 if not high_inclusive:
173 # Exclusive high: low <= range < high
174 return rng
175 else:
176 # Inclusive range: low <= range <= high
177 return (rng[0], rng[1] - 1)
178
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000179 def getRandTensor(self, shape, dtype, data_range=None):
180 if data_range is None:
181 low, high = self.getDTypeRange(dtype)
182 else:
183 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100184
Eric Kunzee5e26762020-10-13 16:11:07 -0700185 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100188 return np.int64(self.rng.integers(low=low, high=high, size=shape))
189 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
190 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
191
192 if dtype == DType.FP16:
193 return np.float16(f_tensor)
194 else:
195 f32_tensor = np.float32(f_tensor)
196 if dtype == DType.BF16:
197 # Floor the last 16 bits of each f32 value
198 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
199 else:
200 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100202 # All other integer types
203 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700204
Kevin Cheng989cb052021-04-28 16:29:44 -0700205 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 placeholders = []
207
Kevin Cheng989cb052021-04-28 16:29:44 -0700208 assert len(shape_list) == len(dtype_list)
209
Jeremy Johnson1271c442023-09-05 11:39:26 +0100210 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100212 if not self.args.lazy_data_gen:
213 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700214 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
216 return placeholders
217
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700219 consts = []
220
Kevin Cheng989cb052021-04-28 16:29:44 -0700221 assert len(shape_list) == len(dtype_list)
222
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100225 if not self.args.lazy_data_gen:
226 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700227 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700228
229 return consts
230
231 def makeShape(self, rank):
232 if self.targetted_shape:
233 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800234 return np.int32(
235 self.rng.integers(
236 low=self.args.tensor_shape_range[0],
237 high=self.args.tensor_shape_range[1],
238 size=rank,
239 )
240 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700241
242 def setTargetShape(self, shape):
243 self.targetted_shape = shape
244
245 def randInt(self, low=0, high=256):
246 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
247
248 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100249 low, high = self.getDTypeRange(dtype)
250
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100251 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100252 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100253 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100254 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100255 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100256 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
257 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 elif dtype == DType.BOOL:
259 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700260 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700261 # Special size
262 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700263
264 return np.int32(self.rng.integers(low, high, size=1))[0]
265
266 def shapeStr(self, shape):
267
268 sStr = []
269 # Convert to strings
270 for i in shape:
271 sStr.append(str(i))
272
Kevin Cheng550ccc52021-03-03 11:21:43 -0800273 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100275 def typeStr(self, dtype):
276 if isinstance(dtype, list) or isinstance(dtype, tuple):
277 assert len(dtype) >= 2
278 strs = [self.typeStr(t) for t in dtype]
279 # Limit types to the first 2 as the 3rd is the accumulator
280 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700281 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100282 if dtype in gtu.DTYPE_ATTRIBUTES:
283 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700284 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100285 raise Exception(
286 "Unknown dtype, cannot convert to string: {}".format(dtype)
287 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700288
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100289 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100290 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100291 if dtype in gtu.DTYPE_ATTRIBUTES:
292 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700293 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100294 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700295
Luke Hutton57287132023-02-06 14:54:18 +0000296 def constrictBatchSize(self, shape):
297 # Limit the batch size unless an explicit target shape set
298 if self.args.max_batch_size and not self.args.target_shapes:
299 shape[0] = min(shape[0], self.args.max_batch_size)
300 return shape
301
James Ward30124a82023-02-02 14:56:33 +0000302 def makeDimension(self):
303 return self.randInt(
304 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
305 )
306
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100307 def tensorComplianceMetaData(
308 self, op, inputType, argsDict, outputTensor, errorName
309 ):
Jeremy Johnson708da822023-11-15 16:25:45 +0000310 # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet
311 UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100312 if (
313 errorName
314 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000315 or (
316 not gtu.dtypeIsSupportedByCompliance(inputType)
317 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
318 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100319 ):
320 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100321 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100322
Jeremy Johnson1271c442023-09-05 11:39:26 +0100323 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100324 compliance_tens = {
325 "mode": None,
326 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
327 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
328 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100329 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
330 mode = gtu.ComplianceMode.DOT_PRODUCT
331 compliance_tens["dot_product_info"] = {
332 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100333 "ks": int(argsDict["ksb"])
334 if "ksb" in argsDict
335 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100336 }
337 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
338 mode = gtu.ComplianceMode.FP_SPECIAL
339 elif "compliance" in op and "ulp" in op["compliance"]:
340 mode = gtu.ComplianceMode.ULP
341 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
342 elif op["op"] == Op.REDUCE_PRODUCT:
343 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnson534923d2023-12-04 11:11:06 +0000344 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000345 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnson1271c442023-09-05 11:39:26 +0100346 else:
347 mode = gtu.ComplianceMode.EXACT
348 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
349
350 return compliance_tens
351
352 # Build Op functions
353 # Create the output tensor (calling OutputShaper as needed)
354 # Do final tweaks to attributes (if necessary for errorIf)
355 # Add Op into graph
356 # Return resulting tensor information or BuildInfo
357
358 class BuildInfo:
359 """Enhanced build information containing result tensor and associated compliance dict."""
360
361 def __init__(self, resultTensor, complianceDict):
362 self.resultTensor = resultTensor
363 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700364
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000365 def build_unary(
366 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
367 ):
368 assert len(inputs) == 1
369 a = inputs[0]
370 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100371
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000372 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100373
374 # Ensure new output type has correct qinfo
375 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000376 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000377 qinfo = [
378 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000379 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000380 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100381
382 # Invalidate Input/Output list for error if checks.
383 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000384 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100385 pCount, cCount = op["operands"]
386 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000387 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
388 self, error_name, input_list, output_list
389 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100390
Les Bell729b0352021-11-24 10:28:21 +0000391 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100392 self.ser,
393 validator_fcns,
394 error_name,
395 op=op,
396 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000397 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000398 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000399 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100400 input_list=input_list,
401 output_list=output_list,
402 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000403 ):
404 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100405
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000406 attr = None
407 if op["op"] == Op.NEGATE:
408 attr = ts.TosaSerializerAttribute()
409 attr.NegateAttribute(qinfo[0], qinfo[1])
410
411 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000412
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000413 compliance = self.tensorComplianceMetaData(
414 op, a.dtype, args_dict, result_tensor, error_name
415 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000416 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700417
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000418 def build_binary_broadcast(
419 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
420 ):
421 assert len(inputs) == 2
422 a, b = inputs
423 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000424 self.ser, self.rng, a, b, error_name
425 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100426
427 # Invalidate Input/Output list for error if checks.
428 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000429 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100430 pCount, cCount = op["operands"]
431 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000432 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
433 self, error_name, input_list, output_list
434 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100435
Les Bell729b0352021-11-24 10:28:21 +0000436 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100437 self.ser,
438 validator_fcns,
439 error_name,
440 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000441 input1=a,
442 input2=b,
443 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000444 output_dtype=result_tensor.dtype,
445 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100446 input_list=input_list,
447 output_list=output_list,
448 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000449 ):
450 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100451
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000452 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000453
Jeremy Johnson9a758382023-11-07 16:27:35 +0000454 compliance = self.tensorComplianceMetaData(
455 op, a.dtype, args_dict, result_tensor, error_name
456 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000457
458 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700459
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100460 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700461 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000462 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700463 return result_tens
464
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000465 def build_arithmetic_right_shift(
466 self, op, a, b, round, validator_fcns=None, error_name=None
467 ):
468 result_tens = OutputShaper.binaryBroadcastOp(
469 self.ser, self.rng, a, b, error_name
470 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100471
472 # Invalidate Input/Output list for error if checks.
473 input_list = [a.name, b.name]
474 output_list = [result_tens.name]
475 pCount, cCount = op["operands"]
476 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000477 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
478 self, error_name, input_list, output_list
479 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100480
Les Bell729b0352021-11-24 10:28:21 +0000481 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100482 self.ser,
483 validator_fcns,
484 error_name,
485 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000486 input1=a,
487 input2=b,
488 input_dtype=a.dtype,
489 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000490 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100491 input_list=input_list,
492 output_list=output_list,
493 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000494 ):
495 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800496
497 attr = ts.TosaSerializerAttribute()
498 attr.ArithmeticRightShiftAttribute(round)
499
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000500 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800501 return result_tens
502
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100503 def build_mul(
504 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
505 ):
506 assert len(inputs) == 2
507 a, b = inputs
508 shift = args_dict["shift"]
509
510 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000511 self.ser, self.rng, a, b, error_name
512 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700513
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100514 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100515 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100516 result_tensor.setDtype(DType.INT32)
517
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100518 if error_name == ErrorIf.WrongOutputType:
519 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
520 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100521 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100522
523 # Invalidate Input/Output list for error if checks.
524 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100525 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100526 pCount, cCount = op["operands"]
527 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000528 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
529 self, error_name, input_list, output_list
530 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100531
Les Bell729b0352021-11-24 10:28:21 +0000532 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100533 self.ser,
534 validator_fcns,
535 error_name,
536 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000537 input1=a,
538 input2=b,
539 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100540 output_dtype=result_tensor.dtype,
541 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100542 input_list=input_list,
543 output_list=output_list,
544 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000545 ):
546 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700547
Kevin Chengaee1fac2020-11-11 13:54:06 -0800548 attr = ts.TosaSerializerAttribute()
549 attr.MulAttribute(shift)
550
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000551 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100552
553 compliance = self.tensorComplianceMetaData(
554 op, a.dtype, args_dict, result_tensor, error_name
555 )
556
557 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700558
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100559 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
560 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700561
Kevin Chengfe392ce2021-10-18 21:51:55 +0000562 attr = ts.TosaSerializerAttribute()
563 attr.TableAttribute(table)
564
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100565 # Invalidate Input/Output list for error if checks.
566 input_list = [a.name]
567 output_list = [result_tens.name]
568 pCount, cCount = op["operands"]
569 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000570 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
571 self, error_name, input_list, output_list
572 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100573
Les Bell729b0352021-11-24 10:28:21 +0000574 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100575 self.ser,
576 validator_fcns,
577 error_name,
578 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000579 input_shape=a.shape,
580 input_dtype=a.dtype,
581 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000582 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100583 input_list=input_list,
584 output_list=output_list,
585 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000586 ):
587 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100588
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000589 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700590
591 return result_tens
592
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100593 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
594 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
595
596 # Invalidate Input/Output list for error if checks.
597 input_list = [cond.name, a.name, b.name]
598 output_list = [result_tens.name]
599 pCount, cCount = op["operands"]
600 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000601 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
602 self, error_name, input_list, output_list
603 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100604
Les Bell729b0352021-11-24 10:28:21 +0000605 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100606 self.ser,
607 validator_fcns,
608 error_name,
609 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000610 input1=cond,
611 input2=a,
612 input3=b,
613 input_shape=a.shape,
614 input_dtype=a.dtype,
615 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000616 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100617 input_list=input_list,
618 output_list=output_list,
619 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000620 ):
621 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000623 self.ser.addOperator(
624 op["op"],
625 input_list,
626 output_list,
627 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700628 return result_tens
629
Jeremy Johnsona0150012023-11-15 15:52:06 +0000630 def build_comparison(
631 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
632 ):
633 assert len(inputs) == 2
634 a, b = inputs
635
636 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000637 self.ser, self.rng, a, b, error_name
638 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100639
640 # Invalidate Input/Output list for error if checks.
641 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000642 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100643 pCount, cCount = op["operands"]
644 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000645 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
646 self, error_name, input_list, output_list
647 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100648
Les Bell729b0352021-11-24 10:28:21 +0000649 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100650 self.ser,
651 validator_fcns,
652 error_name,
653 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000654 input1=a,
655 input2=b,
656 input_shape=a.shape,
657 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000658 output_shape=result_tensor.shape,
659 output_dtype=result_tensor.dtype,
660 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100661 input_list=input_list,
662 output_list=output_list,
663 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000664 ):
665 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100666
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 self.ser.addOperator(
668 op["op"],
669 input_list,
670 output_list,
671 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000672
673 compliance = self.tensorComplianceMetaData(
674 op, a.dtype, args_dict, result_tensor, error_name
675 )
676 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700677
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000678 def build_argmax(
679 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
680 ):
681 assert len(inputs) == 1
682 a = inputs[0]
683 axis = args_dict["axis"]
684 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100685
686 # Invalidate Input/Output list for error if checks.
687 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000688 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100689 pCount, cCount = op["operands"]
690 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000691 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
692 self, error_name, input_list, output_list
693 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100694
Les Bell729b0352021-11-24 10:28:21 +0000695 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100696 self.ser,
697 validator_fcns,
698 error_name,
699 op=op,
700 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000701 input_shape=a.shape,
702 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000703 output_shape=result_tensor.shape,
704 output_dtype=result_tensor.dtype,
705 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100706 input_list=input_list,
707 output_list=output_list,
708 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000709 ):
710 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700711
712 attr = ts.TosaSerializerAttribute()
713 attr.AxisAttribute(axis)
714
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000715 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000716
717 compliance = self.tensorComplianceMetaData(
718 op, inputs[0].dtype, args_dict, result_tensor, error_name
719 )
720 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700721
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000722 def build_pool2d(
723 self,
724 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100725 inputs,
726 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000727 validator_fcns=None,
728 error_name=None,
729 qinfo=None,
730 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100731 assert len(inputs) == 1
732 input = inputs[0]
733 # max_pool has no accum_dtype
734 accum_dtype = (
735 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
736 )
737 stride = args_dict["stride"]
738 pad = args_dict["pad"]
739 kernel = args_dict["kernel"]
740
Jeremy Johnson0601f802023-11-08 16:28:09 +0000741 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000742 self.ser, self.rng, input, kernel, stride, pad, error_name
743 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100744
745 # Ensure new output type has correct qinfo
746 if error_name == ErrorIf.WrongInputType:
747 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000748 qinfo = [
749 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000750 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000751 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100752
753 # Invalidate Input/Output list for error if checks.
754 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000755 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100756 pCount, cCount = op["operands"]
757 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000758 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
759 self, error_name, input_list, output_list
760 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100761
Les Bell729b0352021-11-24 10:28:21 +0000762 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100763 self.ser,
764 validator_fcns,
765 error_name,
766 op=op,
767 input_shape=input.shape,
768 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000769 output_shape=result_tensor.shape,
770 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100771 kernel=kernel,
772 stride=stride,
773 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000774 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000775 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100776 input_list=input_list,
777 output_list=output_list,
778 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000779 ):
780 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700781
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000782 if qinfo is None:
783 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700784
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000785 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100786 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000787
788 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700789
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100790 compliance = self.tensorComplianceMetaData(
791 op, inputs[0].dtype, args_dict, result_tensor, error_name
792 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100793
794 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100795
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000796 def build_conv2d(
797 self,
798 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100799 inputs,
800 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000801 validator_fcns=None,
802 error_name=None,
803 qinfo=None,
804 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100805 assert len(inputs) == 3
806 ifm, filter, bias = inputs
807 accum_dtype = args_dict["acc_type"]
808 strides = args_dict["stride"]
809 padding = args_dict["pad"]
810 dilations = args_dict["dilation"]
811
Kevin Cheng550ccc52021-03-03 11:21:43 -0800812 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100813 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100814 self.ser,
815 self.rng,
816 ifm,
817 filter,
818 accum_dtype,
819 strides,
820 padding,
821 dilations,
822 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000823 )
824
825 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000826 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
827 DType.INT8,
828 DType.UINT8,
829 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000830 qinfo = [
831 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100832 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000833 ]
Les Bell0e027d42021-11-09 14:42:14 +0000834
835 # Invalidate Input/Output list for error_if checks.
836 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100837 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000838 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000839 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
840 self, error_name, input_list, output_list
841 )
Les Bell0e027d42021-11-09 14:42:14 +0000842
Les Bell729b0352021-11-24 10:28:21 +0000843 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000844 self.ser,
845 validator_fcns,
846 error_name,
847 op=op,
848 input_dtype=ifm.dtype,
849 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100850 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000851 qinfo=qinfo,
852 input_list=input_list,
853 num_operands=num_operands,
854 output_list=output_list,
855 pad=padding,
856 stride=strides,
857 dilation=dilations,
858 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100859 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100860 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000861 ):
862 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700863
Tai Lyd3797f02023-11-15 23:06:19 +0000864 # TODO - Test local_bound, for now set local bound attribute to False
865 local_bound = False
866
Eric Kunzee5e26762020-10-13 16:11:07 -0700867 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000868 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700869
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000870 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100871
872 compliance = self.tensorComplianceMetaData(
873 op, ifm.dtype, args_dict, result_tensor, error_name
874 )
875
876 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700877
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000878 def build_conv3d(
879 self,
880 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100881 inputs,
882 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000883 validator_fcns=None,
884 error_name=None,
885 qinfo=None,
886 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100887 assert len(inputs) == 3
888 ifm, filter, bias = inputs
889 accum_dtype = args_dict["acc_type"]
890 strides = args_dict["stride"]
891 padding = args_dict["pad"]
892 dilations = args_dict["dilation"]
893
Kevin Cheng1533b852021-09-01 12:51:58 -0700894 assert len(padding) == 6
895 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100896 self.ser,
897 self.rng,
898 ifm,
899 filter,
900 accum_dtype,
901 strides,
902 padding,
903 dilations,
904 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000905 )
906
907 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000908 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
909 DType.INT8,
910 DType.UINT8,
911 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000912 qinfo = [
913 TosaQuantGen.getZeroPoint(self, ifm.dtype),
914 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
915 ]
Les Bell0e027d42021-11-09 14:42:14 +0000916
917 # Invalidate Input/Output list for error_if checks.
918 input_list = [ifm.name, filter.name, bias.name]
919 output_list = [result_tens.name]
920 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000921 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
922 self, error_name, input_list, output_list
923 )
Les Bell0e027d42021-11-09 14:42:14 +0000924
Les Bell729b0352021-11-24 10:28:21 +0000925 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000926 self.ser,
927 validator_fcns,
928 error_name,
929 op=op,
930 input_dtype=ifm.dtype,
931 weight_dtype=filter.dtype,
932 output_dtype=result_tens.dtype,
933 qinfo=qinfo,
934 input_list=input_list,
935 num_operands=num_operands,
936 output_list=output_list,
937 pad=padding,
938 stride=strides,
939 dilation=dilations,
940 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100941 weight_shape=filter.shape,
942 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000943 ):
944 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700945
Tai Lyd3797f02023-11-15 23:06:19 +0000946 # TODO - Test local_bound, for now set local bound attribute to False
947 local_bound = False
948
Kevin Cheng1533b852021-09-01 12:51:58 -0700949 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000950 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700951
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000952 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700953 return result_tens
954
Kevin Cheng550ccc52021-03-03 11:21:43 -0800955 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000956 self,
957 op,
958 ifm,
959 filter,
960 bias,
James Ward8b390432022-08-12 20:48:56 +0100961 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000962 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700963 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000964 output_shape,
965 validator_fcns=None,
966 error_name=None,
967 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800968 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700969 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000970 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100971 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000972 )
Les Bell0e027d42021-11-09 14:42:14 +0000973
974 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000975 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
976 DType.INT8,
977 DType.UINT8,
978 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000979 qinfo = [
980 TosaQuantGen.getZeroPoint(self, ifm.dtype),
981 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
982 ]
Les Bell0e027d42021-11-09 14:42:14 +0000983
984 # Invalidate Input/Output list for error_if checks.
985 input_list = [ifm.name, filter.name, bias.name]
986 output_list = [result_tens.name]
987 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000988 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
989 self, error_name, input_list, output_list
990 )
Les Bell0e027d42021-11-09 14:42:14 +0000991
Les Bell729b0352021-11-24 10:28:21 +0000992 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000993 self.ser,
994 validator_fcns,
995 error_name,
996 op=op,
997 input_dtype=ifm.dtype,
998 weight_dtype=filter.dtype,
999 output_dtype=result_tens.dtype,
1000 qinfo=qinfo,
1001 input_list=input_list,
1002 num_operands=num_operands,
1003 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001004 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001005 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001006 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001007 weight_shape=filter.shape,
1008 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001009 ):
1010 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001011
Tai Lyd3797f02023-11-15 23:06:19 +00001012 # TODO - Test local_bound, for now set local bound attribute to False
1013 local_bound = False
1014
Eric Kunzee5e26762020-10-13 16:11:07 -07001015 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001016 attr.TransposeConvAttribute(
1017 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1018 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001019
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001020 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001021 return result_tens
1022
Kevin Cheng550ccc52021-03-03 11:21:43 -08001023 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001024 self,
1025 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001026 inputs,
1027 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001028 validator_fcns=None,
1029 error_name=None,
1030 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001031 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001032 assert len(inputs) == 3
1033 ifm, filter, bias = inputs
1034 accum_dtype = args_dict["acc_type"]
1035 strides = args_dict["stride"]
1036 padding = args_dict["pad"]
1037 dilations = args_dict["dilation"]
1038
Kevin Cheng550ccc52021-03-03 11:21:43 -08001039 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001040 self.ser,
1041 self.rng,
1042 ifm,
1043 filter,
1044 accum_dtype,
1045 strides,
1046 padding,
1047 dilations,
1048 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001049 )
1050
1051 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001052 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1053 DType.INT8,
1054 DType.UINT8,
1055 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001056 qinfo = [
1057 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1058 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1059 ]
Les Bell0e027d42021-11-09 14:42:14 +00001060
1061 # Invalidate Input/Output list for error_if checks.
1062 input_list = [ifm.name, filter.name, bias.name]
1063 output_list = [result_tens.name]
1064 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001065 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1066 self, error_name, input_list, output_list
1067 )
Les Bell0e027d42021-11-09 14:42:14 +00001068
Les Bell729b0352021-11-24 10:28:21 +00001069 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001070 self.ser,
1071 validator_fcns,
1072 error_name,
1073 op=op,
1074 input_dtype=ifm.dtype,
1075 weight_dtype=filter.dtype,
1076 output_dtype=result_tens.dtype,
1077 qinfo=qinfo,
1078 input_list=input_list,
1079 num_operands=num_operands,
1080 output_list=output_list,
1081 pad=padding,
1082 stride=strides,
1083 dilation=dilations,
1084 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001085 weight_shape=filter.shape,
1086 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001087 ):
1088 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001089
Tai Lyd3797f02023-11-15 23:06:19 +00001090 # TODO - Test local_bound, for now set local bound attribute to False
1091 local_bound = False
1092
Eric Kunzee5e26762020-10-13 16:11:07 -07001093 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001094 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001095
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001096 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001097 return result_tens
1098
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001099 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001100 self,
1101 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001102 inputs,
1103 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001104 validator_fcns=None,
1105 error_name=None,
1106 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001107 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001108 assert len(inputs) == 3
1109 ifm, filter, bias = inputs
1110 accum_dtype = args_dict["acc_type"]
1111
1112 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001113 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001114 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001115
1116 # Invalidate Input/Output list for error if checks.
1117 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001118 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001119 pCount, cCount = op["operands"]
1120 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001121 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1122 self, error_name, input_list, output_list
1123 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001124
Les Bell729b0352021-11-24 10:28:21 +00001125 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001126 self.ser,
1127 validator_fcns,
1128 error_name,
1129 op=op,
1130 input_shape=ifm.shape,
1131 input_dtype=ifm.dtype,
1132 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001133 output_shape=result_tensor.shape,
1134 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001135 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001136 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001137 input_list=input_list,
1138 output_list=output_list,
1139 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001140 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001141 ):
1142 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001143
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001144 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001145 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001146
1147 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001148
1149 compliance = self.tensorComplianceMetaData(
1150 op, ifm.dtype, args_dict, result_tensor, error_name
1151 )
1152
1153 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001154
James Ward8b390432022-08-12 20:48:56 +01001155 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001156 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001157 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001158 assert len(inputs) == 2
1159 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001160 accum_dtype = args_dict["acc_type"]
1161 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001162 self.ser, self.rng, a, b, accum_dtype, error_name
1163 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001164
1165 # Invalidate Input/Output list for error if checks.
1166 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001167 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001168 pCount, cCount = op["operands"]
1169 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001170 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1171 self, error_name, input_list, output_list
1172 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001173
Les Bell729b0352021-11-24 10:28:21 +00001174 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001175 self.ser,
1176 validator_fcns,
1177 error_name,
1178 op=op,
1179 input_shape=a.shape,
1180 input_dtype=a.dtype,
1181 input2_shape=b.shape,
1182 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001183 output_shape=result_tensor.shape,
1184 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001185 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001186 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001187 input_list=input_list,
1188 output_list=output_list,
1189 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001190 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001191 ):
1192 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001193
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001194 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001195 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001196
1197 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001198
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001199 compliance = self.tensorComplianceMetaData(
1200 op, a.dtype, args_dict, result_tensor, error_name
1201 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001202
1203 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001204
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001205 def build_reduce(
1206 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1207 ):
1208 assert len(inputs) == 1
1209 a = inputs[0]
1210 axis = args_dict["axis"]
1211 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001212
1213 # Invalidate Input/Output list for error if checks.
1214 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001215 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001216 pCount, cCount = op["operands"]
1217 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001218 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1219 self, error_name, input_list, output_list
1220 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001221
Les Bell729b0352021-11-24 10:28:21 +00001222 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001223 self.ser,
1224 validator_fcns,
1225 error_name,
1226 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001227 axis=axis,
1228 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001229 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001230 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001231 output_dtype=result_tensor.dtype,
1232 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001233 input_list=input_list,
1234 output_list=output_list,
1235 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001236 ):
1237 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001238
1239 attr = ts.TosaSerializerAttribute()
1240 attr.AxisAttribute(axis)
1241
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001242 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001243
1244 if op["op"] == Op.REDUCE_PRODUCT:
1245 # TODO: Add compliance support!
1246 compliance = None
1247 else:
1248 compliance = self.tensorComplianceMetaData(
1249 op, a.dtype, args_dict, result_tensor, error_name
1250 )
1251
1252 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001253
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001254 def build_clamp(
1255 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1256 ):
1257 assert len(inputs) == 1
1258 a = inputs[0]
1259
1260 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001261
Jeremy Johnson18e26662021-07-22 16:15:29 +01001262 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001263
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001264 if error_name == ErrorIf.MaxSmallerMin:
1265 # Make sure the numbers are different to invoke this error
1266 while v[0] == v[1]:
1267 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1268 max_val = min(v)
1269 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001270 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001271 max_val = max(v)
1272 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001273
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001274 # Invalidate Input/Output list for error if checks.
1275 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001276 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001277 pCount, cCount = op["operands"]
1278 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001279 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1280 self, error_name, input_list, output_list
1281 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001282
Les Bell729b0352021-11-24 10:28:21 +00001283 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001284 self.ser,
1285 validator_fcns,
1286 error_name,
1287 op=op,
1288 max_val=max_val,
1289 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001290 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001291 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001292 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001293 output_dtype=result_tensor.dtype,
1294 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001295 input_list=input_list,
1296 output_list=output_list,
1297 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001298 ):
1299 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001300
1301 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001302 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1303 if a.dtype == DType.FP16:
1304 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1305 min_val = min_val.astype(np.float32)
1306 max_val = max_val.astype(np.float32)
1307
1308 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001309 else:
James Ward34071252022-12-07 15:48:47 +00001310 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001311
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001312 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001313
1314 compliance = self.tensorComplianceMetaData(
1315 op, a.dtype, args_dict, result_tensor, error_name
1316 )
1317
1318 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001319
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001320 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1321 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001322 attr = ts.TosaSerializerAttribute()
1323
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001324 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001325
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001326 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001327 return result_tens
1328
1329 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001330 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1331 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001332
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001333 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001334 return result_tens
1335
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001336 def build_activation(
1337 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1338 ):
1339 assert len(inputs) == 1
1340 a = inputs[0]
1341
1342 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001343
1344 # Invalidate Input/Output list for error if checks.
1345 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001346 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001347 pCount, cCount = op["operands"]
1348 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001349 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1350 self, error_name, input_list, output_list
1351 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001352
Les Bell729b0352021-11-24 10:28:21 +00001353 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001354 self.ser,
1355 validator_fcns,
1356 error_name,
1357 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001358 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001359 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001360 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001361 output_dtype=result_tensor.dtype,
1362 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001363 input_list=input_list,
1364 output_list=output_list,
1365 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001366 ):
1367 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001368
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001369 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001370
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001371 compliance = self.tensorComplianceMetaData(
1372 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001373 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001375 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001376
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001377 def build_concat(
1378 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1379 ):
1380 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001381 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001382 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001383
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001384 result_tensor = OutputShaper.concatOp(
1385 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001386 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001387
Matthew Haddon818ab902021-07-27 09:12:49 +01001388 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001389 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001390 input_tensor_names.append(tensor.name)
1391
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001392 # Invalidate Input/Output list for error if checks.
1393 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001394 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001395 pCount, cCount = op["operands"]
1396 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001397 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1398 self, error_name, input_list, output_list
1399 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001400
Les Bell729b0352021-11-24 10:28:21 +00001401 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001402 self.ser,
1403 validator_fcns,
1404 error_name,
1405 op=op,
1406 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001407 input_shape=inputs[0].shape,
1408 output_shape=result_tensor.shape,
1409 input_dtype=inputs[0].dtype,
1410 output_dtype=result_tensor.dtype,
1411 inputs=inputs,
1412 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001413 input_list=input_list,
1414 output_list=output_list,
1415 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001416 ):
1417 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001418
1419 attr = ts.TosaSerializerAttribute()
1420 attr.AxisAttribute(axis)
1421
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001422 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001423 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001424
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001425 def build_pad(
1426 self,
1427 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001428 inputs,
1429 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001430 validator_fcns=None,
1431 error_name=None,
1432 qinfo=None,
1433 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001434 assert len(inputs) == 1
1435 a = inputs[0]
1436 padding = args_dict["pad"]
1437 pad_const_int = args_dict["pad_const_int"]
1438 pad_const_float = args_dict["pad_const_fp"]
1439
1440 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001441
Kevin Chengfe392ce2021-10-18 21:51:55 +00001442 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001443 attr.PadAttribute(
1444 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1445 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001446
Matthew Haddone807aae2021-10-11 18:12:58 +01001447 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001448 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001449 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001450 pCount, cCount = op["operands"]
1451 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001452 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1453 self, error_name, input_list, output_list
1454 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001455
Les Bell729b0352021-11-24 10:28:21 +00001456 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001457 self.ser,
1458 validator_fcns,
1459 error_name,
1460 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001461 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001462 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001463 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001464 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001465 pad=padding,
1466 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001467 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001468 input_list=input_list,
1469 output_list=output_list,
1470 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001471 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001472 ):
1473 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001474
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001475 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001476
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001477 compliance = self.tensorComplianceMetaData(
1478 op, a.dtype, args_dict, result_tensor, error_name
1479 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001480
1481 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001482
Won Jeona21b2e82023-08-10 10:33:01 +00001483 def build_dim(
1484 self,
1485 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001486 inputs,
1487 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001488 validator_fcns=None,
1489 error_name=None,
1490 qinfo=None,
1491 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001492 assert len(inputs) == 1
1493 a = inputs[0]
1494 axis = args_dict["axis"]
1495 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001496
1497 # Invalidate Input/Output list for error if checks.
1498 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001499 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001500 pCount, cCount = op["operands"]
1501 num_operands = pCount + cCount
1502 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1503 self, error_name, input_list, output_list
1504 )
1505
1506 if not TosaErrorValidator.evValidateErrorIfs(
1507 self.ser,
1508 validator_fcns,
1509 error_name,
1510 op=op,
1511 axis=axis,
1512 input_shape=a.shape,
1513 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001514 output_shape=result_tensor.shape,
1515 output_dtype=result_tensor.dtype,
1516 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001517 input_list=input_list,
1518 output_list=output_list,
1519 num_operands=num_operands,
1520 ):
1521 return None
1522
1523 attr = ts.TosaSerializerAttribute()
1524 attr.AxisAttribute(axis)
1525
1526 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001527 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001528
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001529 def build_reshape(
1530 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1531 ):
1532 assert len(inputs) == 1
1533 a = inputs[0]
1534 new_shape = args_dict["new_shape"]
1535 result_tensor = OutputShaper.reshapeOp(
1536 self.ser, self.rng, a, new_shape, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001537 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001538
1539 # Invalidate Input/Output list for error if checks.
1540 input_list = [a.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001541 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001542 pCount, cCount = op["operands"]
1543 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001544 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1545 self, error_name, input_list, output_list
1546 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001547
Les Bell729b0352021-11-24 10:28:21 +00001548 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001549 self.ser,
1550 validator_fcns,
1551 error_name,
1552 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001553 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001554 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001555 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001556 output_dtype=result_tensor.dtype,
1557 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001558 input_list=input_list,
1559 output_list=output_list,
1560 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001561 ):
1562 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001563
1564 attr = ts.TosaSerializerAttribute()
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001565 attr.ReshapeAttribute(new_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07001566
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001567 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001568
1569 compliance = self.tensorComplianceMetaData(
1570 op, a.dtype, args_dict, result_tensor, error_name
1571 )
1572
1573 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001574
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001575 def build_reverse(
1576 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1577 ):
1578 assert len(inputs) == 1
1579 a = inputs[0]
1580 axis = args_dict["axis"]
1581 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001582
1583 # Invalidate Input/Output list for error if checks.
1584 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001585 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001586 pCount, cCount = op["operands"]
1587 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001588 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1589 self, error_name, input_list, output_list
1590 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001591
Les Bell729b0352021-11-24 10:28:21 +00001592 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001593 self.ser,
1594 validator_fcns,
1595 error_name,
1596 op=op,
1597 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001598 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001599 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001600 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001601 output_dtype=result_tensor.dtype,
1602 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001603 input_list=input_list,
1604 output_list=output_list,
1605 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001606 ):
1607 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001608
1609 attr = ts.TosaSerializerAttribute()
1610 attr.AxisAttribute(axis)
1611
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001612 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001613 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001614
Matthew Haddone807aae2021-10-11 18:12:58 +01001615 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1616 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001617
Kevin Chengfe392ce2021-10-18 21:51:55 +00001618 attr = ts.TosaSerializerAttribute()
1619 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001620
Matthew Haddone807aae2021-10-11 18:12:58 +01001621 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001622 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001623 output_list = [result_tens.name]
1624 pCount, cCount = op["operands"]
1625 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001626 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1627 self, error_name, input_list, output_list
1628 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001629
Les Bell729b0352021-11-24 10:28:21 +00001630 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001631 self.ser,
1632 validator_fcns,
1633 error_name,
1634 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001635 input_shape=a.shape,
1636 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001637 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001638 input_dtype=a.dtype,
1639 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001640 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001641 input_list=input_list,
1642 output_list=output_list,
1643 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001644 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001645 ):
1646 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001647
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001648 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001649 return result_tens
1650
Matthew Haddone807aae2021-10-11 18:12:58 +01001651 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001652 result_tens = OutputShaper.sliceOp(
1653 self.ser, self.rng, a, start, size, error_name
1654 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001655
1656 # Invalidate Input/Output list for error if checks.
1657 input_list = [a.name]
1658 output_list = [result_tens.name]
1659 pCount, cCount = op["operands"]
1660 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001661 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1662 self, error_name, input_list, output_list
1663 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001664
Les Bell729b0352021-11-24 10:28:21 +00001665 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001666 self.ser,
1667 validator_fcns,
1668 error_name,
1669 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001670 input_shape=a.shape,
1671 output_shape=result_tens.shape,
1672 input_dtype=a.dtype,
1673 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001674 start=start,
1675 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001676 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001677 input_list=input_list,
1678 output_list=output_list,
1679 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001680 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001681 ):
1682 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001683
1684 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001685 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001686
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001687 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001688 return result_tens
1689
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001690 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1691 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1692
1693 # Invalidate Input/Output list for error if checks.
1694 input_list = [a.name]
1695 output_list = [result_tens.name]
1696 pCount, cCount = op["operands"]
1697 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001698 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1699 self, error_name, input_list, output_list
1700 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001701
Les Bell729b0352021-11-24 10:28:21 +00001702 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001703 self.ser,
1704 validator_fcns,
1705 error_name,
1706 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001707 input_shape=a.shape,
1708 output_shape=result_tens.shape,
1709 input_dtype=a.dtype,
1710 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001711 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001712 input_list=input_list,
1713 output_list=output_list,
1714 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001715 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001716 ):
1717 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001718
1719 attr = ts.TosaSerializerAttribute()
1720 attr.TileAttribute(multiples)
1721
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001722 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001723 return result_tens
1724
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001725 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001726
1727 # Create a new indicies tensor
1728 # here with data that doesn't exceed the dimensions of the values tensor
1729
Kevin Cheng550ccc52021-03-03 11:21:43 -08001730 K = values.shape[1] # K
1731 W = self.randInt(
1732 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1733 ) # W
1734 indicies_arr = np.int32(
1735 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1736 ) # (N, W)
1737 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001738
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001739 result_tens = OutputShaper.gatherOp(
1740 self.ser, self.rng, values, indicies, error_name
1741 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001742
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001743 # Invalidate Input/Output list for error if checks.
1744 input_list = [values.name, indicies.name]
1745 output_list = [result_tens.name]
1746 pCount, cCount = op["operands"]
1747 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001748 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1749 self, error_name, input_list, output_list
1750 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001751
Les Bell729b0352021-11-24 10:28:21 +00001752 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001753 self.ser,
1754 validator_fcns,
1755 error_name,
1756 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001757 input_shape=values.shape,
1758 output_shape=result_tens.shape,
1759 input_dtype=values.dtype,
1760 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001761 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001762 input_list=input_list,
1763 output_list=output_list,
1764 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001765 ):
1766 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001767
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001768 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001769
1770 return result_tens
1771
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001772 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001773
Kevin Cheng550ccc52021-03-03 11:21:43 -08001774 K = values_in.shape[1] # K
1775 W = input.shape[1] # W
Jeremy Johnson194fe312023-12-07 14:17:57 +00001776
1777 # Create an indices tensor here with data that doesn't exceed the
1778 # dimension K of the values_in tensor and does NOT repeat the same K
1779 # location as needed by the spec:
1780 # "It is not permitted to repeat the same output index within a single
1781 # SCATTER operation and so each output index occurs at most once."
1782 assert K >= W
1783 arr = []
1784 for n in range(values_in.shape[0]):
1785 # Get a shuffled list of output indices and limit it to size W
1786 arr.append(self.rng.permutation(K)[:W])
1787 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1788 indices = self.ser.addConst(indices_arr.shape, DType.INT32, indices_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001789
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001790 result_tens = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001791 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001792 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001793
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001794 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001795 input_list = [values_in.name, indices.name, input.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001796 output_list = [result_tens.name]
1797 pCount, cCount = op["operands"]
1798 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001799 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1800 self, error_name, input_list, output_list
1801 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001802
Les Bell729b0352021-11-24 10:28:21 +00001803 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001804 self.ser,
1805 validator_fcns,
1806 error_name,
1807 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001808 input_shape=values_in.shape,
1809 output_shape=result_tens.shape,
1810 input_dtype=values_in.dtype,
1811 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001812 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001813 input_list=input_list,
1814 output_list=output_list,
1815 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001816 ):
1817 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001818
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001819 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001820
Kevin Cheng77d0f762020-11-24 10:26:32 -08001821 return result_tens
1822
Kevin Cheng550ccc52021-03-03 11:21:43 -08001823 def build_resize(
1824 self,
1825 op,
1826 input,
1827 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001828 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001829 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001830 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001831 input_dtype,
1832 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001833 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001834 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001835 ):
1836 result_tens = OutputShaper.resizeOp(
1837 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001838 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001839 input,
1840 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001841 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001842 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001843 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001844 input_dtype,
1845 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001846 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001847 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001848
Matthew Haddon848efb42021-09-09 12:30:53 +01001849 # Invalidate Input/Output list for error if checks.
1850 input_list = [input.name]
1851 output_list = [result_tens.name]
1852 pCount, cCount = op["operands"]
1853 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001854 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1855 self, error_name, input_list, output_list
1856 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001857
Les Bell729b0352021-11-24 10:28:21 +00001858 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001859 self.ser,
1860 validator_fcns,
1861 error_name,
1862 op=op,
1863 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001864 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001865 input_dtype=input_dtype,
1866 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001867 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001868 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001869 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001870 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001871 input_list=input_list,
1872 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001873 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001874 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001875 ):
1876 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001877
Eric Kunzee5e26762020-10-13 16:11:07 -07001878 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001879
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001880 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001881
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001882 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001883 return result_tens
1884
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001885 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1886 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1887 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001888 self.ser.addOperator(
1889 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1890 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001891 return result_tens
1892
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001893 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001894 self.ser.addOutputTensor(val)
1895 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001896
1897 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00001898 def build_cast(
1899 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1900 ):
1901 assert len(inputs) == 1
1902 val = inputs[0]
1903 out_dtype = args_dict["out_type"]
1904
1905 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001906 self.ser, self.rng, val, out_dtype, error_name
1907 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001908
1909 # Invalidate Input/Output list for error if checks.
1910 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00001911 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001912 pCount, cCount = op["operands"]
1913 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001914 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1915 self, error_name, input_list, output_list
1916 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001917
Les Bell729b0352021-11-24 10:28:21 +00001918 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001919 self.ser,
1920 validator_fcns,
1921 error_name,
1922 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001923 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00001924 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001925 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00001926 output_dtype=result_tensor.dtype,
1927 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001928 input_list=input_list,
1929 output_list=output_list,
1930 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001931 ):
1932 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001933
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001934 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00001935
1936 compliance = self.tensorComplianceMetaData(
1937 op, val.dtype, args_dict, result_tensor, error_name
1938 )
1939
1940 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001941
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001942 def build_rescale(
1943 self,
1944 op,
1945 val,
1946 out_dtype,
1947 scale32,
1948 double_round,
1949 per_channel,
1950 validator_fcns,
1951 error_name,
1952 ):
1953 result_tens = OutputShaper.typeConversionOp(
1954 self.ser, self.rng, val, out_dtype, error_name
1955 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001956
1957 if per_channel:
1958 nc = val.shape[-1]
1959 else:
1960 nc = 1
1961
1962 in_type_width = self.typeWidth(val.dtype)
1963 out_type_width = self.typeWidth(out_dtype)
1964
Kevin Cheng3a478572021-01-22 17:21:02 -08001965 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001966 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001967 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001968 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001969 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001970 in_type_width += 1
1971 elif error_name in [
1972 ErrorIf.InputZeroPointNotZero,
1973 ErrorIf.U16InputZeroPointNotValid,
1974 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001975 input_zp = self.randInt(-128, 128)
1976 if input_zp == 0:
1977 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001978 in_type_width += 1
1979 elif val.dtype == DType.UINT16:
1980 # Must come after ErrorIf.U16InputZeroPointNotValid check
1981 input_zp = self.rng.choice([0, 32768])
1982 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001983 else:
1984 input_zp = 0
1985
Kevin Cheng3a478572021-01-22 17:21:02 -08001986 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001987 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001988 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001989 elif out_dtype == DType.UINT8:
1990 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001991 out_type_width += 1
1992 elif error_name in [
1993 ErrorIf.OutputZeroPointNotZero,
1994 ErrorIf.U16OutputZeroPointNotValid,
1995 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001996 output_zp = self.randInt(-128, 128)
1997 if output_zp == 0:
1998 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001999 out_type_width += 1
2000 elif out_dtype == DType.UINT16:
2001 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2002 output_zp = self.rng.choice([0, 32768])
2003 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002004 else:
2005 output_zp = 0
2006
2007 # Calculate scale based on:
2008 # scale = a *(2^output_width)/(2^input_width))
2009
2010 a = np.float32(self.rng.random(size=[nc]))
2011 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2012
2013 if scale32:
2014 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002015 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002016 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2017 else:
2018 # Cap the scaling at 2^15 - 1 for scale16
2019 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2020
Kevin Cheng550ccc52021-03-03 11:21:43 -08002021 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002022
2023 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2024 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002025 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2026 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002027
2028 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002029 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2030 scale_arr[i], scale32
2031 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002032 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2033 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002034
Kevin Cheng550ccc52021-03-03 11:21:43 -08002035 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002036 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002037 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002038 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002039 assert val.placeholderFilename
2040 values = np.load(
2041 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2042 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002043 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2044 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2045 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2046 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002047 if not np.all(np.array_equal(values, val_adj)):
2048 # Values changed so overwrite file with new values
2049 np.save(
2050 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2051 val_adj,
2052 False,
2053 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002054
Matthew Haddonc2025212021-10-08 21:21:05 +01002055 # Invalidate Input/Output list for error if checks.
2056 input_list = [val.name]
2057 output_list = [result_tens.name]
2058 pCount, cCount = op["operands"]
2059 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002060 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2061 self, error_name, input_list, output_list
2062 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002063
2064 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002065 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002066 self.ser,
2067 validator_fcns,
2068 error_name,
2069 op=op,
2070 input_dtype=val.dtype,
2071 output_dtype=out_dtype,
2072 input_shape=val.shape,
2073 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002074 scale32=scale32,
2075 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002076 input_list=input_list,
2077 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002078 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002079 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002080 ):
2081 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002082
Eric Kunzee5e26762020-10-13 16:11:07 -07002083 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002084 attr.RescaleAttribute(
2085 input_zp,
2086 output_zp,
2087 multiplier_arr,
2088 shift_arr,
2089 scale32,
2090 double_round,
2091 per_channel,
2092 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002093
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002094 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002095 return result_tens
2096
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002097 def _get_condition_tensor(self, op, cond, error_name):
2098 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002099 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002100 else:
2101 cond_type = DType.BOOL
2102 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2103 choice = self.rng.choice([1, 2])
2104 if choice == 1:
2105 cond_shape = [2]
2106 else:
2107 cond_shape = [1, 2]
2108 else:
2109 # Must be of size 1 (rank 0)
2110 cond_shape = []
2111 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2112 return cond_tens
2113
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002114 def build_cond_if_const(
2115 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2116 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002117 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002118 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002119 # and fill them with const nodes for the body.
2120
2121 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002122 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002123
2124 # Make then/else tensors
2125 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002126
2127 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002128 if error_name in [
2129 ErrorIf.CondIfOutputListThenGraphMismatch,
2130 ErrorIf.CondIfOutputListElseGraphMismatch,
2131 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002132 incorrect_shape = deepcopy(then_tens.shape)
2133 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002134 incorrect_shape[i] += (
2135 self.rng.choice([-3, -2, 2, 3])
2136 if incorrect_shape[i] > 3
2137 else self.rng.choice([1, 2, 4])
2138 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002139 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2140
Jeremy Johnson18e26662021-07-22 16:15:29 +01002141 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2142 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002143
2144 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002145 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002146
2147 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002148 then_block = "THEN_BLOCK"
2149 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002150 attr = ts.TosaSerializerAttribute()
2151 attr.CondIfAttribute(then_block, else_block)
2152
2153 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002154 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002155
Jerry Ge9e94af82022-10-27 09:57:00 -07002156 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002157 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002158 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2159 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2160 else:
2161 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002162 self.ser.addOutputTensor(then_tens)
2163
Jerry Ge9e94af82022-10-27 09:57:00 -07002164 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002165 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2166 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2167 else:
2168 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002169 self.ser.addOutputTensor(else_tens)
2170
Les Bell729b0352021-11-24 10:28:21 +00002171 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002172 self.ser,
2173 validator_fcns,
2174 error_name,
2175 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002176 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002177 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002178 ):
2179 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002180
Eric Kunzee5e26762020-10-13 16:11:07 -07002181 return result_tens
2182
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002183 def build_cond_if_binary(
2184 self, op, a, b, cond, validator_fcns=None, error_name=None
2185 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002186 # For cond_if with a binary op in the then/else blocks, take a and b and
2187 # alternately add or subtract them based on the condition
2188
2189 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002190 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002191
Kevin Cheng550ccc52021-03-03 11:21:43 -08002192 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002193
2194 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002195 then_block = "THEN_BLOCK"
2196 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002197 attr = ts.TosaSerializerAttribute()
2198 attr.CondIfAttribute(then_block, else_block)
2199
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002200 if error_name in [
2201 ErrorIf.CondIfInputListThenGraphMismatch,
2202 ErrorIf.CondIfInputListElseGraphMismatch,
2203 ErrorIf.CondIfOutputListElseGraphMismatch,
2204 ErrorIf.CondIfOutputListThenGraphMismatch,
2205 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002206 incorrect_shape = a.shape.copy()
2207 for i in range(len(incorrect_shape)):
2208 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2209 incorrect_block_input = deepcopy(a)
2210 incorrect_block_input.shape = incorrect_shape
2211
Eric Kunzee5e26762020-10-13 16:11:07 -07002212 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002213 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002214 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002215 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002216
James Ward24dbc422022-10-19 12:20:31 +01002217 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002218 then_op, else_op = Op.ADD, Op.SUB
2219 elif a.dtype in (DType.INT8, DType.INT16):
2220 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2221 else:
2222 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002223
Les Bell6040b4d2021-10-11 12:50:31 +01002224 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002225 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002226 if (
2227 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2228 and block == then_block
2229 ) or (
2230 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2231 and block == else_block
2232 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002233 self.ser.addInputTensor(incorrect_block_input)
2234 self.ser.addInputTensor(b)
2235 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002236 elif (
2237 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2238 and block == then_block
2239 ) or (
2240 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2241 and block == else_block
2242 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002243 self.ser.addInputTensor(a)
2244 self.ser.addInputTensor(b)
2245 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2246 else:
2247 self.ser.addInputTensor(a)
2248 self.ser.addInputTensor(b)
2249 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002250 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002251
Les Bell729b0352021-11-24 10:28:21 +00002252 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002253 self.ser,
2254 validator_fcns,
2255 error_name,
2256 op=op,
2257 a=a,
2258 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002259 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002260 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002261 ):
2262 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002263
Eric Kunzee5e26762020-10-13 16:11:07 -07002264 return result_tens
2265
Matthew Haddon630c17c2021-10-14 15:05:41 +01002266 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002267 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002268
Kevin Cheng550ccc52021-03-03 11:21:43 -08002269 cond_block = "COND_BLOCK"
2270 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002271
2272 attr = ts.TosaSerializerAttribute()
2273 attr.WhileLoopAttribute(cond_block, body_block)
2274
2275 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002276 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002277 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002278 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002279
2280 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002281 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2282 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002283 if error_name == ErrorIf.InputListOutputListMismatch:
2284 incorrect_acc = deepcopy(acc)
2285 for i in range(len(incorrect_acc.shape)):
2286 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2287 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2288 else:
2289 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002290
2291 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002292 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002293 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002294 [iter.name, a.name, acc.name],
2295 [iter_out.name, a_out.name, acc_out.name],
2296 attr,
2297 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002298 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002299
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002300 if error_name in [
2301 ErrorIf.InputListCondGraphMismatch,
2302 ErrorIf.InputListBodyGraphInputMismatch,
2303 ErrorIf.InputListBodyGraphOutputMismatch,
2304 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002305 incorrect_iter = deepcopy(iter)
2306 for i in range(len(incorrect_iter.shape)):
2307 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2308 if len(incorrect_iter.shape) == 0:
2309 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2310
2311 incorrect_acc = deepcopy(acc)
2312 for i in range(len(incorrect_acc.shape)):
2313 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2314
Eric Kunzee5e26762020-10-13 16:11:07 -07002315 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002316 self.ser.addBasicBlock(cond_block)
2317
Matthew Haddon630c17c2021-10-14 15:05:41 +01002318 if error_name == ErrorIf.InputListCondGraphMismatch:
2319 self.ser.addInputTensor(incorrect_iter)
2320 self.ser.addInputTensor(a)
2321 self.ser.addInputTensor(incorrect_acc)
2322 else:
2323 self.ser.addInputTensor(iter)
2324 self.ser.addInputTensor(a)
2325 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002326 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002327
2328 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002329 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002330 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002331 cond_type = DType.BOOL
2332 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2333 choice = self.rng.choice([1, 2])
2334 if choice == 1:
2335 cond_shape = [3]
2336 else:
2337 cond_shape = [1, 2]
2338 else:
2339 cond_shape = []
2340 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002341
Kevin Cheng550ccc52021-03-03 11:21:43 -08002342 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002343
2344 # BODY block (input: a, acc, iter, output: a, acc, iter)
2345 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002346 self.ser.addBasicBlock(body_block)
2347
Matthew Haddon630c17c2021-10-14 15:05:41 +01002348 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2349 self.ser.addInputTensor(incorrect_iter)
2350 self.ser.addInputTensor(a)
2351 self.ser.addInputTensor(incorrect_acc)
2352 else:
2353 self.ser.addInputTensor(iter)
2354 self.ser.addInputTensor(a)
2355 self.ser.addInputTensor(acc)
2356
Kevin Cheng550ccc52021-03-03 11:21:43 -08002357 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002358
2359 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002360 iter_body_out = self.ser.addIntermediate(
2361 incorrect_iter.shape, incorrect_iter.dtype
2362 )
2363 acc_body_out = self.ser.addIntermediate(
2364 incorrect_acc.shape, incorrect_acc.dtype
2365 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002366 else:
2367 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2368 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2369
Eric Kunzee5e26762020-10-13 16:11:07 -07002370 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2371 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2372 self.ser.addOutputTensor(iter_body_out)
2373 self.ser.addOutputTensor(a)
2374 self.ser.addOutputTensor(acc_body_out)
2375
Les Bell729b0352021-11-24 10:28:21 +00002376 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002377 self.ser,
2378 validator_fcns,
2379 error_name,
2380 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002381 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002382 ):
2383 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002384
Eric Kunzee5e26762020-10-13 16:11:07 -07002385 return acc_out
2386
Luke Hutton57287132023-02-06 14:54:18 +00002387 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002388 self,
2389 op,
2390 val1,
2391 val2,
2392 inverse,
2393 validator_fcns=None,
2394 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002395 ):
2396 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2397
2398 input_names = [val1.name, val2.name]
2399 pCount, cCount = op["operands"]
2400 num_operands = pCount + cCount
2401
2402 output_names = [res.name for res in results]
2403 output_shapes = [res.shape for res in results]
2404 output_dtypes = [res.dtype for res in results]
2405
2406 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2407 self, error_name, input_names, output_names
2408 )
2409
2410 if not TosaErrorValidator.evValidateErrorIfs(
2411 self.ser,
2412 validator_fcns,
2413 error_name,
2414 op=op,
2415 inverse=inverse,
2416 input1=val1,
2417 input2=val2,
2418 input_shape=val1.shape,
2419 input_dtype=val1.dtype,
2420 output_shape=output_shapes,
2421 output_dtype=output_dtypes,
2422 result_tensors=results,
2423 input_list=input_names,
2424 output_list=output_names,
2425 num_operands=num_operands,
2426 ):
2427 return None
2428
Tai Lyd3797f02023-11-15 23:06:19 +00002429 # TODO - Test local_bound, for now set local bound attribute to False
2430 local_bound = False
2431
Luke Hutton57287132023-02-06 14:54:18 +00002432 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002433 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002434
2435 self.ser.addOperator(op["op"], input_names, output_names, attr)
2436 return results
2437
Tai Lyd3797f02023-11-15 23:06:19 +00002438 def build_rfft2d(
2439 self,
2440 op,
2441 val,
2442 validator_fcns=None,
2443 error_name=None,
2444 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002445 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2446
2447 input_names = [val.name]
2448 pCount, cCount = op["operands"]
2449 num_operands = pCount + cCount
2450
2451 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002452 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002453 output_dtypes = [res.dtype for res in results]
2454
2455 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2456 self, error_name, input_names, output_names
2457 )
2458
2459 if not TosaErrorValidator.evValidateErrorIfs(
2460 self.ser,
2461 validator_fcns,
2462 error_name,
2463 op=op,
2464 input_shape=val.shape,
2465 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002466 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002467 output_dtype=output_dtypes,
2468 result_tensors=results,
2469 input_list=input_names,
2470 output_list=output_names,
2471 num_operands=num_operands,
2472 ):
2473 return None
2474
Tai Lyd3797f02023-11-15 23:06:19 +00002475 # TODO - Test local_bound, for now set local bound attribute to False
2476 local_bound = False
2477
2478 attr = ts.TosaSerializerAttribute()
2479 attr.RFFTAttribute(local_bound)
2480
2481 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002482 return results
2483
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002484 def create_filter_lists(
2485 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2486 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002487 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2488 default_test_rank_range = range(1, 5)
2489 if not shapeFilter:
2490 shapeFilter = [None]
2491
2492 # Calculate the filters based on what is requested and what the operator allows
2493 rmin, rmax = op["rank"]
2494 if rankFilter is not None:
2495 cleanRankFilter = []
2496 # Ensure rankFilter values are allowed by operator
2497 for rank in rankFilter:
2498 if rank >= rmin and rank <= rmax:
2499 cleanRankFilter.append(rank)
2500 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002501 # Ensure default behaviour is bounded by default range or by operator,
2502 # whichever is the smaller range of ranks.
2503 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002504 cleanRankFilter = (
2505 opRankRange
2506 if len(opRankRange) <= len(default_test_rank_range)
2507 else default_test_rank_range
2508 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002509 else:
2510 cleanRankFilter = range(rmin, rmax + 1)
2511
2512 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002513
Matthew Haddon1c00b712021-10-01 15:51:03 +01002514 if dtypeFilter is not None:
2515 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002516 # Create list of operator dtypes filtered by requested dtypes
2517 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002518 if dtype in dtypeFilter or (
2519 isinstance(dtype, list) and dtype[0] in dtypeFilter
2520 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002521 cleanDtypeFilter.append(dtype)
2522 else:
2523 cleanDtypeFilter = dtypes
2524
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002525 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002526 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002527 "shapeFilter": shapeFilter,
2528 "rankFilter": cleanRankFilter,
2529 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002530 }
2531 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002532 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002533 if validator is not None:
2534 validator_info = validator(check=False, op=op)
2535 else:
2536 return None
2537
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002538 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002539
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002540 # Set parameters as required
2541 if error_arguments["rank"] is not None:
2542 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002543 else:
2544 rankFilter = cleanRankFilter
2545
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002546 if error_arguments["dtype"] is not None:
2547 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002548 else:
2549 dtypeFilter = cleanDtypeFilter
2550
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002551 if error_arguments["shape"] is not None:
2552 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002553 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002554 shapeFilter = shapeFilter[
2555 :2
2556 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002557
2558 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002559 "shapeFilter": shapeFilter,
2560 "rankFilter": rankFilter,
2561 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002562 }
2563 return filterDict
2564
Kevin Cheng550ccc52021-03-03 11:21:43 -08002565 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002566 self,
2567 opName,
2568 shapeFilter=[None],
2569 rankFilter=None,
2570 dtypeFilter=None,
2571 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002572 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002573
2574 try:
2575 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002576 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002577 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002578
2579 # Initialize a new random number generator
2580 self.rng = np.random.default_rng(self.random_seed)
2581
Jeremy Johnson1271c442023-09-05 11:39:26 +01002582 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002583
Eric Kunzee5e26762020-10-13 16:11:07 -07002584 # Test list consists of a tuple of:
2585 # (opName, testNameStr, dtype, shapeList, argumentsList)
2586 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002587 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002588 error_if_validators = op["error_if_validators"]
2589 else:
2590 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002591
Matthew Haddon1c00b712021-10-01 15:51:03 +01002592 for validator in error_if_validators:
2593 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002594 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002595 else:
2596 error_name = None
2597
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002598 filterDict = self.create_filter_lists(
2599 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2600 )
2601 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002602 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002603 cleanRankFilter = filterDict["rankFilter"]
2604 cleanDtypeFilter = filterDict["dtypeFilter"]
2605 cleanShapeFilter = filterDict["shapeFilter"]
2606 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002607
2608 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002609 for t in cleanDtypeFilter:
2610 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002611 # Filter out by rank
2612 if shape is not None and len(shape) != r:
2613 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002614 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002615 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002616
Matthew Haddon74567092021-07-16 15:38:20 +01002617 shapeStr = self.shapeStr(shapeList[0])
2618 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002619
Matthew Haddon74567092021-07-16 15:38:20 +01002620 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2621 argList = []
2622 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002623 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002624 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002625 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002626
Matthew Haddon74567092021-07-16 15:38:20 +01002627 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002628 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002629 if argStr:
2630 testStr = "{}_{}_{}_{}".format(
2631 opName, shapeStr, typeStr, argStr
2632 )
2633 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002634 testStr = "{}_{}_{}".format(
2635 opName, shapeStr, typeStr
2636 )
2637 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002638 if argStr:
2639 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2640 opName, error_name, shapeStr, typeStr, argStr
2641 )
2642 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002643 testStr = "{}_ERRORIF_{}_{}_{}".format(
2644 opName, error_name, shapeStr, typeStr
2645 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002646
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002647 testList.append(
2648 (opName, testStr, t, error_name, shapeList, args)
2649 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002650
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002651 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002652 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2653 if "invalid_test_validators" in op:
2654 invalid_test_validators = op["invalid_test_validators"]
2655 clean_testList = []
2656 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002657 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002658 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002659 if validator_fcn(
2660 opName=test[0],
2661 input_dtype=test[2],
2662 shapeList=test[4],
2663 args=test[5],
2664 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002665 remove_test = True
2666 if not remove_test:
2667 clean_testList.append(test)
2668 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002669
2670 return testList
2671
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002672 def serializeTest(
2673 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2674 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002675 try:
2676 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002677 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002678 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002679
Jeremy Johnson0c716862023-04-13 17:18:19 +01002680 if self.args.verbose:
2681 print(f"Creating {testStr}")
2682
Eric Kunzee5e26762020-10-13 16:11:07 -07002683 # Create a serializer
2684 self.createSerializer(opName, testStr)
2685
Jeremy Johnson1271c442023-09-05 11:39:26 +01002686 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002687 if "error_if_validators" in op:
2688 error_if_validators = op["error_if_validators"]
2689 else:
2690 error_if_validators = None
2691
Kevin Cheng550ccc52021-03-03 11:21:43 -08002692 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002693 num_operands = pCount + cCount
2694
2695 if isinstance(dtype_or_dtypeList, list):
2696 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002697 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002698 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002699 else:
2700 dtypeList = [dtype_or_dtypeList] * (num_operands)
2701
Kevin Cheng93a16282021-08-31 16:14:03 -07002702 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002703 assert (
2704 len(shapeList) == num_operands
2705 ), "shapeList length {} must match number of operands {}".format(
2706 len(shapeList), num_operands
2707 )
2708 assert (
2709 len(dtypeList) == num_operands
2710 ), "dtypeList length {} must match number of operands {}".format(
2711 len(dtypeList), num_operands
2712 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002713
2714 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002715 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002716 except KeyError:
2717 qgen = None
2718
2719 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002720
Matthew Haddon1c00b712021-10-01 15:51:03 +01002721 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002722 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002723 else:
2724 qinfo = None
2725
Jeremy Johnson1271c442023-09-05 11:39:26 +01002726 # Extra meta data for the desc.json
2727 tensMeta = {}
2728
2729 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002730 if isinstance(testArgs, dict):
2731 # New interface with args info in dictionary
2732 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002733 assert "dg_type" in argsDict
2734 tvgInfo = tvgen_fcn(
2735 self, opName, dtypeList, shapeList, argsDict, error_name
2736 )
2737 if tvgInfo.dataGenDict:
2738 tensMeta["data_gen"] = tvgInfo.dataGenDict
2739 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002740
2741 result = build_fcn(
2742 self,
2743 op,
2744 tens,
2745 argsDict,
2746 validator_fcns=error_if_validators,
2747 error_name=error_name,
2748 qinfo=qinfo,
2749 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002750 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002751 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002752 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002753
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002754 try:
2755 if error_if_validators is None:
2756 if qinfo is not None:
2757 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2758 else:
2759 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002760 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002761 if qinfo is not None:
2762 result = build_fcn(
2763 self,
2764 op,
2765 *tens,
2766 *testArgs,
2767 validator_fcns=error_if_validators,
2768 error_name=error_name,
2769 qinfo=qinfo,
2770 )
2771 else:
2772 result = build_fcn(
2773 self,
2774 op,
2775 *tens,
2776 *testArgs,
2777 validator_fcns=error_if_validators,
2778 error_name=error_name,
2779 )
2780 except TypeError as e:
2781 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2782 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002783
Jeremy Johnson1271c442023-09-05 11:39:26 +01002784 if result:
Les Bell729b0352021-11-24 10:28:21 +00002785 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002786 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2787 # Add the compliance meta data
2788 # NOTE: This currently expects only one result output
2789 tensMeta["compliance"] = {
2790 "version": "0.1",
2791 "tensors": {result.resultTensor.name: result.complianceDict},
2792 }
2793 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002794 else:
2795 # The test is not valid
2796 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002797
Eric Kunzee5e26762020-10-13 16:11:07 -07002798 def createDynamicOpLists(self):
2799
Jeremy Johnson00423432022-09-12 17:27:37 +01002800 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2801 # Already created these lists (can occur when class is initialized more than once)
2802 return
2803
Eric Kunzee5e26762020-10-13 16:11:07 -07002804 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002805 if not self.args.level8k:
2806 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2807 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2808 else:
2809 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2810 KERNELS_2D = [[1, bigK], [bigK, 2]]
2811 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002812
Kevin Cheng1533b852021-09-01 12:51:58 -07002813 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002814 testName = "conv2d_{}x{}".format(k[0], k[1])
2815 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2816 self.TOSA_OP_LIST[testName]["filter"] = k
2817 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002818
Kevin Cheng550ccc52021-03-03 11:21:43 -08002819 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2820 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2821 "depthwise_conv2d_TEMPLATE"
2822 ].copy()
2823 self.TOSA_OP_LIST[testName]["filter"] = k
2824 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002825
Kevin Cheng550ccc52021-03-03 11:21:43 -08002826 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2827 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2828 "transpose_conv2d_TEMPLATE"
2829 ].copy()
2830 self.TOSA_OP_LIST[testName]["filter"] = k
2831 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002832
Kevin Cheng1533b852021-09-01 12:51:58 -07002833 for k in KERNELS_3D:
2834 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2835 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2836 self.TOSA_OP_LIST[testName]["filter"] = k
2837 self.TOSA_OP_LIST[testName]["template"] = False
2838
Eric Kunzee5e26762020-10-13 16:11:07 -07002839 # Delete any templates after having created any dynamic ops
2840 # This is a two-pass operation because it's bad practice to delete
2841 # keys from dictionaries while iterating
2842 keyList = []
2843 for k in self.TOSA_OP_LIST:
2844 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002845 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002846 keyList.append(k)
2847 continue
2848 except KeyError:
2849 pass
2850
2851 for k in keyList:
2852 del self.TOSA_OP_LIST[k]
2853
2854 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002855 """Fill in default fields for ops if they aren't already specified.
2856 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002857 for op in self.TOSA_OP_LIST:
2858
2859 # Required fields
2860 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002861 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002862 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002863 raise Exception(
2864 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2865 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002866
2867 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002868 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002869 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002870 raise Exception(
2871 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2872 op
2873 )
2874 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002875
2876 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002877 _ = self.TOSA_OP_LIST[op]["types"]
2878 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002879 raise Exception(
2880 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2881 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002882
2883 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002884 _ = self.TOSA_OP_LIST[op]["op"]
2885 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002886 raise Exception(
2887 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2888 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002889
2890 # Put in default rank range, if missing
2891 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002892 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002893 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002894 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002895
2896 # Tensor operator list
2897 # 'op': op name
2898 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002899 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2900 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002901 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2902 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002903 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002904
Kevin Cheng550ccc52021-03-03 11:21:43 -08002905 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002906 TYPE_INT_FP = [
2907 DType.INT8,
2908 DType.INT16,
2909 DType.INT32,
2910 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002911 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002912 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002913 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002914
Kevin Cheng550ccc52021-03-03 11:21:43 -08002915 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002916 TYPE_FI32 = [
2917 DType.FP32,
2918 DType.FP16,
2919 DType.BF16,
2920 DType.INT32,
2921 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002922 TYPE_FIB = [
2923 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002924 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002925 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002926 DType.INT8,
2927 DType.INT16,
2928 DType.INT32,
2929 DType.BOOL,
2930 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002931 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002932
James Ward24dbc422022-10-19 12:20:31 +01002933 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002934
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002935 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002936 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002937 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002938 [DType.INT8, DType.INT8, DType.INT32],
2939 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002940 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002941 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002942 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002943 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002944 ]
2945
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002946 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002947
2948 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002949 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002950 "argmax": {
2951 "op": Op.ARGMAX,
2952 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002953 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002954 "build_fcn": (
2955 build_argmax,
2956 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002957 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002958 TosaArgGen.agAxis,
2959 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002960 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002961 "error_if_validators": (
2962 TosaErrorValidator.evAxisSmallerZero,
2963 TosaErrorValidator.evAxisLargerRank,
2964 TosaErrorValidator.evArgmaxOutputRankMismatch,
2965 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2966 TosaErrorValidator.evWrongRank,
2967 TosaErrorValidator.evWrongInputType,
2968 TosaErrorValidator.evWrongOutputType,
2969 TosaErrorValidator.evWrongInputList,
2970 TosaErrorValidator.evWrongOutputList,
2971 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002972 "data_gen": {
2973 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
2974 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002975 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002976 "avg_pool2d": {
2977 "op": Op.AVG_POOL2D,
2978 "operands": (1, 0),
2979 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002980 "build_fcn": (
2981 build_pool2d,
2982 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002983 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002984 TosaArgGen.agPooling,
2985 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002986 "qgen": TosaQuantGen.qgUnary,
2987 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002988 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002989 "error_if_validators": (
2990 TosaErrorValidator.evKernelSmallerOne,
2991 TosaErrorValidator.evStrideSmallerOne,
2992 TosaErrorValidator.evPadSmallerZero,
2993 TosaErrorValidator.evWrongRank,
2994 TosaErrorValidator.evWrongInputType,
2995 TosaErrorValidator.evWrongOutputType,
2996 TosaErrorValidator.evWrongInputList,
2997 TosaErrorValidator.evWrongOutputList,
2998 TosaErrorValidator.evInputZeroPointNotZero,
2999 TosaErrorValidator.evOutputZeroPointNotZero,
3000 TosaErrorValidator.evPadLargerEqualKernel,
3001 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003002 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003003 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003004 "data_gen": {
3005 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3006 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003007 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003008 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003009 "conv2d_TEMPLATE": {
3010 "op": Op.CONV2D,
3011 "operands": (1, 2),
3012 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003013 "build_fcn": (
3014 build_conv2d,
3015 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003016 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003017 TosaArgGen.agConv,
3018 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003019 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003020 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003021 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3022 "error_if_validators": (
3023 TosaErrorValidator.evWrongInputType,
3024 TosaErrorValidator.evWrongOutputType,
3025 TosaErrorValidator.evWrongInputList,
3026 TosaErrorValidator.evWrongOutputList,
3027 TosaErrorValidator.evInputZeroPointNotZero,
3028 TosaErrorValidator.evWeightZeroPointNotZero,
3029 TosaErrorValidator.evPadSmallerZero,
3030 TosaErrorValidator.evStrideSmallerOne,
3031 TosaErrorValidator.evDilationSmallerOne,
3032 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003033 TosaErrorValidator.evConvOutputShapeMismatch,
3034 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003035 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003036 "data_gen": {
3037 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3038 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003039 "template": True,
3040 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003041 # Templated operator. Filled in by createDynamicOpLists
3042 "conv3d_TEMPLATE": {
3043 "op": Op.CONV3D,
3044 "operands": (1, 2),
3045 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003046 "build_fcn": (
3047 build_conv3d,
3048 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003049 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003050 TosaArgGen.agConv,
3051 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003052 "qgen": TosaQuantGen.qgConv,
3053 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003054 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3055 "error_if_validators": (
3056 TosaErrorValidator.evWrongInputType,
3057 TosaErrorValidator.evWrongOutputType,
3058 TosaErrorValidator.evWrongInputList,
3059 TosaErrorValidator.evWrongOutputList,
3060 TosaErrorValidator.evInputZeroPointNotZero,
3061 TosaErrorValidator.evWeightZeroPointNotZero,
3062 TosaErrorValidator.evPadSmallerZero,
3063 TosaErrorValidator.evStrideSmallerOne,
3064 TosaErrorValidator.evDilationSmallerOne,
3065 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003066 TosaErrorValidator.evConvOutputShapeMismatch,
3067 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003068 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003069 "template": True,
3070 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003071 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003072 "depthwise_conv2d_TEMPLATE": {
3073 "op": Op.DEPTHWISE_CONV2D,
3074 "operands": (1, 2),
3075 "filter": [1, 1],
3076 "rank": (4, 4),
3077 "build_fcn": (
3078 build_depthwise_conv2d,
3079 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003080 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003081 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003082 ),
3083 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003084 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003085 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3086 "error_if_validators": (
3087 TosaErrorValidator.evWrongInputType,
3088 TosaErrorValidator.evWrongOutputType,
3089 TosaErrorValidator.evWrongInputList,
3090 TosaErrorValidator.evWrongOutputList,
3091 TosaErrorValidator.evInputZeroPointNotZero,
3092 TosaErrorValidator.evWeightZeroPointNotZero,
3093 TosaErrorValidator.evPadSmallerZero,
3094 TosaErrorValidator.evStrideSmallerOne,
3095 TosaErrorValidator.evDilationSmallerOne,
3096 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003097 TosaErrorValidator.evConvOutputShapeMismatch,
3098 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003099 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003100 "template": True,
3101 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003102 "fully_connected": {
3103 "op": Op.FULLY_CONNECTED,
3104 "operands": (1, 2),
3105 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003106 "build_fcn": (
3107 build_fully_connected,
3108 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003109 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003110 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003111 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003112 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003113 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003114 "error_if_validators": (
3115 TosaErrorValidator.evInputZeroPointNotZero,
3116 TosaErrorValidator.evWeightZeroPointNotZero,
3117 TosaErrorValidator.evWrongRank,
3118 TosaErrorValidator.evWrongInputType,
3119 TosaErrorValidator.evWrongOutputType,
3120 TosaErrorValidator.evWrongInputList,
3121 TosaErrorValidator.evWrongOutputList,
3122 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003123 "data_gen": {
3124 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3125 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003126 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003127 "matmul": {
3128 "op": Op.MATMUL,
3129 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003130 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003131 "build_fcn": (
3132 build_matmul,
3133 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003134 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003135 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003136 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003137 "qgen": TosaQuantGen.qgMatmul,
3138 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003139 "error_if_validators": (
3140 TosaErrorValidator.evInputZeroPointNotZero,
3141 TosaErrorValidator.evWrongRank,
3142 TosaErrorValidator.evWrongInputType,
3143 TosaErrorValidator.evWrongOutputType,
3144 TosaErrorValidator.evWrongInputList,
3145 TosaErrorValidator.evWrongOutputList,
3146 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003147 "data_gen": {
3148 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003149 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003150 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003151 "max_pool2d": {
3152 "op": Op.MAX_POOL2D,
3153 "operands": (1, 0),
3154 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003155 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003156 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003157 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003158 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003159 TosaArgGen.agPooling,
3160 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003161 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003162 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003163 "error_if_validators": (
3164 TosaErrorValidator.evKernelSmallerOne,
3165 TosaErrorValidator.evStrideSmallerOne,
3166 TosaErrorValidator.evPadSmallerZero,
3167 TosaErrorValidator.evWrongRank,
3168 TosaErrorValidator.evWrongInputType,
3169 TosaErrorValidator.evWrongOutputType,
3170 TosaErrorValidator.evWrongInputList,
3171 TosaErrorValidator.evWrongOutputList,
3172 TosaErrorValidator.evPadLargerEqualKernel,
3173 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003174 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003175 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003176 "data_gen": {
3177 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3178 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003179 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003180 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003181 "transpose_conv2d_TEMPLATE": {
3182 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003183 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003184 "rank": (4, 4),
3185 "build_fcn": (
3186 build_transpose_conv2d,
3187 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003188 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003189 TosaArgGen.agTransposeConv2D,
3190 ),
3191 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003192 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003193 "invalid_test_validators": (
3194 TosaInvalidValidator.ivHeightWidthInvalid,
3195 TosaInvalidValidator.ivNonPositiveOutputShape,
3196 ),
3197 "error_if_validators": (
3198 TosaErrorValidator.evWrongInputType,
3199 TosaErrorValidator.evWrongOutputType,
3200 TosaErrorValidator.evWrongInputList,
3201 TosaErrorValidator.evWrongOutputList,
3202 TosaErrorValidator.evInputZeroPointNotZero,
3203 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003204 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003205 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003206 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003207 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003208 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003209 "template": True,
3210 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003211 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003212 "clamp": {
3213 "op": Op.CLAMP,
3214 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003215 "build_fcn": (
3216 build_clamp,
3217 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003218 TosaTensorValuesGen.tvgLazyGenDefault,
3219 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003220 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003221 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003222 "error_if_validators": (
3223 TosaErrorValidator.evMaxSmallerMin,
3224 TosaErrorValidator.evWrongInputType,
3225 TosaErrorValidator.evWrongOutputType,
3226 TosaErrorValidator.evWrongInputList,
3227 TosaErrorValidator.evWrongOutputList,
3228 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003229 "data_gen": {
3230 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3231 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003232 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003233 "sigmoid": {
3234 "op": Op.SIGMOID,
3235 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003236 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003237 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003238 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003239 TosaTensorValuesGen.tvgLazyGenDefault,
3240 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003241 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003242 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003243 "error_if_validators": (
3244 TosaErrorValidator.evWrongInputType,
3245 TosaErrorValidator.evWrongOutputType,
3246 TosaErrorValidator.evWrongInputList,
3247 TosaErrorValidator.evWrongOutputList,
3248 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003249 "data_gen": {
3250 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3251 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003252 },
3253 "tanh": {
3254 "op": Op.TANH,
3255 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003256 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003257 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003258 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003259 TosaTensorValuesGen.tvgLazyGenDefault,
3260 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003261 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003262 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003263 "error_if_validators": (
3264 TosaErrorValidator.evWrongInputType,
3265 TosaErrorValidator.evWrongOutputType,
3266 TosaErrorValidator.evWrongInputList,
3267 TosaErrorValidator.evWrongOutputList,
3268 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003269 "data_gen": {
3270 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3271 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003272 },
Won Jeon78155c62023-06-10 00:20:04 +00003273 "erf": {
3274 "op": Op.ERF,
3275 "operands": (1, 0),
3276 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003277 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003278 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003279 TosaTensorValuesGen.tvgLazyGenDefault,
3280 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003281 ),
3282 "types": TYPE_FP,
3283 "error_if_validators": (
3284 TosaErrorValidator.evWrongInputType,
3285 TosaErrorValidator.evWrongOutputType,
3286 TosaErrorValidator.evWrongInputList,
3287 TosaErrorValidator.evWrongOutputList,
3288 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003289 "data_gen": {
3290 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3291 },
3292 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003293 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003294 # Elementwise Binary Operators
3295 "add": {
3296 "op": Op.ADD,
3297 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003298 "build_fcn": (
3299 build_binary_broadcast,
3300 TosaTensorGen.tgBroadcastFuzz,
3301 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003302 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003303 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003304 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003305 "error_if_validators": (
3306 TosaErrorValidator.evRankMismatch,
3307 TosaErrorValidator.evWrongInputType,
3308 TosaErrorValidator.evWrongOutputType,
3309 TosaErrorValidator.evWrongInputList,
3310 TosaErrorValidator.evWrongOutputList,
3311 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003312 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003313 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003314 "data_gen": {
3315 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3316 },
3317 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003318 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003319 "arithmetic_right_shift": {
3320 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3321 "operands": (2, 0),
3322 "build_fcn": (
3323 build_arithmetic_right_shift,
3324 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003325 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003326 TosaArgGen.agArithmeticRightShift,
3327 ),
3328 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003329 "error_if_validators": (
3330 TosaErrorValidator.evRankMismatch,
3331 TosaErrorValidator.evWrongInputType,
3332 TosaErrorValidator.evWrongOutputType,
3333 TosaErrorValidator.evWrongInputList,
3334 TosaErrorValidator.evWrongOutputList,
3335 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003336 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003337 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003339 "bitwise_and": {
3340 "op": Op.BITWISE_AND,
3341 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003342 "build_fcn": (
3343 build_binary_broadcast,
3344 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003345 TosaTensorValuesGen.tvgLazyGenDefault,
3346 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003347 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003348 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003349 "error_if_validators": (
3350 TosaErrorValidator.evRankMismatch,
3351 TosaErrorValidator.evWrongInputType,
3352 TosaErrorValidator.evWrongOutputType,
3353 TosaErrorValidator.evWrongInputList,
3354 TosaErrorValidator.evWrongOutputList,
3355 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003356 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003357 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003358 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003359 "bitwise_or": {
3360 "op": Op.BITWISE_OR,
3361 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003362 "build_fcn": (
3363 build_binary_broadcast,
3364 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003365 TosaTensorValuesGen.tvgLazyGenDefault,
3366 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003367 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003368 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003369 "error_if_validators": (
3370 TosaErrorValidator.evRankMismatch,
3371 TosaErrorValidator.evWrongInputType,
3372 TosaErrorValidator.evWrongOutputType,
3373 TosaErrorValidator.evWrongInputList,
3374 TosaErrorValidator.evWrongOutputList,
3375 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003376 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003377 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003378 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003379 "bitwise_xor": {
3380 "op": Op.BITWISE_XOR,
3381 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003382 "build_fcn": (
3383 build_binary_broadcast,
3384 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003385 TosaTensorValuesGen.tvgLazyGenDefault,
3386 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003387 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003388 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003389 "error_if_validators": (
3390 TosaErrorValidator.evRankMismatch,
3391 TosaErrorValidator.evWrongInputType,
3392 TosaErrorValidator.evWrongOutputType,
3393 TosaErrorValidator.evWrongInputList,
3394 TosaErrorValidator.evWrongOutputList,
3395 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003396 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003397 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003398 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003399 "intdiv": {
3400 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003401 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003402 "build_fcn": (
3403 build_binary_broadcast,
3404 TosaTensorGen.tgBroadcastFuzz,
3405 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003406 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003407 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003408 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003409 "error_if_validators": (
3410 TosaErrorValidator.evRankMismatch,
3411 TosaErrorValidator.evWrongInputType,
3412 TosaErrorValidator.evWrongOutputType,
3413 TosaErrorValidator.evWrongInputList,
3414 TosaErrorValidator.evWrongOutputList,
3415 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003416 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003417 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003418 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003419 "logical_and": {
3420 "op": Op.LOGICAL_AND,
3421 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003422 "build_fcn": (
3423 build_binary_broadcast,
3424 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003425 TosaTensorValuesGen.tvgLazyGenDefault,
3426 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003427 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003428 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003429 "error_if_validators": (
3430 TosaErrorValidator.evRankMismatch,
3431 TosaErrorValidator.evWrongInputType,
3432 TosaErrorValidator.evWrongOutputType,
3433 TosaErrorValidator.evWrongInputList,
3434 TosaErrorValidator.evWrongOutputList,
3435 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003436 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003437 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003438 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003439 "logical_left_shift": {
3440 "op": Op.LOGICAL_LEFT_SHIFT,
3441 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003442 "build_fcn": (
3443 build_binary_broadcast,
3444 TosaTensorGen.tgBroadcastFuzz,
3445 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003446 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003447 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003448 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003449 "error_if_validators": (
3450 TosaErrorValidator.evRankMismatch,
3451 TosaErrorValidator.evWrongInputType,
3452 TosaErrorValidator.evWrongOutputType,
3453 TosaErrorValidator.evWrongInputList,
3454 TosaErrorValidator.evWrongOutputList,
3455 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003456 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003457 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003458 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003459 "logical_right_shift": {
3460 "op": Op.LOGICAL_RIGHT_SHIFT,
3461 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003462 "build_fcn": (
3463 build_binary_broadcast,
3464 TosaTensorGen.tgBroadcastFuzz,
3465 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003466 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003467 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003468 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003469 "error_if_validators": (
3470 TosaErrorValidator.evRankMismatch,
3471 TosaErrorValidator.evWrongInputType,
3472 TosaErrorValidator.evWrongOutputType,
3473 TosaErrorValidator.evWrongInputList,
3474 TosaErrorValidator.evWrongOutputList,
3475 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003476 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003477 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003478 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003479 "logical_or": {
3480 "op": Op.LOGICAL_OR,
3481 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003482 "build_fcn": (
3483 build_binary_broadcast,
3484 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003485 TosaTensorValuesGen.tvgLazyGenDefault,
3486 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003487 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003488 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003489 "error_if_validators": (
3490 TosaErrorValidator.evRankMismatch,
3491 TosaErrorValidator.evWrongInputType,
3492 TosaErrorValidator.evWrongOutputType,
3493 TosaErrorValidator.evWrongInputList,
3494 TosaErrorValidator.evWrongOutputList,
3495 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003496 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003497 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003498 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003499 "logical_xor": {
3500 "op": Op.LOGICAL_XOR,
3501 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003502 "build_fcn": (
3503 build_binary_broadcast,
3504 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003505 TosaTensorValuesGen.tvgLazyGenDefault,
3506 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003507 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003508 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003509 "error_if_validators": (
3510 TosaErrorValidator.evRankMismatch,
3511 TosaErrorValidator.evWrongInputType,
3512 TosaErrorValidator.evWrongOutputType,
3513 TosaErrorValidator.evWrongInputList,
3514 TosaErrorValidator.evWrongOutputList,
3515 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003516 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003517 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003518 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003519 "maximum": {
3520 "op": Op.MAXIMUM,
3521 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003522 "build_fcn": (
3523 build_binary_broadcast,
3524 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003525 TosaTensorValuesGen.tvgLazyGenDefault,
3526 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003527 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003528 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003529 "error_if_validators": (
3530 TosaErrorValidator.evRankMismatch,
3531 TosaErrorValidator.evWrongInputType,
3532 TosaErrorValidator.evWrongOutputType,
3533 TosaErrorValidator.evWrongInputList,
3534 TosaErrorValidator.evWrongOutputList,
3535 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003536 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003537 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003538 "data_gen": {
3539 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3540 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003541 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003542 "minimum": {
3543 "op": Op.MINIMUM,
3544 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003545 "build_fcn": (
3546 build_binary_broadcast,
3547 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003548 TosaTensorValuesGen.tvgLazyGenDefault,
3549 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003550 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003551 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003552 "error_if_validators": (
3553 TosaErrorValidator.evRankMismatch,
3554 TosaErrorValidator.evWrongInputType,
3555 TosaErrorValidator.evWrongOutputType,
3556 TosaErrorValidator.evWrongInputList,
3557 TosaErrorValidator.evWrongOutputList,
3558 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003559 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003560 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003561 "data_gen": {
3562 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3563 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003564 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003565 "mul": {
3566 "op": Op.MUL,
3567 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003568 "build_fcn": (
3569 build_mul,
3570 TosaTensorGen.tgBroadcastFuzz,
3571 TosaTensorValuesGen.tvgMul,
3572 TosaArgGen.agMul,
3573 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003574 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003575 "error_if_validators": (
3576 TosaErrorValidator.evWrongInputType,
3577 TosaErrorValidator.evWrongOutputType,
3578 TosaErrorValidator.evWrongInputList,
3579 TosaErrorValidator.evWrongOutputList,
3580 TosaErrorValidator.evRankMismatch,
3581 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003582 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003583 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003584 "data_gen": {
3585 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3586 },
3587 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003588 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003589 "pow": {
3590 "op": Op.POW,
3591 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003592 "build_fcn": (
3593 build_binary_broadcast,
3594 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003595 TosaTensorValuesGen.tvgPow,
3596 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003597 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003598 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003599 "error_if_validators": (
3600 TosaErrorValidator.evRankMismatch,
3601 TosaErrorValidator.evWrongInputType,
3602 TosaErrorValidator.evWrongOutputType,
3603 TosaErrorValidator.evWrongInputList,
3604 TosaErrorValidator.evWrongOutputList,
3605 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003606 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003607 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003608 "data_gen": {
3609 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3610 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003611 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003612 "sub": {
3613 "op": Op.SUB,
3614 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003615 "build_fcn": (
3616 build_binary_broadcast,
3617 TosaTensorGen.tgBroadcastFuzz,
3618 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003619 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003620 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003621 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003622 "error_if_validators": (
3623 TosaErrorValidator.evRankMismatch,
3624 TosaErrorValidator.evWrongInputType,
3625 TosaErrorValidator.evWrongOutputType,
3626 TosaErrorValidator.evWrongInputList,
3627 TosaErrorValidator.evWrongOutputList,
3628 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003629 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003630 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003631 "data_gen": {
3632 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3633 },
3634 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003635 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003636 "table": {
3637 "op": Op.TABLE,
3638 # Use the automatic generation functions to create the input array
3639 # but create the table tensor in the build function, as it may be
3640 # a different type from the input
3641 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003642 "build_fcn": (
3643 build_table,
3644 TosaTensorGen.tgBasic,
3645 TosaTensorValuesGen.tvgDefault,
3646 TosaArgGen.agTable,
3647 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003648 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003649 "error_if_validators": (
3650 TosaErrorValidator.evWrongInputType,
3651 TosaErrorValidator.evWrongOutputType,
3652 TosaErrorValidator.evWrongInputList,
3653 TosaErrorValidator.evWrongOutputList,
3654 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003655 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003656 # Elementwise Unary operators
3657 "abs": {
3658 "op": Op.ABS,
3659 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003660 "build_fcn": (
3661 build_unary,
3662 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003663 TosaTensorValuesGen.tvgLazyGenDefault,
3664 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003665 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003666 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003667 "error_if_validators": (
3668 TosaErrorValidator.evWrongInputType,
3669 TosaErrorValidator.evWrongOutputType,
3670 TosaErrorValidator.evWrongInputList,
3671 TosaErrorValidator.evWrongOutputList,
3672 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003673 "data_gen": {
3674 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3675 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003676 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003677 "bitwise_not": {
3678 "op": Op.BITWISE_NOT,
3679 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003680 "build_fcn": (
3681 build_unary,
3682 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003683 TosaTensorValuesGen.tvgLazyGenDefault,
3684 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003685 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003686 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003687 "error_if_validators": (
3688 TosaErrorValidator.evWrongInputType,
3689 TosaErrorValidator.evWrongOutputType,
3690 TosaErrorValidator.evWrongInputList,
3691 TosaErrorValidator.evWrongOutputList,
3692 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003693 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003694 "ceil": {
3695 "op": Op.CEIL,
3696 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003697 "build_fcn": (
3698 build_unary,
3699 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003700 TosaTensorValuesGen.tvgLazyGenDefault,
3701 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003702 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003703 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003704 "error_if_validators": (
3705 TosaErrorValidator.evWrongInputType,
3706 TosaErrorValidator.evWrongOutputType,
3707 TosaErrorValidator.evWrongInputList,
3708 TosaErrorValidator.evWrongOutputList,
3709 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003710 "data_gen": {
3711 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3712 },
3713 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003714 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003715 "clz": {
3716 "op": Op.CLZ,
3717 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003718 "build_fcn": (
3719 build_unary,
3720 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003721 TosaTensorValuesGen.tvgLazyGenDefault,
3722 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003723 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003724 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003725 "error_if_validators": (
3726 TosaErrorValidator.evWrongInputType,
3727 TosaErrorValidator.evWrongOutputType,
3728 TosaErrorValidator.evWrongInputList,
3729 TosaErrorValidator.evWrongOutputList,
3730 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003731 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003732 "exp": {
3733 "op": Op.EXP,
3734 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003735 "build_fcn": (
3736 build_unary,
3737 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003738 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003739 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003740 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003741 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003742 "error_if_validators": (
3743 TosaErrorValidator.evWrongInputType,
3744 TosaErrorValidator.evWrongOutputType,
3745 TosaErrorValidator.evWrongInputList,
3746 TosaErrorValidator.evWrongOutputList,
3747 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003748 "data_gen": {
3749 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3750 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003751 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003752 "floor": {
3753 "op": Op.FLOOR,
3754 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003755 "build_fcn": (
3756 build_unary,
3757 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003758 TosaTensorValuesGen.tvgLazyGenDefault,
3759 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003760 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003761 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003762 "error_if_validators": (
3763 TosaErrorValidator.evWrongInputType,
3764 TosaErrorValidator.evWrongOutputType,
3765 TosaErrorValidator.evWrongInputList,
3766 TosaErrorValidator.evWrongOutputList,
3767 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003768 "data_gen": {
3769 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3770 },
3771 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003772 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003773 "log": {
3774 "op": Op.LOG,
3775 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003776 "build_fcn": (
3777 build_unary,
3778 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003779 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003780 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003781 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003782 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003783 "error_if_validators": (
3784 TosaErrorValidator.evWrongInputType,
3785 TosaErrorValidator.evWrongOutputType,
3786 TosaErrorValidator.evWrongInputList,
3787 TosaErrorValidator.evWrongOutputList,
3788 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003789 "data_gen": {
3790 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3791 },
3792 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003793 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003794 "logical_not": {
3795 "op": Op.LOGICAL_NOT,
3796 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003797 "build_fcn": (
3798 build_unary,
3799 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003800 TosaTensorValuesGen.tvgLazyGenDefault,
3801 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003802 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003803 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003804 "error_if_validators": (
3805 TosaErrorValidator.evWrongInputType,
3806 TosaErrorValidator.evWrongOutputType,
3807 TosaErrorValidator.evWrongInputList,
3808 TosaErrorValidator.evWrongOutputList,
3809 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003810 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003811 "negate": {
3812 "op": Op.NEGATE,
3813 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003814 "build_fcn": (
3815 build_unary,
3816 TosaTensorGen.tgBasic,
3817 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003818 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003819 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003820 "qgen": TosaQuantGen.qgUnary,
3821 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003822 "error_if_validators": (
3823 TosaErrorValidator.evInputZeroPointNotZero,
3824 TosaErrorValidator.evOutputZeroPointNotZero,
3825 TosaErrorValidator.evWrongInputType,
3826 TosaErrorValidator.evWrongOutputType,
3827 TosaErrorValidator.evWrongInputList,
3828 TosaErrorValidator.evWrongOutputList,
3829 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003830 "data_gen": {
3831 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3832 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003833 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003834 "reciprocal": {
3835 "op": Op.RECIPROCAL,
3836 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003837 "build_fcn": (
3838 build_unary,
3839 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003840 TosaTensorValuesGen.tvgLazyGenDefault,
3841 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003842 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003843 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003844 "error_if_validators": (
3845 TosaErrorValidator.evWrongInputType,
3846 TosaErrorValidator.evWrongOutputType,
3847 TosaErrorValidator.evWrongInputList,
3848 TosaErrorValidator.evWrongOutputList,
3849 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003850 "data_gen": {
3851 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3852 },
3853 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003854 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003855 "rsqrt": {
3856 "op": Op.RSQRT,
3857 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003858 "build_fcn": (
3859 build_unary,
3860 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003861 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003862 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003863 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003864 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003865 "error_if_validators": (
3866 TosaErrorValidator.evWrongInputType,
3867 TosaErrorValidator.evWrongOutputType,
3868 TosaErrorValidator.evWrongInputList,
3869 TosaErrorValidator.evWrongOutputList,
3870 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003871 "data_gen": {
3872 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3873 },
3874 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003875 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003876 # Elementwise Ternary operators
3877 "select": {
3878 "op": Op.SELECT,
3879 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003880 "build_fcn": (
3881 build_select,
3882 TosaTensorGen.tgBroadcastFuzz,
3883 TosaTensorValuesGen.tvgSelect,
3884 None,
3885 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003886 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003887 "error_if_validators": (
3888 TosaErrorValidator.evRankMismatch,
3889 TosaErrorValidator.evWrongInputType,
3890 TosaErrorValidator.evWrongOutputType,
3891 TosaErrorValidator.evWrongInputList,
3892 TosaErrorValidator.evWrongOutputList,
3893 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003894 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003895 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003896 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003897 # Comparison operators
3898 "equal": {
3899 "op": Op.EQUAL,
3900 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003901 "build_fcn": (
3902 build_comparison,
3903 TosaTensorGen.tgBroadcastFuzz,
3904 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003905 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003906 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003907 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003908 "error_if_validators": (
3909 TosaErrorValidator.evRankMismatch,
3910 TosaErrorValidator.evWrongInputType,
3911 TosaErrorValidator.evWrongOutputType,
3912 TosaErrorValidator.evWrongInputList,
3913 TosaErrorValidator.evWrongOutputList,
3914 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003915 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003916 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003917 "data_gen": {
3918 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3919 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003920 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003921 "greater_equal": {
3922 "op": Op.GREATER_EQUAL,
3923 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003924 "build_fcn": (
3925 build_comparison,
3926 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003927 TosaTensorValuesGen.tvgLazyGenDefault,
3928 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003929 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003930 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003931 "error_if_validators": (
3932 TosaErrorValidator.evRankMismatch,
3933 TosaErrorValidator.evWrongInputType,
3934 TosaErrorValidator.evWrongOutputType,
3935 TosaErrorValidator.evWrongInputList,
3936 TosaErrorValidator.evWrongOutputList,
3937 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003938 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003939 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003940 "data_gen": {
3941 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3942 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003943 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003944 "greater": {
3945 "op": Op.GREATER,
3946 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003947 "build_fcn": (
3948 build_comparison,
3949 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003950 TosaTensorValuesGen.tvgLazyGenDefault,
3951 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003952 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003953 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003954 "error_if_validators": (
3955 TosaErrorValidator.evRankMismatch,
3956 TosaErrorValidator.evWrongInputType,
3957 TosaErrorValidator.evWrongOutputType,
3958 TosaErrorValidator.evWrongInputList,
3959 TosaErrorValidator.evWrongOutputList,
3960 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003961 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003962 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003963 "data_gen": {
3964 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3965 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003966 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003967 # Reduction operators
3968 "reduce_all": {
3969 "op": Op.REDUCE_ALL,
3970 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003971 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003972 "build_fcn": (
3973 build_reduce,
3974 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003975 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003976 TosaArgGen.agAxis,
3977 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003978 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003979 "error_if_validators": (
3980 TosaErrorValidator.evAxisLargerRank,
3981 TosaErrorValidator.evAxisSmallerZero,
3982 TosaErrorValidator.evShapeOfAxisNotOne,
3983 TosaErrorValidator.evWrongInputType,
3984 TosaErrorValidator.evWrongOutputType,
3985 TosaErrorValidator.evWrongRank,
3986 TosaErrorValidator.evWrongInputList,
3987 TosaErrorValidator.evWrongOutputList,
3988 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003989 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003990 "reduce_any": {
3991 "op": Op.REDUCE_ANY,
3992 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003993 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003994 "build_fcn": (
3995 build_reduce,
3996 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003997 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003998 TosaArgGen.agAxis,
3999 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004000 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004001 "error_if_validators": (
4002 TosaErrorValidator.evAxisLargerRank,
4003 TosaErrorValidator.evAxisSmallerZero,
4004 TosaErrorValidator.evShapeOfAxisNotOne,
4005 TosaErrorValidator.evWrongInputType,
4006 TosaErrorValidator.evWrongOutputType,
4007 TosaErrorValidator.evWrongRank,
4008 TosaErrorValidator.evWrongInputList,
4009 TosaErrorValidator.evWrongOutputList,
4010 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004011 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004012 "reduce_max": {
4013 "op": Op.REDUCE_MAX,
4014 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004015 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004016 "build_fcn": (
4017 build_reduce,
4018 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004019 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004020 TosaArgGen.agAxis,
4021 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004022 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004023 "error_if_validators": (
4024 TosaErrorValidator.evAxisLargerRank,
4025 TosaErrorValidator.evAxisSmallerZero,
4026 TosaErrorValidator.evShapeOfAxisNotOne,
4027 TosaErrorValidator.evWrongInputType,
4028 TosaErrorValidator.evWrongOutputType,
4029 TosaErrorValidator.evWrongRank,
4030 TosaErrorValidator.evWrongInputList,
4031 TosaErrorValidator.evWrongOutputList,
4032 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004033 "data_gen": {
4034 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4035 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004036 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004037 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004038 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004039 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004040 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004041 "build_fcn": (
4042 build_reduce,
4043 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004044 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004045 TosaArgGen.agAxis,
4046 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004047 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004048 "error_if_validators": (
4049 TosaErrorValidator.evAxisLargerRank,
4050 TosaErrorValidator.evAxisSmallerZero,
4051 TosaErrorValidator.evShapeOfAxisNotOne,
4052 TosaErrorValidator.evWrongInputType,
4053 TosaErrorValidator.evWrongOutputType,
4054 TosaErrorValidator.evWrongRank,
4055 TosaErrorValidator.evWrongInputList,
4056 TosaErrorValidator.evWrongOutputList,
4057 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004058 "data_gen": {
4059 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4060 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004061 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004062 "reduce_product": {
4063 "op": Op.REDUCE_PRODUCT,
4064 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004065 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004066 "build_fcn": (
4067 build_reduce,
4068 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004069 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004070 TosaArgGen.agAxis,
4071 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004072 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004073 "error_if_validators": (
4074 TosaErrorValidator.evAxisLargerRank,
4075 TosaErrorValidator.evAxisSmallerZero,
4076 TosaErrorValidator.evShapeOfAxisNotOne,
4077 TosaErrorValidator.evWrongInputType,
4078 TosaErrorValidator.evWrongOutputType,
4079 TosaErrorValidator.evWrongRank,
4080 TosaErrorValidator.evWrongInputList,
4081 TosaErrorValidator.evWrongOutputList,
4082 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004083 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004084 "reduce_sum": {
4085 "op": Op.REDUCE_SUM,
4086 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004087 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004088 "build_fcn": (
4089 build_reduce,
4090 TosaTensorGen.tgBasic,
4091 TosaTensorValuesGen.tvgReduceSum,
4092 TosaArgGen.agAxis,
4093 ),
James Ward24dbc422022-10-19 12:20:31 +01004094 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004095 "error_if_validators": (
4096 TosaErrorValidator.evAxisLargerRank,
4097 TosaErrorValidator.evAxisSmallerZero,
4098 TosaErrorValidator.evShapeOfAxisNotOne,
4099 TosaErrorValidator.evWrongInputType,
4100 TosaErrorValidator.evWrongOutputType,
4101 TosaErrorValidator.evWrongRank,
4102 TosaErrorValidator.evWrongInputList,
4103 TosaErrorValidator.evWrongOutputList,
4104 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004105 "data_gen": {
4106 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4107 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004108 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004109 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004110 "concat": {
4111 "op": Op.CONCAT,
4112 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004113 "build_fcn": (
4114 build_concat,
4115 TosaTensorGen.tgConcat,
4116 TosaTensorValuesGen.tvgConcat,
4117 TosaArgGen.agAxis,
4118 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004119 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004120 "error_if_validators": (
4121 TosaErrorValidator.evAxisLargerRank,
4122 TosaErrorValidator.evAxisSmallerZero,
4123 TosaErrorValidator.evConcatInputRankMismatch,
4124 TosaErrorValidator.evConcatShapeSumMismatch,
4125 TosaErrorValidator.evConcatInputDimMismatch,
4126 TosaErrorValidator.evWrongInputType,
4127 TosaErrorValidator.evWrongOutputType,
4128 TosaErrorValidator.evWrongOutputList,
4129 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004130 },
4131 "pad": {
4132 "op": Op.PAD,
4133 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004134 "build_fcn": (
4135 build_pad,
4136 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004137 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004138 TosaArgGen.agPad,
4139 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004140 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004141 "error_if_validators": (
4142 TosaErrorValidator.evWrongInputType,
4143 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004144 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004145 TosaErrorValidator.evWrongOutputType,
4146 TosaErrorValidator.evWrongInputList,
4147 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004148 TosaErrorValidator.evRankMismatch,
4149 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004150 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004151 "data_gen": {
4152 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4153 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004154 },
Won Jeona21b2e82023-08-10 10:33:01 +00004155 "dim": {
4156 "op": Op.DIM,
4157 "operands": (1, 0),
4158 "build_fcn": (
4159 build_dim,
4160 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004161 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004162 TosaArgGen.agAxis,
4163 ),
4164 "types": TYPE_FIB,
4165 "error_if_validators": (
4166 TosaErrorValidator.evAxisLargerRank,
4167 TosaErrorValidator.evAxisSmallerZero,
4168 TosaErrorValidator.evWrongInputType,
4169 TosaErrorValidator.evWrongInputList,
4170 TosaErrorValidator.evWrongOutputList,
4171 TosaErrorValidator.evWrongRank,
4172 ),
4173 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004174 "reshape": {
4175 "op": Op.RESHAPE,
4176 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004177 "build_fcn": (
4178 build_reshape,
4179 TosaTensorGen.tgBasic,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004180 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004181 TosaArgGen.agReshape,
4182 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004183 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004184 "error_if_validators": (
4185 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4186 TosaErrorValidator.evWrongInputType,
4187 TosaErrorValidator.evWrongOutputType,
4188 TosaErrorValidator.evWrongInputList,
4189 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00004190 TosaErrorValidator.evReshapeOutputSizeMultiInference,
4191 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004192 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004193 "data_gen": {
4194 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4195 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004196 },
4197 "reverse": {
4198 "op": Op.REVERSE,
4199 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004200 "build_fcn": (
4201 build_reverse,
4202 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004203 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004204 TosaArgGen.agAxis,
4205 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004206 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004207 "error_if_validators": (
4208 TosaErrorValidator.evAxisSmallerZero,
4209 TosaErrorValidator.evAxisLargerRank,
4210 TosaErrorValidator.evWrongInputType,
4211 TosaErrorValidator.evWrongOutputType,
4212 TosaErrorValidator.evWrongInputList,
4213 TosaErrorValidator.evWrongOutputList,
4214 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004215 },
4216 "slice": {
4217 "op": Op.SLICE,
4218 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004219 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004220 "build_fcn": (
4221 build_slice,
4222 TosaTensorGen.tgBasic,
4223 TosaTensorValuesGen.tvgDefault,
4224 TosaArgGen.agSlice,
4225 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004226 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004227 "error_if_validators": (
4228 TosaErrorValidator.evStartSmallerZero,
4229 TosaErrorValidator.evSizeSmallerEqualZero,
4230 TosaErrorValidator.evStartSizeOutsideBounds,
4231 TosaErrorValidator.evSizeOutputShapeMismatch,
4232 TosaErrorValidator.evInputSizeStartLengthMismatch,
4233 TosaErrorValidator.evWrongRank,
4234 TosaErrorValidator.evWrongInputType,
4235 TosaErrorValidator.evWrongOutputType,
4236 TosaErrorValidator.evWrongInputList,
4237 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004238 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004239 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004240 },
4241 "tile": {
4242 "op": Op.TILE,
4243 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004244 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004245 "build_fcn": (
4246 build_tile,
4247 TosaTensorGen.tgBasic,
4248 TosaTensorValuesGen.tvgDefault,
4249 TosaArgGen.agTile,
4250 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004251 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004252 "error_if_validators": (
4253 TosaErrorValidator.evWrongInputType,
4254 TosaErrorValidator.evWrongOutputType,
4255 TosaErrorValidator.evWrongInputList,
4256 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004257 TosaErrorValidator.evRankMismatch,
4258 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004259 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004260 },
4261 "transpose": {
4262 "op": Op.TRANSPOSE,
4263 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004264 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004265 "build_fcn": (
4266 build_transpose,
4267 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004268 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004269 TosaArgGen.agTranspose,
4270 ),
4271 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004272 "error_if_validators": (
4273 TosaErrorValidator.evIndexOutsideBounds,
4274 TosaErrorValidator.evIndexUsedTwice,
4275 TosaErrorValidator.evWrongInputType,
4276 TosaErrorValidator.evWrongOutputType,
4277 TosaErrorValidator.evWrongInputList,
4278 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004279 TosaErrorValidator.evWrongRank,
4280 TosaErrorValidator.evRankMismatch,
4281 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004282 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004283 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004284 # Data nodes
4285 "const": {
4286 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004287 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004288 "build_fcn": (
4289 build_const,
4290 TosaTensorGen.tgBasic,
4291 TosaTensorValuesGen.tvgDefault,
4292 None,
4293 ),
Luke Hutton65872422023-02-20 10:33:04 +00004294 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004295 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004296 "identity": {
4297 "op": Op.IDENTITY,
4298 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004299 "build_fcn": (
4300 build_unary,
4301 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004302 TosaTensorValuesGen.tvgLazyGenDefault,
4303 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004304 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004305 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004306 "data_gen": {
4307 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4308 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004309 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004310 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004311 "gather": {
4312 "op": Op.GATHER,
4313 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4314 "operands": (1, 0),
4315 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004316 "build_fcn": (
4317 build_gather,
4318 TosaTensorGen.tgBasic,
4319 TosaTensorValuesGen.tvgDefault,
4320 None,
4321 ),
James Ward24dbc422022-10-19 12:20:31 +01004322 "types": (
4323 DType.INT8,
4324 DType.INT16,
4325 DType.INT32,
4326 DType.FP16,
4327 DType.BF16,
4328 DType.FP32,
4329 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004330 "error_if_validators": (
4331 TosaErrorValidator.evWrongInputType,
4332 TosaErrorValidator.evWrongOutputType,
4333 TosaErrorValidator.evWrongInputList,
4334 TosaErrorValidator.evWrongOutputList,
4335 TosaErrorValidator.evWrongRank,
4336 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004337 },
4338 "scatter": {
4339 "op": Op.SCATTER,
4340 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004341 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004342 "operands": (2, 0),
4343 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004344 "build_fcn": (
4345 build_scatter,
4346 TosaTensorGen.tgScatter,
4347 TosaTensorValuesGen.tvgDefault,
4348 None,
4349 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004350 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004351 "error_if_validators": (
4352 TosaErrorValidator.evWrongInputType,
4353 TosaErrorValidator.evWrongOutputType,
4354 TosaErrorValidator.evWrongInputList,
4355 TosaErrorValidator.evWrongOutputList,
4356 TosaErrorValidator.evWrongRank,
4357 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004358 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004359 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004360 "resize": {
4361 "op": Op.RESIZE,
4362 "operands": (1, 0),
4363 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004364 "build_fcn": (
4365 build_resize,
4366 TosaTensorGen.tgNHWC,
4367 TosaTensorValuesGen.tvgDefault,
4368 TosaArgGen.agResize,
4369 ),
James Ward24dbc422022-10-19 12:20:31 +01004370 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004371 "invalid_test_validators": (
4372 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004373 ),
4374 "error_if_validators": (
4375 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004376 TosaErrorValidator.evScaleSmallerEqualZero,
4377 TosaErrorValidator.evScaleNLargerMax,
4378 TosaErrorValidator.evScaleDLargerMax,
4379 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004380 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004381 TosaErrorValidator.evBorderSmallerMin,
4382 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004383 TosaErrorValidator.evWrongInputType,
4384 TosaErrorValidator.evWrongOutputType,
4385 TosaErrorValidator.evWrongRank,
4386 TosaErrorValidator.evWrongInputList,
4387 TosaErrorValidator.evWrongOutputList,
4388 TosaErrorValidator.evBatchMismatch,
4389 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004390 TosaErrorValidator.evResizeOutputShapeMismatch,
4391 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004392 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004393 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004394 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004395 "cast": {
4396 "op": Op.CAST,
4397 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004398 "build_fcn": (
4399 build_cast,
4400 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004401 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004402 TosaArgGen.agCast,
4403 ),
James Ward8b390432022-08-12 20:48:56 +01004404 "types": (
4405 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004406 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004407 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004408 DType.INT8,
4409 DType.INT16,
4410 DType.INT32,
4411 DType.BOOL,
4412 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004413 "error_if_validators": (
4414 TosaErrorValidator.evWrongInputType,
4415 TosaErrorValidator.evWrongOutputType,
4416 TosaErrorValidator.evWrongInputList,
4417 TosaErrorValidator.evWrongOutputList,
4418 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004419 "data_gen": {
4420 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4421 },
4422 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004423 },
4424 "rescale": {
4425 "op": Op.RESCALE,
4426 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004427 "build_fcn": (
4428 build_rescale,
4429 TosaTensorGen.tgBasic,
4430 TosaTensorValuesGen.tvgDefault,
4431 TosaArgGen.agRescale,
4432 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004433 "types": [
4434 DType.UINT8,
4435 DType.INT8,
4436 DType.INT16,
4437 DType.INT32,
4438 DType.INT48,
4439 DType.UINT16,
4440 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004441 "error_if_validators": (
4442 TosaErrorValidator.evInputZeroPointNotZero,
4443 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004444 TosaErrorValidator.evU16InputZeroPointNotValid,
4445 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004446 TosaErrorValidator.evScaleTrue,
4447 TosaErrorValidator.evScaleNotTrue,
4448 TosaErrorValidator.evWrongInputType,
4449 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004450 TosaErrorValidator.evWrongInputList,
4451 TosaErrorValidator.evWrongOutputList,
4452 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004453 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004454 # Custom
4455 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004456 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004457 # Two varients of cond_if, one that generates one of two constant tensors (no
4458 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4459 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004460 "cond_if_const": {
4461 "op": Op.COND_IF,
4462 "operands": (0, 2),
4463 "build_fcn": (
4464 build_cond_if_const,
4465 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004466 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004467 TosaArgGen.agCondIf,
4468 ),
4469 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004470 "error_if_validators": (
4471 TosaErrorValidator.evOutputListThenGraphMismatch,
4472 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004473 TosaErrorValidator.evCondIfCondNotMatchingBool,
4474 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004475 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004476 },
4477 "cond_if_binary": {
4478 "op": Op.COND_IF,
4479 "operands": (2, 0),
4480 "build_fcn": (
4481 build_cond_if_binary,
4482 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004483 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004484 TosaArgGen.agCondIf,
4485 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004486 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004487 "error_if_validators": (
4488 TosaErrorValidator.evInputListThenGraphMismatch,
4489 TosaErrorValidator.evInputListElseGraphMismatch,
4490 TosaErrorValidator.evOutputListThenGraphMismatch,
4491 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004492 TosaErrorValidator.evCondIfCondNotMatchingBool,
4493 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004494 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004495 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004496 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004497 "while_loop": {
4498 "op": Op.WHILE_LOOP,
4499 "operands": (0, 1),
4500 "build_fcn": (
4501 build_while_loop,
4502 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004503 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004504 TosaArgGen.agWhileLoop,
4505 ),
4506 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004507 "error_if_validators": (
4508 TosaErrorValidator.evInputListOutputListMismatch,
4509 TosaErrorValidator.evInputListCondGraphMismatch,
4510 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4511 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4512 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004513 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004514 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004515 },
Luke Hutton57287132023-02-06 14:54:18 +00004516 "fft2d": {
4517 "op": Op.FFT2D,
4518 "operands": (2, 0),
4519 "rank": (3, 3),
4520 "build_fcn": (
4521 build_fft2d,
4522 TosaTensorGen.tgFFT2d,
4523 TosaTensorValuesGen.tvgDefault,
4524 TosaArgGen.agFFT2d,
4525 ),
4526 "types": [DType.FP32],
4527 "error_if_validators": (
4528 TosaErrorValidator.evWrongInputType,
4529 TosaErrorValidator.evWrongOutputType,
4530 TosaErrorValidator.evWrongInputList,
4531 TosaErrorValidator.evWrongOutputList,
4532 TosaErrorValidator.evWrongRank,
4533 TosaErrorValidator.evBatchMismatch,
4534 TosaErrorValidator.evKernelNotPowerOfTwo,
4535 TosaErrorValidator.evFFTInputShapeMismatch,
4536 TosaErrorValidator.evFFTOutputShapeMismatch,
4537 ),
4538 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004539 "rfft2d": {
4540 "op": Op.RFFT2D,
4541 "operands": (1, 0),
4542 "rank": (3, 3),
4543 "build_fcn": (
4544 build_rfft2d,
4545 TosaTensorGen.tgRFFT2d,
4546 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004547 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004548 ),
4549 "types": [DType.FP32],
4550 "error_if_validators": (
4551 TosaErrorValidator.evWrongInputType,
4552 TosaErrorValidator.evWrongOutputType,
4553 TosaErrorValidator.evWrongInputList,
4554 TosaErrorValidator.evWrongOutputList,
4555 TosaErrorValidator.evWrongRank,
4556 TosaErrorValidator.evBatchMismatch,
4557 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004558 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004559 ),
4560 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004561 }
4562
Kevin Cheng550ccc52021-03-03 11:21:43 -08004563
Eric Kunzee5e26762020-10-13 16:11:07 -07004564class OutputShaper:
4565 # Methods in this class compute the expected output shape and datatype
4566 # for common classes of operations
4567 def __init__(self):
4568 pass
4569
4570 # These methods return arguments that can be used for
4571 # creating a new output tensor
4572 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004573 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4574 if error_name != ErrorIf.RankMismatch:
4575 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004576 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004577
4578 shape = []
4579 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004580 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004581 shape.append(b.shape[i])
4582 else:
4583 shape.append(a.shape[i])
4584
Jerry Ge135c9552023-05-23 20:59:32 +00004585 fuzz_idx = rng.integers(0, len(a.shape))
4586 if error_name == ErrorIf.DimensionMismatch:
4587 shape[fuzz_idx] += 1
4588
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004589 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004590 all_dtypes = [
4591 DType.INT8,
4592 DType.INT16,
4593 DType.INT32,
4594 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004595 DType.FP16,
4596 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004597 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004598 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004599 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4600 outputDType = rng.choice(wrong_dtypes)
4601 else:
4602 outputDType = a.dtype
4603
4604 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004605
4606 @staticmethod
4607 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004608 assert len(a.shape) == len(b.shape)
4609 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004610
4611 shape = []
4612 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004613 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004614 shape.append(a.shape[i])
4615
Kevin Cheng550ccc52021-03-03 11:21:43 -08004616 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004617
4618 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004619 def unaryOp(ser, rng, a, error_name=None):
4620 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004621 all_dtypes = [
4622 DType.INT8,
4623 DType.INT16,
4624 DType.INT32,
4625 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004626 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004627 DType.FP16,
4628 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004629 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004630 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4631 outputDType = rng.choice(wrong_dtypes)
4632 else:
4633 outputDType = a.dtype
4634
4635 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004636
4637 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004638 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004639 if error_name != ErrorIf.RankMismatch:
4640 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004641 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004642
4643 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004644 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004645 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004646 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4647 else:
4648 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004649
Jerry Ge135c9552023-05-23 20:59:32 +00004650 fuzz_idx = rng.integers(0, len(a.shape))
4651 if error_name == ErrorIf.DimensionMismatch:
4652 shape[fuzz_idx] += 1
4653
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004654 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004655 all_dtypes = [
4656 DType.INT8,
4657 DType.INT16,
4658 DType.INT32,
4659 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004660 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004661 DType.FP16,
4662 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004663 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004664 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4665 outputDType = rng.choice(wrong_dtypes)
4666 else:
4667 outputDType = a.dtype
4668
4669 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004670
4671 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004672 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004673 if error_name != ErrorIf.RankMismatch:
4674 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004675 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004676
4677 # Do broadcast
4678 shape = []
4679 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004680 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004681 shape.append(b.shape[i])
4682 else:
4683 shape.append(a.shape[i])
4684
Jerry Ge135c9552023-05-23 20:59:32 +00004685 fuzz_idx = rng.integers(0, len(a.shape))
4686 if error_name == ErrorIf.DimensionMismatch:
4687 shape[fuzz_idx] += 1
4688
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004689 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004690 wrong_dtypes = [
4691 DType.INT8,
4692 DType.INT16,
4693 DType.INT32,
4694 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004695 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004696 DType.FP16,
4697 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004698 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004699 outputDType = rng.choice(wrong_dtypes)
4700 else:
4701 outputDType = DType.BOOL
4702
4703 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004704
4705 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004706 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004707 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004708 if error_name not in [
4709 ErrorIf.AxisSmallerZero,
4710 ErrorIf.AxisLargerRank,
4711 ErrorIf.ShapeOfAxisNotOne,
4712 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004713 shape[axis] = 1
4714 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4715 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004716
Matthew Haddond6ce7252021-09-29 15:35:44 +01004717 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004718 all_dtypes = [
4719 DType.INT8,
4720 DType.INT16,
4721 DType.INT32,
4722 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004723 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004724 DType.FP16,
4725 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004726 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004727 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4728 outputDType = rng.choice(wrong_dtypes)
4729 else:
4730 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004731
Matthew Haddond6ce7252021-09-29 15:35:44 +01004732 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004733
4734 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004735 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004736 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004737
4738 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4739 del shape[axis]
4740
4741 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4742 remove = rng.choice([True, False])
4743 if remove and len(shape) > 1:
4744 del shape[0]
4745 else:
4746 shape.append(1)
4747 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4748 for i in range(len(shape)):
4749 shape[i] = shape[i] + rng.integers(1, 10)
4750
4751 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004752 all_dtypes = [
4753 DType.INT8,
4754 DType.INT16,
4755 DType.INT32,
4756 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004757 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004758 DType.FP16,
4759 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004760 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004761 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4762 outputDType = rng.choice(wrong_dtypes)
4763 else:
4764 outputDType = DType.INT32
4765
4766 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004767
4768 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004769 def conv2dOp(
4770 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4771 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004772
4773 # IFM: NHWC
4774 # Filter: OHWI
4775 # OFM: NHWC
4776
Kevin Cheng550ccc52021-03-03 11:21:43 -08004777 h = (
4778 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004779 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004780 + padding[0]
4781 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004782 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004783 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004784
Kevin Cheng550ccc52021-03-03 11:21:43 -08004785 w = (
4786 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004787 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004788 + padding[2]
4789 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004790 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004791 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004792
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004793 if error_name == ErrorIf.ConvOutputShapeMismatch:
4794 choices = [1, 2, 3]
4795 change = rng.choice(choices)
4796 # increment in multiples of stride to not hit non-integer error case
4797 if change in [1, 3]:
4798 h = h + (rng.choice(choices) * strides[0])
4799 if change in [2, 3]:
4800 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004801
Eric Kunzee5e26762020-10-13 16:11:07 -07004802 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4803
James Ward8b390432022-08-12 20:48:56 +01004804 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004805 # Pick some potentially correct output dtype if input type is incorrect
4806 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004807 else:
James Ward8b390432022-08-12 20:48:56 +01004808 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004809
4810 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004811 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004812 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004813 else:
4814 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004815 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004816 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004817
Kevin Cheng550ccc52021-03-03 11:21:43 -08004818 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004819
4820 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004821 def conv3dOp(
4822 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4823 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004824
4825 # IFM: NDHWC
4826 # Filter: ODHWI
4827 # OFM: NDHWC
4828
4829 d = (
4830 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004831 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004832 + padding[0]
4833 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004834 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004835 ) // strides[0] + 1
4836
4837 h = (
4838 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004839 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004840 + padding[2]
4841 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004842 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004843 ) // strides[1] + 1
4844
4845 w = (
4846 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004847 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004848 + padding[4]
4849 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004850 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004851 ) // strides[2] + 1
4852
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004853 if error_name == ErrorIf.ConvOutputShapeMismatch:
4854 choices = [1, 2, 3, 4]
4855 change = rng.choice(choices)
4856 # increment in multiples of stride to not hit non-integer error case
4857 if change in [1, 4]:
4858 d = d + (rng.choice(choices) * strides[0])
4859 if change in [2, 4]:
4860 h = h + (rng.choice(choices) * strides[1])
4861 if change in [3, 4]:
4862 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004863
Kevin Cheng1533b852021-09-01 12:51:58 -07004864 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4865
James Ward8b390432022-08-12 20:48:56 +01004866 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004867 # Pick some potentially correct output dtype if input type is incorrect
4868 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004869 else:
James Ward8b390432022-08-12 20:48:56 +01004870 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004871
4872 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004873 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004874 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004875 else:
4876 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004877 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004878 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004879
4880 return ser.addOutput(ofm_shape, out_dtype)
4881
4882 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004883 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004884 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004885 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004886 # IFM: NHWC
4887 # Filter: HWCM
4888 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004889
Kevin Cheng550ccc52021-03-03 11:21:43 -08004890 h = (
4891 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004892 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004893 + padding[0]
4894 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004895 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004896 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004897
Kevin Cheng550ccc52021-03-03 11:21:43 -08004898 w = (
4899 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004900 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004901 + padding[2]
4902 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004903 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004904 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004905
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004906 if error_name == ErrorIf.ConvOutputShapeMismatch:
4907 choices = [1, 2, 3]
4908 change = rng.choice(choices)
4909 # increment in multiples of stride to not hit non-integer error case
4910 if change in [1, 3]:
4911 h = h + (rng.choice(choices) * strides[0])
4912 if change in [2, 3]:
4913 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004914
Eric Kunzee5e26762020-10-13 16:11:07 -07004915 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4916
James Ward8b390432022-08-12 20:48:56 +01004917 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004918 # Pick some potentially correct output dtype if input type is incorrect
4919 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004920 else:
James Ward8b390432022-08-12 20:48:56 +01004921 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004922
4923 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004924 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004925 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004926 else:
4927 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004928 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004929 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004930
Kevin Cheng550ccc52021-03-03 11:21:43 -08004931 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004932
4933 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004934 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004935 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004936 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004937 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004938 h = 1
4939 w = 1
4940 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004941 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4942 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004943
4944 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004945 choices = [1, 2, 3]
4946 change = rng.choice(choices)
4947 # increment in multiples of stride to not hit non-integer error case
4948 if change in [1, 3]:
4949 h = h + (rng.choice(choices) * stride[0])
4950 if change in [2, 3]:
4951 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004952 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004953
4954 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004955 all_dtypes = [
4956 DType.INT8,
4957 DType.INT16,
4958 DType.INT32,
4959 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004960 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004961 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004962 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004963 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004964 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4965 outputDType = rng.choice(wrong_dtypes)
4966 else:
4967 outputDType = ifm.dtype
4968
4969 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004970
4971 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004972 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004973 # input: N, IC
4974 # filter: OC, IC
4975 # output: N, OC
4976
4977 output_shape = [input.shape[0], filter.shape[0]]
4978
James Ward8b390432022-08-12 20:48:56 +01004979 # Validated in arg_gen (also invalidated for ErrorIf)
4980 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004981
Kevin Cheng550ccc52021-03-03 11:21:43 -08004982 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004983
4984 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004985 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004986 # a: N, H, C
4987 # b: N, C, W
4988 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004989
Kevin Cheng2d60f002021-06-09 14:18:32 -07004990 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004991
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004992 if error_name == ErrorIf.WrongOutputType:
4993 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004994 incorrect_types = (
4995 DType.INT4,
4996 DType.INT8,
4997 DType.INT16,
4998 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004999 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005000 DType.FP16,
5001 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005002 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005003 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005004 incorrect_types = (
5005 DType.INT4,
5006 DType.INT8,
5007 DType.INT16,
5008 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005009 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005010 DType.FP16,
5011 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005012 )
James Ward24dbc422022-10-19 12:20:31 +01005013 elif (
5014 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5015 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005016 incorrect_types = (
5017 DType.INT4,
5018 DType.INT8,
5019 DType.INT16,
5020 DType.INT32,
5021 DType.INT48,
5022 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005023 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005024 elif error_name == ErrorIf.WrongInputType:
5025 # Pick some potentially correct output dtype if input type is incorrect
5026 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005027 else:
James Ward8b390432022-08-12 20:48:56 +01005028 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005029
Kevin Cheng550ccc52021-03-03 11:21:43 -08005030 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005031
5032 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005033 def concatOp(ser, rng, axis, inputs, error_name=None):
5034 input1 = inputs[0]
5035 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005036
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005037 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005038 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005039 if not (
5040 # unable to concat tensors of different ranks
5041 error_name == ErrorIf.ConcatInputRankMismatch
5042 # unable to concat tensors along an invalid axis
5043 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005044 ):
5045 for tensor in remaining_inputs:
5046 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005047
Matthew Haddon01c359d2021-10-15 16:30:48 +01005048 if error_name == ErrorIf.ConcatShapeSumMismatch:
5049 output_shape[axis] += rng.integers(5, 10)
5050
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005051 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005052 all_dtypes = {
5053 DType.INT8,
5054 DType.INT16,
5055 DType.INT32,
5056 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005057 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005058 DType.FP16,
5059 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005060 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005061 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5062 outputDType = rng.choice(wrong_dtypes)
5063 else:
5064 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005065
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005066 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005067
5068 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005069 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005070
5071 output_shape = a.shape.copy()
5072
5073 for i in range(len(output_shape)):
5074 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5075
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005076 if error_name == ErrorIf.PadOutputShapeMismatch:
5077 bad_dim = rng.choice(range(len(output_shape)))
5078 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005079 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005080 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005081
Matthew Haddone807aae2021-10-11 18:12:58 +01005082 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005083 all_dtypes = [
5084 DType.INT8,
5085 DType.INT16,
5086 DType.INT32,
5087 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005088 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005089 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005090 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005091 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005092 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5093 outputDType = rng.choice(wrong_dtypes)
5094 else:
5095 outputDType = a.dtype
5096
5097 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005098
5099 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005100 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00005101 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00005102
5103 if error_name == ErrorIf.WrongOutputType:
5104 all_dtypes = [
5105 DType.INT8,
5106 DType.INT16,
5107 DType.INT32,
5108 DType.INT48,
5109 DType.FP32,
5110 DType.FP16,
5111 DType.BF16,
5112 ]
5113 wrong_dtypes = list(set(all_dtypes))
5114 outputDType = rng.choice(wrong_dtypes)
5115 else:
5116 outputDType = DType.SHAPE
5117
5118 return ser.addOutput(output_shape, outputDType)
5119
5120 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005121 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005122 output_shape = shape.copy()
5123
Matthew Haddone807aae2021-10-11 18:12:58 +01005124 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5125 for i in range(len(output_shape)):
5126 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5127
5128 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005129 all_dtypes = [
5130 DType.INT8,
5131 DType.INT16,
5132 DType.INT32,
5133 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005134 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005135 DType.FP16,
5136 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005137 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005138 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5139 outputDType = rng.choice(wrong_dtypes)
5140 else:
5141 outputDType = a.dtype
5142
5143 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005144
5145 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005146 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005147
Matthew Haddone807aae2021-10-11 18:12:58 +01005148 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005149 all_dtypes = [
5150 DType.INT8,
5151 DType.INT16,
5152 DType.INT32,
5153 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005154 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005155 DType.FP16,
5156 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005157 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005158 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005159 outputDType = rng.choice(wrong_dtypes)
5160 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005161 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005162
Luke Huttona4e48ca2023-02-22 11:53:48 +00005163 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005164 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005165 for index in range(len(output_shape)):
5166 if output_shape[index] <= 2:
5167 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5168 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005169 output_shape[index] = output_shape[index] + rng.choice(
5170 [-2, -1, 1, 2]
5171 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005172 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5173 output_shape = input.shape.copy()
5174 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005175 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005176
5177 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005178
5179 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005180 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005181
5182 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005183 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005184
5185 for i in range(len(output_shape)):
5186 output_shape[i] = a.shape[i] * multiples[i]
5187
Luke Huttona4e48ca2023-02-22 11:53:48 +00005188 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005189 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005190
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005191 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005192 all_dtypes = [
5193 DType.INT8,
5194 DType.INT16,
5195 DType.INT32,
5196 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005197 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005198 DType.FP16,
5199 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005200 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005201 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5202 outputDType = rng.choice(wrong_dtypes)
5203 else:
5204 outputDType = a.dtype
5205
5206 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005207
5208 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005209 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005210 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005211
Kevin Cheng550ccc52021-03-03 11:21:43 -08005212 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005213
Luke Huttona4e48ca2023-02-22 11:53:48 +00005214 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005215 for i in range(len(output_shape)):
5216 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005217
Luke Huttona4e48ca2023-02-22 11:53:48 +00005218 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5219 for i in range(len(output_shape)):
5220 output_shape[i] += rng.integers(1, 10)
5221 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005222 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005223
Matthew Haddone807aae2021-10-11 18:12:58 +01005224 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005225 all_dtypes = [
5226 DType.INT8,
5227 DType.INT16,
5228 DType.INT32,
5229 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005230 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005231 DType.FP16,
5232 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005233 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005234 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5235 outputDType = rng.choice(wrong_dtypes)
5236 else:
5237 outputDType = a.dtype
5238
5239 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005240
5241 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005242 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005243 if error_name != ErrorIf.WrongRank:
5244 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005245 assert len(indices.shape) == 2
5246 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005247
Kevin Cheng77d0f762020-11-24 10:26:32 -08005248 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5249
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005250 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005251 all_dtypes = [
5252 DType.INT8,
5253 DType.INT16,
5254 DType.INT32,
5255 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005256 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005257 DType.FP16,
5258 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005259 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005260 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5261 outputDType = rng.choice(wrong_dtypes)
5262 else:
5263 outputDType = values.dtype
5264
5265 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005266
5267 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005268 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005269 if error_name != ErrorIf.WrongRank:
5270 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005271 assert len(indices.shape) == 2
5272 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005273 assert values_in.shape[0] == indices.shape[0] # N
5274 assert input.shape[1] == indices.shape[1] # W
5275 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005276
5277 output_shape = values_in.shape
5278
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005279 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005280 all_dtypes = [
5281 DType.INT8,
5282 DType.INT16,
5283 DType.INT32,
5284 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005285 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005286 DType.FP16,
5287 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005288 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005289 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5290 outputDType = rng.choice(wrong_dtypes)
5291 else:
5292 outputDType = values_in.dtype
5293
5294 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005295
5296 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005297 def tableOp(ser, rng, input, error_name=None):
5298 # Same shape as the input, dtype dependent on input dtype
5299 if error_name != ErrorIf.WrongInputType:
5300 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005301 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005302 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005303 wrong_dtypes = [
5304 DType.INT8,
5305 DType.INT16,
5306 DType.INT32,
5307 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005308 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005309 DType.FP16,
5310 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005311 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005312 wrong_dtypes.remove(output_dtype)
5313 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005314 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005315
5316 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005317 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005318 serializer,
5319 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005320 input,
5321 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005322 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005323 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005324 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005325 input_dtype,
5326 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005327 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005328 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005329 # Calculate OH, OW
5330 scale_y_n = scale[0]
5331 scale_y_d = scale[1]
5332 scale_x_n = scale[2]
5333 scale_x_d = scale[3]
5334 if error_name == ErrorIf.ScaleSmallerEqualZero:
5335 scale_y_n = max(scale_y_n, 1)
5336 scale_y_d = max(scale_y_d, 1)
5337 scale_x_n = max(scale_x_n, 1)
5338 scale_x_d = max(scale_x_d, 1)
5339
5340 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5341 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5342
5343 if error_name is not None:
5344 # Make sure the output tensor is valid, which can occur when
5345 # scale, offset or border have been changed for ERROR_IFs
5346 oh = max(oh, 1)
5347 ow = max(ow, 1)
5348 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005349 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5350 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005351
5352 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5353 choices = [1, 2, 3]
5354 change = rng.choice(choices)
5355 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5356 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005357 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005358 oh -= scale_y_d
5359 assert oh > 0 # Should have been caught in agResize
5360 else:
5361 oh += scale_y_d
5362 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005363 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005364 ow -= scale_x_d
5365 assert ow > 0 # Should have been caught in agResize
5366 else:
5367 ow += scale_x_d
5368
Matthew Haddon848efb42021-09-09 12:30:53 +01005369 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005370 output_dims = [
5371 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005372 oh,
5373 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005374 input.shape[0],
5375 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005376 elif error_name == ErrorIf.BatchMismatch:
5377 output_dims = [
5378 input.shape[0] + rng.integers(1, 10),
5379 oh,
5380 ow,
5381 input.shape[3],
5382 ]
5383 elif error_name == ErrorIf.ChannelMismatch:
5384 output_dims = [
5385 input.shape[0],
5386 oh,
5387 ow,
5388 input.shape[3] + rng.integers(1, 10),
5389 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005390 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005391 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005392
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005393 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005394
5395 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005396 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005397 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005398
5399 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005400 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005401 if error_name == ErrorIf.ConvOutputShapeMismatch:
5402 choices = [1, 2, 3]
5403 change = rng.choice(choices)
5404 if change in [1, 3]:
5405 output_shape[1] = output_shape[1] + rng.choice(choices)
5406 if change in [2, 3]:
5407 output_shape[2] = output_shape[2] + rng.choice(choices)
5408
James Ward8b390432022-08-12 20:48:56 +01005409 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005410 # Pick some potentially correct output dtype if input type is incorrect
5411 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005412 else:
James Ward8b390432022-08-12 20:48:56 +01005413 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005414
5415 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005416 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005417 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005418 else:
5419 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005420 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005421 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005422
Kevin Cheng550ccc52021-03-03 11:21:43 -08005423 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005424
5425 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005426 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5427 outputs = []
5428
5429 assert ifm1.dtype == ifm2.dtype
5430 input_dtype = ifm1.dtype
5431
5432 if error_name != ErrorIf.FFTInputShapeMismatch:
5433 assert ifm1.shape == ifm2.shape
5434
5435 input_shape = ifm1.shape
5436 if error_name != ErrorIf.WrongRank:
5437 assert len(input_shape) == 3
5438
5439 output_shape = input_shape.copy()
5440 output_dtype = input_dtype
5441
5442 if error_name == ErrorIf.WrongOutputType:
5443 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005444 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005445 output_dtype = rng.choice(wrong_dtypes)
5446 elif error_name == ErrorIf.BatchMismatch:
5447 output_shape[0] += rng.integers(1, 10)
5448 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5449 modify_dim = rng.choice([1, 2])
5450 output_shape[modify_dim] += rng.integers(1, 10)
5451
5452 outputs.append(serializer.addOutput(output_shape, output_dtype))
5453 outputs.append(serializer.addOutput(output_shape, output_dtype))
5454 return outputs
5455
5456 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005457 def rfft2dOp(serializer, rng, value, error_name=None):
5458 outputs = []
5459
5460 input_shape = value.shape
5461 if error_name != ErrorIf.WrongRank:
5462 assert len(input_shape) == 3
5463
5464 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5465
5466 output_dtype = value.dtype
5467 if error_name == ErrorIf.WrongOutputType:
5468 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005469 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005470 output_dtype = rng.choice(wrong_dtypes)
5471 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005472 output_shape[0] += rng.integers(1, 10)
5473 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5474 modify_dim = rng.choice([1, 2])
5475 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005476
5477 outputs.append(serializer.addOutput(output_shape, output_dtype))
5478 outputs.append(serializer.addOutput(output_shape, output_dtype))
5479 return outputs