blob: 16021094ed4800a5f2e05fbed75b439681ed9ca9 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
69 vals.append(v)
70 return tuple(sorted(vals))
71
72 self.random_float_range = {}
73 for dtype in (DType.FP32, DType.FP16, DType.BF16):
74 self.random_float_range[dtype] = convertFPRange(
75 args.tensor_fp_value_range,
76 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
77 )
78
Eric Kunzee5e26762020-10-13 16:11:07 -070079 def createSerializer(self, opName, testPath):
80 self.testPath = os.path.join(opName, testPath)
81
82 fullPath = os.path.join(self.basePath, self.testPath)
83 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010084 # Embed const data in the flatbuffer
85 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010086 if self.args.lazy_data_gen:
87 # Lazy data generation - so make constants files
88 constMode = ts.ConstMode.INPUTS
89 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 constMode = ts.ConstMode.EMBED_DUMP
91 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070092
93 def getSerializer(self):
94 return self.ser
95
Jeremy Johnson1271c442023-09-05 11:39:26 +010096 def serialize(self, testName, metaData=None):
97 path = Path(self.basePath) / self.testPath
98
99 # Write out TOSA flatbuffer binary
100 path_fb = path / f"{testName}.tosa"
101 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700102 fd.write(self.ser.serialize())
103
Jeremy Johnson1271c442023-09-05 11:39:26 +0100104 # Get JSON descriptor from serializer
105 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
106
107 if metaData:
108 # Add extra meta data to desc.json
109 desc["meta"] = metaData
110
111 # Validate desc.json before we output it
112 self.descSchemaValidator.validate_config(desc)
113
114 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100115 if "data_gen" in metaData:
116 if self.args.lazy_data_gen:
117 # Output datagen meta data as CPP data
118 path_md = path / f"{testName}_meta_data_gen.cpp"
119 with path_md.open("w") as fd:
120 fd.write(TOSA_AUTOGENERATED_HEADER)
121 fd.write("// Test meta data for data generation setup\n\n")
122 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
123 json.dump(metaData["data_gen"], fd)
124 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100125 if "compliance" in metaData:
126 # Output datagen meta data as CPP data
127 path_md = path / f"{testName}_meta_compliance.cpp"
128 with path_md.open("w") as fd:
129 fd.write(TOSA_AUTOGENERATED_HEADER)
130 fd.write("// Test meta data for compliance validation\n\n")
131 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
132 json.dump(metaData["compliance"], fd)
133 fd.write(')";\n\n')
134
135 # Write desc.json
136 path_desc = path / "desc.json"
137 with path_desc.open("w") as fd:
138 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700139
Matthew Haddon74567092021-07-16 15:38:20 +0100140 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000141 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100142 seed = self.random_seed + 1
143 self.rng = np.random.default_rng(seed)
144
Jeremy Johnson1271c442023-09-05 11:39:26 +0100145 def getDTypeRange(self, dtype, high_inclusive=False):
146 # Returns dtype value range boundaries (low, high)
147 # The high boundary is excluded in the range
148 # unless high_inclusive is True
Jeremy Johnson1271c442023-09-05 11:39:26 +0100149 if dtype in (DType.FP32, DType.FP16, DType.BF16):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100150 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 elif dtype == DType.BOOL:
152 rng = (0, 2)
153 elif dtype == DType.UINT8:
154 rng = (0, 256)
155 elif dtype == DType.UINT16:
156 rng = (0, 65536)
157 elif dtype == DType.INT4:
158 # TOSA specific INT4 weight range from -7 to 7
159 rng = (-7, 8)
160 elif dtype == DType.INT8:
161 rng = (-128, 128)
162 elif dtype == DType.INT16:
163 rng = (-32768, 32768)
164 elif dtype in (DType.INT32, DType.SHAPE):
165 # restricting too large value for SHAPE
166 rng = (-(1 << 31), (1 << 31))
167 elif dtype == DType.INT48:
168 rng = (-(1 << 47), (1 << 47))
169 else:
170 raise Exception("Unknown dtype: {}".format(dtype))
171
172 if not high_inclusive:
173 # Exclusive high: low <= range < high
174 return rng
175 else:
176 # Inclusive range: low <= range <= high
177 return (rng[0], rng[1] - 1)
178
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000179 def getRandTensor(self, shape, dtype, data_range=None):
180 if data_range is None:
181 low, high = self.getDTypeRange(dtype)
182 else:
183 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100184
Eric Kunzee5e26762020-10-13 16:11:07 -0700185 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100188 return np.int64(self.rng.integers(low=low, high=high, size=shape))
189 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
190 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
191
192 if dtype == DType.FP16:
193 return np.float16(f_tensor)
194 else:
195 f32_tensor = np.float32(f_tensor)
196 if dtype == DType.BF16:
197 # Floor the last 16 bits of each f32 value
198 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
199 else:
200 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100202 # All other integer types
203 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700204
Kevin Cheng989cb052021-04-28 16:29:44 -0700205 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 placeholders = []
207
Kevin Cheng989cb052021-04-28 16:29:44 -0700208 assert len(shape_list) == len(dtype_list)
209
Jeremy Johnson1271c442023-09-05 11:39:26 +0100210 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700211 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100212 if not self.args.lazy_data_gen:
213 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700214 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
216 return placeholders
217
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700219 consts = []
220
Kevin Cheng989cb052021-04-28 16:29:44 -0700221 assert len(shape_list) == len(dtype_list)
222
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700224 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100225 if not self.args.lazy_data_gen:
226 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700227 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700228
229 return consts
230
231 def makeShape(self, rank):
232 if self.targetted_shape:
233 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800234 return np.int32(
235 self.rng.integers(
236 low=self.args.tensor_shape_range[0],
237 high=self.args.tensor_shape_range[1],
238 size=rank,
239 )
240 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700241
242 def setTargetShape(self, shape):
243 self.targetted_shape = shape
244
245 def randInt(self, low=0, high=256):
246 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
247
248 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100249 low, high = self.getDTypeRange(dtype)
250
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100251 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100252 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100253 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100254 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100255 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100256 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
257 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 elif dtype == DType.BOOL:
259 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700260 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700261 # Special size
262 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700263
264 return np.int32(self.rng.integers(low, high, size=1))[0]
265
266 def shapeStr(self, shape):
267
268 sStr = []
269 # Convert to strings
270 for i in shape:
271 sStr.append(str(i))
272
Kevin Cheng550ccc52021-03-03 11:21:43 -0800273 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100275 def typeStr(self, dtype):
276 if isinstance(dtype, list) or isinstance(dtype, tuple):
277 assert len(dtype) >= 2
278 strs = [self.typeStr(t) for t in dtype]
279 # Limit types to the first 2 as the 3rd is the accumulator
280 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700281 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100282 if dtype in gtu.DTYPE_ATTRIBUTES:
283 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700284 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100285 raise Exception(
286 "Unknown dtype, cannot convert to string: {}".format(dtype)
287 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700288
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100289 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100290 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100291 if dtype in gtu.DTYPE_ATTRIBUTES:
292 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700293 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100294 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700295
Luke Hutton57287132023-02-06 14:54:18 +0000296 def constrictBatchSize(self, shape):
297 # Limit the batch size unless an explicit target shape set
298 if self.args.max_batch_size and not self.args.target_shapes:
299 shape[0] = min(shape[0], self.args.max_batch_size)
300 return shape
301
James Ward30124a82023-02-02 14:56:33 +0000302 def makeDimension(self):
303 return self.randInt(
304 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
305 )
306
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100307 def tensorComplianceMetaData(
308 self, op, inputType, argsDict, outputTensor, errorName
309 ):
Jeremy Johnson708da822023-11-15 16:25:45 +0000310 # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet
311 UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100312 if (
313 errorName
314 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000315 or (
316 not gtu.dtypeIsSupportedByCompliance(inputType)
317 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
318 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100319 ):
320 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100321 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100322
Jeremy Johnson1271c442023-09-05 11:39:26 +0100323 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100324 compliance_tens = {
325 "mode": None,
326 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
327 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
328 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100329 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
330 mode = gtu.ComplianceMode.DOT_PRODUCT
331 compliance_tens["dot_product_info"] = {
332 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100333 "ks": int(argsDict["ksb"])
334 if "ksb" in argsDict
335 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100336 }
337 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
338 mode = gtu.ComplianceMode.FP_SPECIAL
339 elif "compliance" in op and "ulp" in op["compliance"]:
340 mode = gtu.ComplianceMode.ULP
341 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
342 elif op["op"] == Op.REDUCE_PRODUCT:
343 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnson9a758382023-11-07 16:27:35 +0000344 elif op["op"] in (Op.EXP, Op.POW):
345 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnson1271c442023-09-05 11:39:26 +0100346 else:
347 mode = gtu.ComplianceMode.EXACT
348 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
349
350 return compliance_tens
351
352 # Build Op functions
353 # Create the output tensor (calling OutputShaper as needed)
354 # Do final tweaks to attributes (if necessary for errorIf)
355 # Add Op into graph
356 # Return resulting tensor information or BuildInfo
357
358 class BuildInfo:
359 """Enhanced build information containing result tensor and associated compliance dict."""
360
361 def __init__(self, resultTensor, complianceDict):
362 self.resultTensor = resultTensor
363 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700364
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000365 def build_unary(
366 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
367 ):
368 assert len(inputs) == 1
369 a = inputs[0]
370 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100371
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000372 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100373
374 # Ensure new output type has correct qinfo
375 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000376 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000377 qinfo = [
378 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000379 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000380 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100381
382 # Invalidate Input/Output list for error if checks.
383 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000384 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100385 pCount, cCount = op["operands"]
386 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000387 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
388 self, error_name, input_list, output_list
389 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100390
Les Bell729b0352021-11-24 10:28:21 +0000391 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100392 self.ser,
393 validator_fcns,
394 error_name,
395 op=op,
396 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000397 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000398 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000399 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100400 input_list=input_list,
401 output_list=output_list,
402 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000403 ):
404 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100405
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000406 attr = None
407 if op["op"] == Op.NEGATE:
408 attr = ts.TosaSerializerAttribute()
409 attr.NegateAttribute(qinfo[0], qinfo[1])
410
411 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000412
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000413 compliance = self.tensorComplianceMetaData(
414 op, a.dtype, args_dict, result_tensor, error_name
415 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000416 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700417
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000418 def build_binary_broadcast(
419 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
420 ):
421 assert len(inputs) == 2
422 a, b = inputs
423 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000424 self.ser, self.rng, a, b, error_name
425 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100426
427 # Invalidate Input/Output list for error if checks.
428 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000429 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100430 pCount, cCount = op["operands"]
431 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000432 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
433 self, error_name, input_list, output_list
434 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100435
Les Bell729b0352021-11-24 10:28:21 +0000436 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100437 self.ser,
438 validator_fcns,
439 error_name,
440 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000441 input1=a,
442 input2=b,
443 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000444 output_dtype=result_tensor.dtype,
445 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100446 input_list=input_list,
447 output_list=output_list,
448 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000449 ):
450 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100451
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000452 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000453
Jeremy Johnson9a758382023-11-07 16:27:35 +0000454 compliance = self.tensorComplianceMetaData(
455 op, a.dtype, args_dict, result_tensor, error_name
456 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000457
458 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700459
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100460 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700461 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000462 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700463 return result_tens
464
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000465 def build_arithmetic_right_shift(
466 self, op, a, b, round, validator_fcns=None, error_name=None
467 ):
468 result_tens = OutputShaper.binaryBroadcastOp(
469 self.ser, self.rng, a, b, error_name
470 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100471
472 # Invalidate Input/Output list for error if checks.
473 input_list = [a.name, b.name]
474 output_list = [result_tens.name]
475 pCount, cCount = op["operands"]
476 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000477 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
478 self, error_name, input_list, output_list
479 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100480
Les Bell729b0352021-11-24 10:28:21 +0000481 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100482 self.ser,
483 validator_fcns,
484 error_name,
485 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000486 input1=a,
487 input2=b,
488 input_dtype=a.dtype,
489 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000490 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100491 input_list=input_list,
492 output_list=output_list,
493 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000494 ):
495 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800496
497 attr = ts.TosaSerializerAttribute()
498 attr.ArithmeticRightShiftAttribute(round)
499
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000500 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800501 return result_tens
502
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100503 def build_mul(
504 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
505 ):
506 assert len(inputs) == 2
507 a, b = inputs
508 shift = args_dict["shift"]
509
510 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000511 self.ser, self.rng, a, b, error_name
512 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700513
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100514 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100515 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100516 result_tensor.setDtype(DType.INT32)
517
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100518 if error_name == ErrorIf.WrongOutputType:
519 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
520 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100521 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100522
523 # Invalidate Input/Output list for error if checks.
524 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100525 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100526 pCount, cCount = op["operands"]
527 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000528 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
529 self, error_name, input_list, output_list
530 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100531
Les Bell729b0352021-11-24 10:28:21 +0000532 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100533 self.ser,
534 validator_fcns,
535 error_name,
536 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000537 input1=a,
538 input2=b,
539 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100540 output_dtype=result_tensor.dtype,
541 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100542 input_list=input_list,
543 output_list=output_list,
544 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000545 ):
546 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700547
Kevin Chengaee1fac2020-11-11 13:54:06 -0800548 attr = ts.TosaSerializerAttribute()
549 attr.MulAttribute(shift)
550
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000551 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100552
553 compliance = self.tensorComplianceMetaData(
554 op, a.dtype, args_dict, result_tensor, error_name
555 )
556
557 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700558
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100559 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
560 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700561
Kevin Chengfe392ce2021-10-18 21:51:55 +0000562 attr = ts.TosaSerializerAttribute()
563 attr.TableAttribute(table)
564
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100565 # Invalidate Input/Output list for error if checks.
566 input_list = [a.name]
567 output_list = [result_tens.name]
568 pCount, cCount = op["operands"]
569 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000570 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
571 self, error_name, input_list, output_list
572 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100573
Les Bell729b0352021-11-24 10:28:21 +0000574 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100575 self.ser,
576 validator_fcns,
577 error_name,
578 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000579 input_shape=a.shape,
580 input_dtype=a.dtype,
581 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000582 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100583 input_list=input_list,
584 output_list=output_list,
585 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000586 ):
587 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100588
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000589 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700590
591 return result_tens
592
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100593 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
594 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
595
596 # Invalidate Input/Output list for error if checks.
597 input_list = [cond.name, a.name, b.name]
598 output_list = [result_tens.name]
599 pCount, cCount = op["operands"]
600 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000601 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
602 self, error_name, input_list, output_list
603 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100604
Les Bell729b0352021-11-24 10:28:21 +0000605 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100606 self.ser,
607 validator_fcns,
608 error_name,
609 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000610 input1=cond,
611 input2=a,
612 input3=b,
613 input_shape=a.shape,
614 input_dtype=a.dtype,
615 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000616 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100617 input_list=input_list,
618 output_list=output_list,
619 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000620 ):
621 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000623 self.ser.addOperator(
624 op["op"],
625 input_list,
626 output_list,
627 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700628 return result_tens
629
Jeremy Johnsona0150012023-11-15 15:52:06 +0000630 def build_comparison(
631 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
632 ):
633 assert len(inputs) == 2
634 a, b = inputs
635
636 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000637 self.ser, self.rng, a, b, error_name
638 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100639
640 # Invalidate Input/Output list for error if checks.
641 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000642 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100643 pCount, cCount = op["operands"]
644 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000645 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
646 self, error_name, input_list, output_list
647 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100648
Les Bell729b0352021-11-24 10:28:21 +0000649 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100650 self.ser,
651 validator_fcns,
652 error_name,
653 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000654 input1=a,
655 input2=b,
656 input_shape=a.shape,
657 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000658 output_shape=result_tensor.shape,
659 output_dtype=result_tensor.dtype,
660 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100661 input_list=input_list,
662 output_list=output_list,
663 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000664 ):
665 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100666
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 self.ser.addOperator(
668 op["op"],
669 input_list,
670 output_list,
671 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000672
673 compliance = self.tensorComplianceMetaData(
674 op, a.dtype, args_dict, result_tensor, error_name
675 )
676 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700677
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000678 def build_argmax(
679 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
680 ):
681 assert len(inputs) == 1
682 a = inputs[0]
683 axis = args_dict["axis"]
684 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100685
686 # Invalidate Input/Output list for error if checks.
687 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000688 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100689 pCount, cCount = op["operands"]
690 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000691 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
692 self, error_name, input_list, output_list
693 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100694
Les Bell729b0352021-11-24 10:28:21 +0000695 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100696 self.ser,
697 validator_fcns,
698 error_name,
699 op=op,
700 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000701 input_shape=a.shape,
702 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000703 output_shape=result_tensor.shape,
704 output_dtype=result_tensor.dtype,
705 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100706 input_list=input_list,
707 output_list=output_list,
708 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000709 ):
710 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700711
712 attr = ts.TosaSerializerAttribute()
713 attr.AxisAttribute(axis)
714
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000715 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000716
717 compliance = self.tensorComplianceMetaData(
718 op, inputs[0].dtype, args_dict, result_tensor, error_name
719 )
720 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700721
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000722 def build_pool2d(
723 self,
724 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100725 inputs,
726 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000727 validator_fcns=None,
728 error_name=None,
729 qinfo=None,
730 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100731 assert len(inputs) == 1
732 input = inputs[0]
733 # max_pool has no accum_dtype
734 accum_dtype = (
735 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
736 )
737 stride = args_dict["stride"]
738 pad = args_dict["pad"]
739 kernel = args_dict["kernel"]
740
Jeremy Johnson0601f802023-11-08 16:28:09 +0000741 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000742 self.ser, self.rng, input, kernel, stride, pad, error_name
743 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100744
745 # Ensure new output type has correct qinfo
746 if error_name == ErrorIf.WrongInputType:
747 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000748 qinfo = [
749 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000750 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000751 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100752
753 # Invalidate Input/Output list for error if checks.
754 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000755 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100756 pCount, cCount = op["operands"]
757 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000758 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
759 self, error_name, input_list, output_list
760 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100761
Les Bell729b0352021-11-24 10:28:21 +0000762 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100763 self.ser,
764 validator_fcns,
765 error_name,
766 op=op,
767 input_shape=input.shape,
768 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000769 output_shape=result_tensor.shape,
770 output_dtype=result_tensor.dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100771 kernel=kernel,
772 stride=stride,
773 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000774 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000775 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100776 input_list=input_list,
777 output_list=output_list,
778 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000779 ):
780 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700781
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000782 if qinfo is None:
783 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700784
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000785 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100786 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000787
788 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700789
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100790 compliance = self.tensorComplianceMetaData(
791 op, inputs[0].dtype, args_dict, result_tensor, error_name
792 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100793
794 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100795
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000796 def build_conv2d(
797 self,
798 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100799 inputs,
800 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000801 validator_fcns=None,
802 error_name=None,
803 qinfo=None,
804 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100805 assert len(inputs) == 3
806 ifm, filter, bias = inputs
807 accum_dtype = args_dict["acc_type"]
808 strides = args_dict["stride"]
809 padding = args_dict["pad"]
810 dilations = args_dict["dilation"]
811
Kevin Cheng550ccc52021-03-03 11:21:43 -0800812 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100813 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100814 self.ser,
815 self.rng,
816 ifm,
817 filter,
818 accum_dtype,
819 strides,
820 padding,
821 dilations,
822 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000823 )
824
825 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000826 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
827 DType.INT8,
828 DType.UINT8,
829 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000830 qinfo = [
831 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100832 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000833 ]
Les Bell0e027d42021-11-09 14:42:14 +0000834
835 # Invalidate Input/Output list for error_if checks.
836 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100837 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000838 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000839 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
840 self, error_name, input_list, output_list
841 )
Les Bell0e027d42021-11-09 14:42:14 +0000842
Les Bell729b0352021-11-24 10:28:21 +0000843 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000844 self.ser,
845 validator_fcns,
846 error_name,
847 op=op,
848 input_dtype=ifm.dtype,
849 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100850 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000851 qinfo=qinfo,
852 input_list=input_list,
853 num_operands=num_operands,
854 output_list=output_list,
855 pad=padding,
856 stride=strides,
857 dilation=dilations,
858 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100859 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100860 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000861 ):
862 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700863
Tai Lyd3797f02023-11-15 23:06:19 +0000864 # TODO - Test local_bound, for now set local bound attribute to False
865 local_bound = False
866
Eric Kunzee5e26762020-10-13 16:11:07 -0700867 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000868 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700869
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000870 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100871
872 compliance = self.tensorComplianceMetaData(
873 op, ifm.dtype, args_dict, result_tensor, error_name
874 )
875
876 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700877
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000878 def build_conv3d(
879 self,
880 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100881 inputs,
882 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000883 validator_fcns=None,
884 error_name=None,
885 qinfo=None,
886 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100887 assert len(inputs) == 3
888 ifm, filter, bias = inputs
889 accum_dtype = args_dict["acc_type"]
890 strides = args_dict["stride"]
891 padding = args_dict["pad"]
892 dilations = args_dict["dilation"]
893
Kevin Cheng1533b852021-09-01 12:51:58 -0700894 assert len(padding) == 6
895 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100896 self.ser,
897 self.rng,
898 ifm,
899 filter,
900 accum_dtype,
901 strides,
902 padding,
903 dilations,
904 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000905 )
906
907 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000908 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
909 DType.INT8,
910 DType.UINT8,
911 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000912 qinfo = [
913 TosaQuantGen.getZeroPoint(self, ifm.dtype),
914 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
915 ]
Les Bell0e027d42021-11-09 14:42:14 +0000916
917 # Invalidate Input/Output list for error_if checks.
918 input_list = [ifm.name, filter.name, bias.name]
919 output_list = [result_tens.name]
920 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000921 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
922 self, error_name, input_list, output_list
923 )
Les Bell0e027d42021-11-09 14:42:14 +0000924
Les Bell729b0352021-11-24 10:28:21 +0000925 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000926 self.ser,
927 validator_fcns,
928 error_name,
929 op=op,
930 input_dtype=ifm.dtype,
931 weight_dtype=filter.dtype,
932 output_dtype=result_tens.dtype,
933 qinfo=qinfo,
934 input_list=input_list,
935 num_operands=num_operands,
936 output_list=output_list,
937 pad=padding,
938 stride=strides,
939 dilation=dilations,
940 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100941 weight_shape=filter.shape,
942 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000943 ):
944 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700945
Tai Lyd3797f02023-11-15 23:06:19 +0000946 # TODO - Test local_bound, for now set local bound attribute to False
947 local_bound = False
948
Kevin Cheng1533b852021-09-01 12:51:58 -0700949 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000950 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700951
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000952 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700953 return result_tens
954
Kevin Cheng550ccc52021-03-03 11:21:43 -0800955 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000956 self,
957 op,
958 ifm,
959 filter,
960 bias,
James Ward8b390432022-08-12 20:48:56 +0100961 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000962 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700963 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000964 output_shape,
965 validator_fcns=None,
966 error_name=None,
967 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800968 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700969 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000970 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100971 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000972 )
Les Bell0e027d42021-11-09 14:42:14 +0000973
974 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000975 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
976 DType.INT8,
977 DType.UINT8,
978 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000979 qinfo = [
980 TosaQuantGen.getZeroPoint(self, ifm.dtype),
981 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
982 ]
Les Bell0e027d42021-11-09 14:42:14 +0000983
984 # Invalidate Input/Output list for error_if checks.
985 input_list = [ifm.name, filter.name, bias.name]
986 output_list = [result_tens.name]
987 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000988 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
989 self, error_name, input_list, output_list
990 )
Les Bell0e027d42021-11-09 14:42:14 +0000991
Les Bell729b0352021-11-24 10:28:21 +0000992 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000993 self.ser,
994 validator_fcns,
995 error_name,
996 op=op,
997 input_dtype=ifm.dtype,
998 weight_dtype=filter.dtype,
999 output_dtype=result_tens.dtype,
1000 qinfo=qinfo,
1001 input_list=input_list,
1002 num_operands=num_operands,
1003 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001004 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +00001005 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +00001006 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001007 weight_shape=filter.shape,
1008 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001009 ):
1010 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001011
Tai Lyd3797f02023-11-15 23:06:19 +00001012 # TODO - Test local_bound, for now set local bound attribute to False
1013 local_bound = False
1014
Eric Kunzee5e26762020-10-13 16:11:07 -07001015 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001016 attr.TransposeConvAttribute(
1017 out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
1018 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001019
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001020 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001021 return result_tens
1022
Kevin Cheng550ccc52021-03-03 11:21:43 -08001023 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001024 self,
1025 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001026 inputs,
1027 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001028 validator_fcns=None,
1029 error_name=None,
1030 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001031 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001032 assert len(inputs) == 3
1033 ifm, filter, bias = inputs
1034 accum_dtype = args_dict["acc_type"]
1035 strides = args_dict["stride"]
1036 padding = args_dict["pad"]
1037 dilations = args_dict["dilation"]
1038
Kevin Cheng550ccc52021-03-03 11:21:43 -08001039 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001040 self.ser,
1041 self.rng,
1042 ifm,
1043 filter,
1044 accum_dtype,
1045 strides,
1046 padding,
1047 dilations,
1048 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001049 )
1050
1051 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001052 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1053 DType.INT8,
1054 DType.UINT8,
1055 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001056 qinfo = [
1057 TosaQuantGen.getZeroPoint(self, ifm.dtype),
1058 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
1059 ]
Les Bell0e027d42021-11-09 14:42:14 +00001060
1061 # Invalidate Input/Output list for error_if checks.
1062 input_list = [ifm.name, filter.name, bias.name]
1063 output_list = [result_tens.name]
1064 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001065 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1066 self, error_name, input_list, output_list
1067 )
Les Bell0e027d42021-11-09 14:42:14 +00001068
Les Bell729b0352021-11-24 10:28:21 +00001069 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001070 self.ser,
1071 validator_fcns,
1072 error_name,
1073 op=op,
1074 input_dtype=ifm.dtype,
1075 weight_dtype=filter.dtype,
1076 output_dtype=result_tens.dtype,
1077 qinfo=qinfo,
1078 input_list=input_list,
1079 num_operands=num_operands,
1080 output_list=output_list,
1081 pad=padding,
1082 stride=strides,
1083 dilation=dilations,
1084 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001085 weight_shape=filter.shape,
1086 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001087 ):
1088 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001089
Tai Lyd3797f02023-11-15 23:06:19 +00001090 # TODO - Test local_bound, for now set local bound attribute to False
1091 local_bound = False
1092
Eric Kunzee5e26762020-10-13 16:11:07 -07001093 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001094 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001095
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001096 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001097 return result_tens
1098
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001099 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001100 self,
1101 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001102 inputs,
1103 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001104 validator_fcns=None,
1105 error_name=None,
1106 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001107 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001108 assert len(inputs) == 3
1109 ifm, filter, bias = inputs
1110 accum_dtype = args_dict["acc_type"]
1111
1112 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001113 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001114 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001115
1116 # Invalidate Input/Output list for error if checks.
1117 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001118 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001119 pCount, cCount = op["operands"]
1120 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001121 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1122 self, error_name, input_list, output_list
1123 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001124
Les Bell729b0352021-11-24 10:28:21 +00001125 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001126 self.ser,
1127 validator_fcns,
1128 error_name,
1129 op=op,
1130 input_shape=ifm.shape,
1131 input_dtype=ifm.dtype,
1132 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001133 output_shape=result_tensor.shape,
1134 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001135 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001136 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001137 input_list=input_list,
1138 output_list=output_list,
1139 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001140 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001141 ):
1142 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001143
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001144 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001145 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001146
1147 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001148
1149 compliance = self.tensorComplianceMetaData(
1150 op, ifm.dtype, args_dict, result_tensor, error_name
1151 )
1152
1153 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001154
James Ward8b390432022-08-12 20:48:56 +01001155 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001156 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001157 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001158 assert len(inputs) == 2
1159 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001160 accum_dtype = args_dict["acc_type"]
1161 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001162 self.ser, self.rng, a, b, accum_dtype, error_name
1163 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001164
1165 # Invalidate Input/Output list for error if checks.
1166 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001167 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001168 pCount, cCount = op["operands"]
1169 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001170 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1171 self, error_name, input_list, output_list
1172 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001173
Les Bell729b0352021-11-24 10:28:21 +00001174 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001175 self.ser,
1176 validator_fcns,
1177 error_name,
1178 op=op,
1179 input_shape=a.shape,
1180 input_dtype=a.dtype,
1181 input2_shape=b.shape,
1182 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001183 output_shape=result_tensor.shape,
1184 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001185 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001186 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001187 input_list=input_list,
1188 output_list=output_list,
1189 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001190 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001191 ):
1192 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001193
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001194 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001195 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001196
1197 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001198
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001199 compliance = self.tensorComplianceMetaData(
1200 op, a.dtype, args_dict, result_tensor, error_name
1201 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001202
1203 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001204
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001205 def build_reduce(
1206 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1207 ):
1208 assert len(inputs) == 1
1209 a = inputs[0]
1210 axis = args_dict["axis"]
1211 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001212
1213 # Invalidate Input/Output list for error if checks.
1214 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001215 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001216 pCount, cCount = op["operands"]
1217 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001218 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1219 self, error_name, input_list, output_list
1220 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001221
Les Bell729b0352021-11-24 10:28:21 +00001222 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001223 self.ser,
1224 validator_fcns,
1225 error_name,
1226 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001227 axis=axis,
1228 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001229 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001230 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001231 output_dtype=result_tensor.dtype,
1232 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001233 input_list=input_list,
1234 output_list=output_list,
1235 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001236 ):
1237 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001238
1239 attr = ts.TosaSerializerAttribute()
1240 attr.AxisAttribute(axis)
1241
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001242 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001243
1244 if op["op"] == Op.REDUCE_PRODUCT:
1245 # TODO: Add compliance support!
1246 compliance = None
1247 else:
1248 compliance = self.tensorComplianceMetaData(
1249 op, a.dtype, args_dict, result_tensor, error_name
1250 )
1251
1252 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001253
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001254 def build_clamp(
1255 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1256 ):
1257 assert len(inputs) == 1
1258 a = inputs[0]
1259
1260 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001261
Jeremy Johnson18e26662021-07-22 16:15:29 +01001262 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001263
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001264 if error_name == ErrorIf.MaxSmallerMin:
1265 # Make sure the numbers are different to invoke this error
1266 while v[0] == v[1]:
1267 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1268 max_val = min(v)
1269 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001270 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001271 max_val = max(v)
1272 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001273
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001274 # Invalidate Input/Output list for error if checks.
1275 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001276 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001277 pCount, cCount = op["operands"]
1278 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001279 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1280 self, error_name, input_list, output_list
1281 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001282
Les Bell729b0352021-11-24 10:28:21 +00001283 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001284 self.ser,
1285 validator_fcns,
1286 error_name,
1287 op=op,
1288 max_val=max_val,
1289 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001290 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001291 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001292 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001293 output_dtype=result_tensor.dtype,
1294 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001295 input_list=input_list,
1296 output_list=output_list,
1297 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001298 ):
1299 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001300
1301 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001302 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1303 if a.dtype == DType.FP16:
1304 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1305 min_val = min_val.astype(np.float32)
1306 max_val = max_val.astype(np.float32)
1307
1308 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001309 else:
James Ward34071252022-12-07 15:48:47 +00001310 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001311
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001312 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001313
1314 compliance = self.tensorComplianceMetaData(
1315 op, a.dtype, args_dict, result_tensor, error_name
1316 )
1317
1318 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001319
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001320 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1321 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001322 attr = ts.TosaSerializerAttribute()
1323
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001324 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001325
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001326 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001327 return result_tens
1328
1329 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001330 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1331 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001332
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001333 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001334 return result_tens
1335
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001336 def build_activation(
1337 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1338 ):
1339 assert len(inputs) == 1
1340 a = inputs[0]
1341
1342 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001343
1344 # Invalidate Input/Output list for error if checks.
1345 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001346 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001347 pCount, cCount = op["operands"]
1348 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001349 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1350 self, error_name, input_list, output_list
1351 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001352
Les Bell729b0352021-11-24 10:28:21 +00001353 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001354 self.ser,
1355 validator_fcns,
1356 error_name,
1357 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001358 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001359 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001360 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001361 output_dtype=result_tensor.dtype,
1362 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001363 input_list=input_list,
1364 output_list=output_list,
1365 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001366 ):
1367 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001368
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001369 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001370
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001371 compliance = self.tensorComplianceMetaData(
1372 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001373 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001375 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001376
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001377 def build_concat(
1378 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1379 ):
1380 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001381 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001382 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001383
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001384 result_tensor = OutputShaper.concatOp(
1385 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001386 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001387
Matthew Haddon818ab902021-07-27 09:12:49 +01001388 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001389 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001390 input_tensor_names.append(tensor.name)
1391
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001392 # Invalidate Input/Output list for error if checks.
1393 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001394 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001395 pCount, cCount = op["operands"]
1396 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001397 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1398 self, error_name, input_list, output_list
1399 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001400
Les Bell729b0352021-11-24 10:28:21 +00001401 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001402 self.ser,
1403 validator_fcns,
1404 error_name,
1405 op=op,
1406 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001407 input_shape=inputs[0].shape,
1408 output_shape=result_tensor.shape,
1409 input_dtype=inputs[0].dtype,
1410 output_dtype=result_tensor.dtype,
1411 inputs=inputs,
1412 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001413 input_list=input_list,
1414 output_list=output_list,
1415 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001416 ):
1417 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001418
1419 attr = ts.TosaSerializerAttribute()
1420 attr.AxisAttribute(axis)
1421
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001422 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001423 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001424
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001425 def build_pad(
1426 self,
1427 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001428 inputs,
1429 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001430 validator_fcns=None,
1431 error_name=None,
1432 qinfo=None,
1433 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001434 assert len(inputs) == 1
1435 a = inputs[0]
1436 padding = args_dict["pad"]
1437 pad_const_int = args_dict["pad_const_int"]
1438 pad_const_float = args_dict["pad_const_fp"]
1439
1440 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001441
Kevin Chengfe392ce2021-10-18 21:51:55 +00001442 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001443 attr.PadAttribute(
1444 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1445 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001446
Matthew Haddone807aae2021-10-11 18:12:58 +01001447 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001448 input_list = [a.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001449 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001450 pCount, cCount = op["operands"]
1451 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001452 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1453 self, error_name, input_list, output_list
1454 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001455
Les Bell729b0352021-11-24 10:28:21 +00001456 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001457 self.ser,
1458 validator_fcns,
1459 error_name,
1460 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001461 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001462 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001463 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001464 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001465 pad=padding,
1466 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001467 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001468 input_list=input_list,
1469 output_list=output_list,
1470 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001471 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001472 ):
1473 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001474
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001475 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001476
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001477 compliance = self.tensorComplianceMetaData(
1478 op, a.dtype, args_dict, result_tensor, error_name
1479 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001480
1481 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001482
Won Jeona21b2e82023-08-10 10:33:01 +00001483 def build_dim(
1484 self,
1485 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001486 inputs,
1487 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001488 validator_fcns=None,
1489 error_name=None,
1490 qinfo=None,
1491 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001492 assert len(inputs) == 1
1493 a = inputs[0]
1494 axis = args_dict["axis"]
1495 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001496
1497 # Invalidate Input/Output list for error if checks.
1498 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001499 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001500 pCount, cCount = op["operands"]
1501 num_operands = pCount + cCount
1502 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1503 self, error_name, input_list, output_list
1504 )
1505
1506 if not TosaErrorValidator.evValidateErrorIfs(
1507 self.ser,
1508 validator_fcns,
1509 error_name,
1510 op=op,
1511 axis=axis,
1512 input_shape=a.shape,
1513 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001514 output_shape=result_tensor.shape,
1515 output_dtype=result_tensor.dtype,
1516 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001517 input_list=input_list,
1518 output_list=output_list,
1519 num_operands=num_operands,
1520 ):
1521 return None
1522
1523 attr = ts.TosaSerializerAttribute()
1524 attr.AxisAttribute(axis)
1525
1526 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001527 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001528
Matthew Haddone807aae2021-10-11 18:12:58 +01001529 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001530 result_tens = OutputShaper.reshapeOp(
1531 self.ser, self.rng, a, newShape, error_name
1532 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001533
1534 # Invalidate Input/Output list for error if checks.
1535 input_list = [a.name]
1536 output_list = [result_tens.name]
1537 pCount, cCount = op["operands"]
1538 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001539 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1540 self, error_name, input_list, output_list
1541 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001542
Les Bell729b0352021-11-24 10:28:21 +00001543 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001544 self.ser,
1545 validator_fcns,
1546 error_name,
1547 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001548 input_shape=a.shape,
1549 output_shape=result_tens.shape,
1550 input_dtype=a.dtype,
1551 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001552 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001553 input_list=input_list,
1554 output_list=output_list,
1555 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001556 ):
1557 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001558
1559 attr = ts.TosaSerializerAttribute()
1560 attr.ReshapeAttribute(newShape)
1561
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001562 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001563 return result_tens
1564
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001565 def build_reverse(
1566 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1567 ):
1568 assert len(inputs) == 1
1569 a = inputs[0]
1570 axis = args_dict["axis"]
1571 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001572
1573 # Invalidate Input/Output list for error if checks.
1574 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001575 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001576 pCount, cCount = op["operands"]
1577 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001578 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1579 self, error_name, input_list, output_list
1580 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001581
Les Bell729b0352021-11-24 10:28:21 +00001582 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001583 self.ser,
1584 validator_fcns,
1585 error_name,
1586 op=op,
1587 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001588 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001589 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001590 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001591 output_dtype=result_tensor.dtype,
1592 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001593 input_list=input_list,
1594 output_list=output_list,
1595 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001596 ):
1597 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001598
1599 attr = ts.TosaSerializerAttribute()
1600 attr.AxisAttribute(axis)
1601
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001602 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001603 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001604
Matthew Haddone807aae2021-10-11 18:12:58 +01001605 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1606 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001607
Kevin Chengfe392ce2021-10-18 21:51:55 +00001608 attr = ts.TosaSerializerAttribute()
1609 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001610
Matthew Haddone807aae2021-10-11 18:12:58 +01001611 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001612 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001613 output_list = [result_tens.name]
1614 pCount, cCount = op["operands"]
1615 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001616 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1617 self, error_name, input_list, output_list
1618 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001619
Les Bell729b0352021-11-24 10:28:21 +00001620 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001621 self.ser,
1622 validator_fcns,
1623 error_name,
1624 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001625 input_shape=a.shape,
1626 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001627 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001628 input_dtype=a.dtype,
1629 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001630 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001631 input_list=input_list,
1632 output_list=output_list,
1633 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001634 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001635 ):
1636 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001637
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001638 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001639 return result_tens
1640
Matthew Haddone807aae2021-10-11 18:12:58 +01001641 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001642 result_tens = OutputShaper.sliceOp(
1643 self.ser, self.rng, a, start, size, error_name
1644 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001645
1646 # Invalidate Input/Output list for error if checks.
1647 input_list = [a.name]
1648 output_list = [result_tens.name]
1649 pCount, cCount = op["operands"]
1650 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001651 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1652 self, error_name, input_list, output_list
1653 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001654
Les Bell729b0352021-11-24 10:28:21 +00001655 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001656 self.ser,
1657 validator_fcns,
1658 error_name,
1659 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001660 input_shape=a.shape,
1661 output_shape=result_tens.shape,
1662 input_dtype=a.dtype,
1663 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001664 start=start,
1665 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001666 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001667 input_list=input_list,
1668 output_list=output_list,
1669 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001670 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001671 ):
1672 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001673
1674 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001675 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001676
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001677 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001678 return result_tens
1679
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001680 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1681 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1682
1683 # Invalidate Input/Output list for error if checks.
1684 input_list = [a.name]
1685 output_list = [result_tens.name]
1686 pCount, cCount = op["operands"]
1687 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001688 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1689 self, error_name, input_list, output_list
1690 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001691
Les Bell729b0352021-11-24 10:28:21 +00001692 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001693 self.ser,
1694 validator_fcns,
1695 error_name,
1696 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001697 input_shape=a.shape,
1698 output_shape=result_tens.shape,
1699 input_dtype=a.dtype,
1700 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001701 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001702 input_list=input_list,
1703 output_list=output_list,
1704 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001705 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001706 ):
1707 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001708
1709 attr = ts.TosaSerializerAttribute()
1710 attr.TileAttribute(multiples)
1711
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001712 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001713 return result_tens
1714
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001715 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001716
1717 # Create a new indicies tensor
1718 # here with data that doesn't exceed the dimensions of the values tensor
1719
Kevin Cheng550ccc52021-03-03 11:21:43 -08001720 K = values.shape[1] # K
1721 W = self.randInt(
1722 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1723 ) # W
1724 indicies_arr = np.int32(
1725 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1726 ) # (N, W)
1727 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001728
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001729 result_tens = OutputShaper.gatherOp(
1730 self.ser, self.rng, values, indicies, error_name
1731 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001732
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001733 # Invalidate Input/Output list for error if checks.
1734 input_list = [values.name, indicies.name]
1735 output_list = [result_tens.name]
1736 pCount, cCount = op["operands"]
1737 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001738 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1739 self, error_name, input_list, output_list
1740 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001741
Les Bell729b0352021-11-24 10:28:21 +00001742 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001743 self.ser,
1744 validator_fcns,
1745 error_name,
1746 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001747 input_shape=values.shape,
1748 output_shape=result_tens.shape,
1749 input_dtype=values.dtype,
1750 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001751 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001752 input_list=input_list,
1753 output_list=output_list,
1754 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001755 ):
1756 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001757
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001758 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001759
1760 return result_tens
1761
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001762 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001763
1764 # Create a new indicies tensor
1765 # here with data that doesn't exceed the dimensions of the values_in tensor
1766
Kevin Cheng550ccc52021-03-03 11:21:43 -08001767 K = values_in.shape[1] # K
1768 W = input.shape[1] # W
1769 indicies_arr = np.int32(
1770 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1771 ) # (N, W)
1772 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001773
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001774 result_tens = OutputShaper.scatterOp(
1775 self.ser, self.rng, values_in, indicies, input, error_name
1776 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001777
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001778 # Invalidate Input/Output list for error if checks.
1779 input_list = [values_in.name, indicies.name, input.name]
1780 output_list = [result_tens.name]
1781 pCount, cCount = op["operands"]
1782 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001783 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1784 self, error_name, input_list, output_list
1785 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001786
Les Bell729b0352021-11-24 10:28:21 +00001787 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001788 self.ser,
1789 validator_fcns,
1790 error_name,
1791 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001792 input_shape=values_in.shape,
1793 output_shape=result_tens.shape,
1794 input_dtype=values_in.dtype,
1795 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001796 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001797 input_list=input_list,
1798 output_list=output_list,
1799 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001800 ):
1801 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001802
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001803 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001804
Kevin Cheng77d0f762020-11-24 10:26:32 -08001805 return result_tens
1806
Kevin Cheng550ccc52021-03-03 11:21:43 -08001807 def build_resize(
1808 self,
1809 op,
1810 input,
1811 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001812 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001813 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001814 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001815 input_dtype,
1816 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001817 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001818 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001819 ):
1820 result_tens = OutputShaper.resizeOp(
1821 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001822 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001823 input,
1824 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001825 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001826 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001827 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001828 input_dtype,
1829 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001830 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001831 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001832
Matthew Haddon848efb42021-09-09 12:30:53 +01001833 # Invalidate Input/Output list for error if checks.
1834 input_list = [input.name]
1835 output_list = [result_tens.name]
1836 pCount, cCount = op["operands"]
1837 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001838 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1839 self, error_name, input_list, output_list
1840 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001841
Les Bell729b0352021-11-24 10:28:21 +00001842 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001843 self.ser,
1844 validator_fcns,
1845 error_name,
1846 op=op,
1847 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001848 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001849 input_dtype=input_dtype,
1850 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001851 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001852 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001853 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001854 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001855 input_list=input_list,
1856 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001857 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001858 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001859 ):
1860 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001861
Eric Kunzee5e26762020-10-13 16:11:07 -07001862 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001863
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001864 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001865
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001866 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001867 return result_tens
1868
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001869 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1870 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1871 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001872 self.ser.addOperator(
1873 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1874 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001875 return result_tens
1876
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001877 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001878 self.ser.addOutputTensor(val)
1879 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001880
1881 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00001882 def build_cast(
1883 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1884 ):
1885 assert len(inputs) == 1
1886 val = inputs[0]
1887 out_dtype = args_dict["out_type"]
1888
1889 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001890 self.ser, self.rng, val, out_dtype, error_name
1891 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001892
1893 # Invalidate Input/Output list for error if checks.
1894 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00001895 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001896 pCount, cCount = op["operands"]
1897 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001898 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1899 self, error_name, input_list, output_list
1900 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001901
Les Bell729b0352021-11-24 10:28:21 +00001902 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001903 self.ser,
1904 validator_fcns,
1905 error_name,
1906 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001907 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00001908 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001909 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00001910 output_dtype=result_tensor.dtype,
1911 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001912 input_list=input_list,
1913 output_list=output_list,
1914 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001915 ):
1916 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001917
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001918 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00001919
1920 compliance = self.tensorComplianceMetaData(
1921 op, val.dtype, args_dict, result_tensor, error_name
1922 )
1923
1924 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001925
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001926 def build_rescale(
1927 self,
1928 op,
1929 val,
1930 out_dtype,
1931 scale32,
1932 double_round,
1933 per_channel,
1934 validator_fcns,
1935 error_name,
1936 ):
1937 result_tens = OutputShaper.typeConversionOp(
1938 self.ser, self.rng, val, out_dtype, error_name
1939 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001940
1941 if per_channel:
1942 nc = val.shape[-1]
1943 else:
1944 nc = 1
1945
1946 in_type_width = self.typeWidth(val.dtype)
1947 out_type_width = self.typeWidth(out_dtype)
1948
Kevin Cheng3a478572021-01-22 17:21:02 -08001949 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001950 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001951 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001952 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001953 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001954 in_type_width += 1
1955 elif error_name in [
1956 ErrorIf.InputZeroPointNotZero,
1957 ErrorIf.U16InputZeroPointNotValid,
1958 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001959 input_zp = self.randInt(-128, 128)
1960 if input_zp == 0:
1961 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001962 in_type_width += 1
1963 elif val.dtype == DType.UINT16:
1964 # Must come after ErrorIf.U16InputZeroPointNotValid check
1965 input_zp = self.rng.choice([0, 32768])
1966 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001967 else:
1968 input_zp = 0
1969
Kevin Cheng3a478572021-01-22 17:21:02 -08001970 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001971 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001972 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001973 elif out_dtype == DType.UINT8:
1974 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001975 out_type_width += 1
1976 elif error_name in [
1977 ErrorIf.OutputZeroPointNotZero,
1978 ErrorIf.U16OutputZeroPointNotValid,
1979 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001980 output_zp = self.randInt(-128, 128)
1981 if output_zp == 0:
1982 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001983 out_type_width += 1
1984 elif out_dtype == DType.UINT16:
1985 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1986 output_zp = self.rng.choice([0, 32768])
1987 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001988 else:
1989 output_zp = 0
1990
1991 # Calculate scale based on:
1992 # scale = a *(2^output_width)/(2^input_width))
1993
1994 a = np.float32(self.rng.random(size=[nc]))
1995 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1996
1997 if scale32:
1998 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001999 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002000 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2001 else:
2002 # Cap the scaling at 2^15 - 1 for scale16
2003 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2004
Kevin Cheng550ccc52021-03-03 11:21:43 -08002005 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002006
2007 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2008 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002009 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2010 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002011
2012 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002013 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2014 scale_arr[i], scale32
2015 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002016 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2017 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002018
Kevin Cheng550ccc52021-03-03 11:21:43 -08002019 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002020 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002021 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002022 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002023 assert val.placeholderFilename
2024 values = np.load(
2025 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2026 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002027 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2028 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2029 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
2030 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002031 if not np.all(np.array_equal(values, val_adj)):
2032 # Values changed so overwrite file with new values
2033 np.save(
2034 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2035 val_adj,
2036 False,
2037 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002038
Matthew Haddonc2025212021-10-08 21:21:05 +01002039 # Invalidate Input/Output list for error if checks.
2040 input_list = [val.name]
2041 output_list = [result_tens.name]
2042 pCount, cCount = op["operands"]
2043 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002044 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2045 self, error_name, input_list, output_list
2046 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002047
2048 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002049 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002050 self.ser,
2051 validator_fcns,
2052 error_name,
2053 op=op,
2054 input_dtype=val.dtype,
2055 output_dtype=out_dtype,
2056 input_shape=val.shape,
2057 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002058 scale32=scale32,
2059 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002060 input_list=input_list,
2061 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00002062 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01002063 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002064 ):
2065 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002066
Eric Kunzee5e26762020-10-13 16:11:07 -07002067 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002068 attr.RescaleAttribute(
2069 input_zp,
2070 output_zp,
2071 multiplier_arr,
2072 shift_arr,
2073 scale32,
2074 double_round,
2075 per_channel,
2076 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002077
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002078 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002079 return result_tens
2080
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002081 def _get_condition_tensor(self, op, cond, error_name):
2082 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002083 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002084 else:
2085 cond_type = DType.BOOL
2086 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2087 choice = self.rng.choice([1, 2])
2088 if choice == 1:
2089 cond_shape = [2]
2090 else:
2091 cond_shape = [1, 2]
2092 else:
2093 # Must be of size 1 (rank 0)
2094 cond_shape = []
2095 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2096 return cond_tens
2097
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002098 def build_cond_if_const(
2099 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2100 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002101 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002102 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002103 # and fill them with const nodes for the body.
2104
2105 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002106 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002107
2108 # Make then/else tensors
2109 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002110
2111 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002112 if error_name in [
2113 ErrorIf.CondIfOutputListThenGraphMismatch,
2114 ErrorIf.CondIfOutputListElseGraphMismatch,
2115 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002116 incorrect_shape = deepcopy(then_tens.shape)
2117 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002118 incorrect_shape[i] += (
2119 self.rng.choice([-3, -2, 2, 3])
2120 if incorrect_shape[i] > 3
2121 else self.rng.choice([1, 2, 4])
2122 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002123 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2124
Jeremy Johnson18e26662021-07-22 16:15:29 +01002125 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2126 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002127
2128 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002129 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002130
2131 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002132 then_block = "THEN_BLOCK"
2133 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002134 attr = ts.TosaSerializerAttribute()
2135 attr.CondIfAttribute(then_block, else_block)
2136
2137 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002138 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002139
Jerry Ge9e94af82022-10-27 09:57:00 -07002140 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002141 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002142 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2143 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2144 else:
2145 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002146 self.ser.addOutputTensor(then_tens)
2147
Jerry Ge9e94af82022-10-27 09:57:00 -07002148 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002149 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2150 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2151 else:
2152 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002153 self.ser.addOutputTensor(else_tens)
2154
Les Bell729b0352021-11-24 10:28:21 +00002155 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002156 self.ser,
2157 validator_fcns,
2158 error_name,
2159 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002160 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002161 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002162 ):
2163 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002164
Eric Kunzee5e26762020-10-13 16:11:07 -07002165 return result_tens
2166
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002167 def build_cond_if_binary(
2168 self, op, a, b, cond, validator_fcns=None, error_name=None
2169 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002170 # For cond_if with a binary op in the then/else blocks, take a and b and
2171 # alternately add or subtract them based on the condition
2172
2173 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002174 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002175
Kevin Cheng550ccc52021-03-03 11:21:43 -08002176 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002177
2178 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002179 then_block = "THEN_BLOCK"
2180 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002181 attr = ts.TosaSerializerAttribute()
2182 attr.CondIfAttribute(then_block, else_block)
2183
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002184 if error_name in [
2185 ErrorIf.CondIfInputListThenGraphMismatch,
2186 ErrorIf.CondIfInputListElseGraphMismatch,
2187 ErrorIf.CondIfOutputListElseGraphMismatch,
2188 ErrorIf.CondIfOutputListThenGraphMismatch,
2189 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002190 incorrect_shape = a.shape.copy()
2191 for i in range(len(incorrect_shape)):
2192 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2193 incorrect_block_input = deepcopy(a)
2194 incorrect_block_input.shape = incorrect_shape
2195
Eric Kunzee5e26762020-10-13 16:11:07 -07002196 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002197 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002198 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002199 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002200
James Ward24dbc422022-10-19 12:20:31 +01002201 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002202 then_op, else_op = Op.ADD, Op.SUB
2203 elif a.dtype in (DType.INT8, DType.INT16):
2204 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2205 else:
2206 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002207
Les Bell6040b4d2021-10-11 12:50:31 +01002208 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002209 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002210 if (
2211 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2212 and block == then_block
2213 ) or (
2214 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2215 and block == else_block
2216 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002217 self.ser.addInputTensor(incorrect_block_input)
2218 self.ser.addInputTensor(b)
2219 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002220 elif (
2221 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2222 and block == then_block
2223 ) or (
2224 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2225 and block == else_block
2226 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002227 self.ser.addInputTensor(a)
2228 self.ser.addInputTensor(b)
2229 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2230 else:
2231 self.ser.addInputTensor(a)
2232 self.ser.addInputTensor(b)
2233 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002234 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002235
Les Bell729b0352021-11-24 10:28:21 +00002236 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002237 self.ser,
2238 validator_fcns,
2239 error_name,
2240 op=op,
2241 a=a,
2242 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002243 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002244 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002245 ):
2246 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002247
Eric Kunzee5e26762020-10-13 16:11:07 -07002248 return result_tens
2249
Matthew Haddon630c17c2021-10-14 15:05:41 +01002250 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002251 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002252
Kevin Cheng550ccc52021-03-03 11:21:43 -08002253 cond_block = "COND_BLOCK"
2254 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002255
2256 attr = ts.TosaSerializerAttribute()
2257 attr.WhileLoopAttribute(cond_block, body_block)
2258
2259 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002260 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002261 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002262 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002263
2264 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002265 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2266 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002267 if error_name == ErrorIf.InputListOutputListMismatch:
2268 incorrect_acc = deepcopy(acc)
2269 for i in range(len(incorrect_acc.shape)):
2270 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2271 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2272 else:
2273 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002274
2275 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002276 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002277 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002278 [iter.name, a.name, acc.name],
2279 [iter_out.name, a_out.name, acc_out.name],
2280 attr,
2281 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002282 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002283
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002284 if error_name in [
2285 ErrorIf.InputListCondGraphMismatch,
2286 ErrorIf.InputListBodyGraphInputMismatch,
2287 ErrorIf.InputListBodyGraphOutputMismatch,
2288 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002289 incorrect_iter = deepcopy(iter)
2290 for i in range(len(incorrect_iter.shape)):
2291 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2292 if len(incorrect_iter.shape) == 0:
2293 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2294
2295 incorrect_acc = deepcopy(acc)
2296 for i in range(len(incorrect_acc.shape)):
2297 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2298
Eric Kunzee5e26762020-10-13 16:11:07 -07002299 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002300 self.ser.addBasicBlock(cond_block)
2301
Matthew Haddon630c17c2021-10-14 15:05:41 +01002302 if error_name == ErrorIf.InputListCondGraphMismatch:
2303 self.ser.addInputTensor(incorrect_iter)
2304 self.ser.addInputTensor(a)
2305 self.ser.addInputTensor(incorrect_acc)
2306 else:
2307 self.ser.addInputTensor(iter)
2308 self.ser.addInputTensor(a)
2309 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002310 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002311
2312 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002313 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002314 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002315 cond_type = DType.BOOL
2316 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2317 choice = self.rng.choice([1, 2])
2318 if choice == 1:
2319 cond_shape = [3]
2320 else:
2321 cond_shape = [1, 2]
2322 else:
2323 cond_shape = []
2324 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002325
Kevin Cheng550ccc52021-03-03 11:21:43 -08002326 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002327
2328 # BODY block (input: a, acc, iter, output: a, acc, iter)
2329 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002330 self.ser.addBasicBlock(body_block)
2331
Matthew Haddon630c17c2021-10-14 15:05:41 +01002332 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2333 self.ser.addInputTensor(incorrect_iter)
2334 self.ser.addInputTensor(a)
2335 self.ser.addInputTensor(incorrect_acc)
2336 else:
2337 self.ser.addInputTensor(iter)
2338 self.ser.addInputTensor(a)
2339 self.ser.addInputTensor(acc)
2340
Kevin Cheng550ccc52021-03-03 11:21:43 -08002341 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002342
2343 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002344 iter_body_out = self.ser.addIntermediate(
2345 incorrect_iter.shape, incorrect_iter.dtype
2346 )
2347 acc_body_out = self.ser.addIntermediate(
2348 incorrect_acc.shape, incorrect_acc.dtype
2349 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002350 else:
2351 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2352 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2353
Eric Kunzee5e26762020-10-13 16:11:07 -07002354 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2355 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2356 self.ser.addOutputTensor(iter_body_out)
2357 self.ser.addOutputTensor(a)
2358 self.ser.addOutputTensor(acc_body_out)
2359
Les Bell729b0352021-11-24 10:28:21 +00002360 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002361 self.ser,
2362 validator_fcns,
2363 error_name,
2364 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002365 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002366 ):
2367 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002368
Eric Kunzee5e26762020-10-13 16:11:07 -07002369 return acc_out
2370
Luke Hutton57287132023-02-06 14:54:18 +00002371 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002372 self,
2373 op,
2374 val1,
2375 val2,
2376 inverse,
2377 validator_fcns=None,
2378 error_name=None,
Luke Hutton57287132023-02-06 14:54:18 +00002379 ):
2380 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2381
2382 input_names = [val1.name, val2.name]
2383 pCount, cCount = op["operands"]
2384 num_operands = pCount + cCount
2385
2386 output_names = [res.name for res in results]
2387 output_shapes = [res.shape for res in results]
2388 output_dtypes = [res.dtype for res in results]
2389
2390 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2391 self, error_name, input_names, output_names
2392 )
2393
2394 if not TosaErrorValidator.evValidateErrorIfs(
2395 self.ser,
2396 validator_fcns,
2397 error_name,
2398 op=op,
2399 inverse=inverse,
2400 input1=val1,
2401 input2=val2,
2402 input_shape=val1.shape,
2403 input_dtype=val1.dtype,
2404 output_shape=output_shapes,
2405 output_dtype=output_dtypes,
2406 result_tensors=results,
2407 input_list=input_names,
2408 output_list=output_names,
2409 num_operands=num_operands,
2410 ):
2411 return None
2412
Tai Lyd3797f02023-11-15 23:06:19 +00002413 # TODO - Test local_bound, for now set local bound attribute to False
2414 local_bound = False
2415
Luke Hutton57287132023-02-06 14:54:18 +00002416 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002417 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002418
2419 self.ser.addOperator(op["op"], input_names, output_names, attr)
2420 return results
2421
Tai Lyd3797f02023-11-15 23:06:19 +00002422 def build_rfft2d(
2423 self,
2424 op,
2425 val,
2426 validator_fcns=None,
2427 error_name=None,
2428 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00002429 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2430
2431 input_names = [val.name]
2432 pCount, cCount = op["operands"]
2433 num_operands = pCount + cCount
2434
2435 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002436 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002437 output_dtypes = [res.dtype for res in results]
2438
2439 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2440 self, error_name, input_names, output_names
2441 )
2442
2443 if not TosaErrorValidator.evValidateErrorIfs(
2444 self.ser,
2445 validator_fcns,
2446 error_name,
2447 op=op,
2448 input_shape=val.shape,
2449 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002450 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002451 output_dtype=output_dtypes,
2452 result_tensors=results,
2453 input_list=input_names,
2454 output_list=output_names,
2455 num_operands=num_operands,
2456 ):
2457 return None
2458
Tai Lyd3797f02023-11-15 23:06:19 +00002459 # TODO - Test local_bound, for now set local bound attribute to False
2460 local_bound = False
2461
2462 attr = ts.TosaSerializerAttribute()
2463 attr.RFFTAttribute(local_bound)
2464
2465 self.ser.addOperator(op["op"], input_names, output_names, attr)
Luke Hutton261b7b62023-01-10 14:50:31 +00002466 return results
2467
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002468 def create_filter_lists(
2469 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2470 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002471 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2472 default_test_rank_range = range(1, 5)
2473 if not shapeFilter:
2474 shapeFilter = [None]
2475
2476 # Calculate the filters based on what is requested and what the operator allows
2477 rmin, rmax = op["rank"]
2478 if rankFilter is not None:
2479 cleanRankFilter = []
2480 # Ensure rankFilter values are allowed by operator
2481 for rank in rankFilter:
2482 if rank >= rmin and rank <= rmax:
2483 cleanRankFilter.append(rank)
2484 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002485 # Ensure default behaviour is bounded by default range or by operator,
2486 # whichever is the smaller range of ranks.
2487 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002488 cleanRankFilter = (
2489 opRankRange
2490 if len(opRankRange) <= len(default_test_rank_range)
2491 else default_test_rank_range
2492 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002493 else:
2494 cleanRankFilter = range(rmin, rmax + 1)
2495
2496 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002497
Matthew Haddon1c00b712021-10-01 15:51:03 +01002498 if dtypeFilter is not None:
2499 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002500 # Create list of operator dtypes filtered by requested dtypes
2501 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002502 if dtype in dtypeFilter or (
2503 isinstance(dtype, list) and dtype[0] in dtypeFilter
2504 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002505 cleanDtypeFilter.append(dtype)
2506 else:
2507 cleanDtypeFilter = dtypes
2508
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002509 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002510 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002511 "shapeFilter": shapeFilter,
2512 "rankFilter": cleanRankFilter,
2513 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002514 }
2515 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002516 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002517 if validator is not None:
2518 validator_info = validator(check=False, op=op)
2519 else:
2520 return None
2521
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002522 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002523
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002524 # Set parameters as required
2525 if error_arguments["rank"] is not None:
2526 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002527 else:
2528 rankFilter = cleanRankFilter
2529
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002530 if error_arguments["dtype"] is not None:
2531 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002532 else:
2533 dtypeFilter = cleanDtypeFilter
2534
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002535 if error_arguments["shape"] is not None:
2536 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002537 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002538 shapeFilter = shapeFilter[
2539 :2
2540 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002541
2542 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002543 "shapeFilter": shapeFilter,
2544 "rankFilter": rankFilter,
2545 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002546 }
2547 return filterDict
2548
Kevin Cheng550ccc52021-03-03 11:21:43 -08002549 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002550 self,
2551 opName,
2552 shapeFilter=[None],
2553 rankFilter=None,
2554 dtypeFilter=None,
2555 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002556 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002557
2558 try:
2559 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002560 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002561 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002562
2563 # Initialize a new random number generator
2564 self.rng = np.random.default_rng(self.random_seed)
2565
Jeremy Johnson1271c442023-09-05 11:39:26 +01002566 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002567
Eric Kunzee5e26762020-10-13 16:11:07 -07002568 # Test list consists of a tuple of:
2569 # (opName, testNameStr, dtype, shapeList, argumentsList)
2570 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002571 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002572 error_if_validators = op["error_if_validators"]
2573 else:
2574 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002575
Matthew Haddon1c00b712021-10-01 15:51:03 +01002576 for validator in error_if_validators:
2577 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002578 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002579 else:
2580 error_name = None
2581
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002582 filterDict = self.create_filter_lists(
2583 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2584 )
2585 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002586 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002587 cleanRankFilter = filterDict["rankFilter"]
2588 cleanDtypeFilter = filterDict["dtypeFilter"]
2589 cleanShapeFilter = filterDict["shapeFilter"]
2590 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002591
2592 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002593 for t in cleanDtypeFilter:
2594 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002595 # Filter out by rank
2596 if shape is not None and len(shape) != r:
2597 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002598 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002599 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002600
Matthew Haddon74567092021-07-16 15:38:20 +01002601 shapeStr = self.shapeStr(shapeList[0])
2602 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002603
Matthew Haddon74567092021-07-16 15:38:20 +01002604 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2605 argList = []
2606 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002607 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002608 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002609 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002610
Matthew Haddon74567092021-07-16 15:38:20 +01002611 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002612 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002613 if argStr:
2614 testStr = "{}_{}_{}_{}".format(
2615 opName, shapeStr, typeStr, argStr
2616 )
2617 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002618 testStr = "{}_{}_{}".format(
2619 opName, shapeStr, typeStr
2620 )
2621 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002622 if argStr:
2623 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2624 opName, error_name, shapeStr, typeStr, argStr
2625 )
2626 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002627 testStr = "{}_ERRORIF_{}_{}_{}".format(
2628 opName, error_name, shapeStr, typeStr
2629 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002630
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002631 testList.append(
2632 (opName, testStr, t, error_name, shapeList, args)
2633 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002634
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002635 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002636 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2637 if "invalid_test_validators" in op:
2638 invalid_test_validators = op["invalid_test_validators"]
2639 clean_testList = []
2640 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002641 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002642 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002643 if validator_fcn(
2644 opName=test[0],
2645 input_dtype=test[2],
2646 shapeList=test[4],
2647 args=test[5],
2648 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002649 remove_test = True
2650 if not remove_test:
2651 clean_testList.append(test)
2652 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002653
2654 return testList
2655
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002656 def serializeTest(
2657 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2658 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002659 try:
2660 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002661 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002662 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002663
Jeremy Johnson0c716862023-04-13 17:18:19 +01002664 if self.args.verbose:
2665 print(f"Creating {testStr}")
2666
Eric Kunzee5e26762020-10-13 16:11:07 -07002667 # Create a serializer
2668 self.createSerializer(opName, testStr)
2669
Jeremy Johnson1271c442023-09-05 11:39:26 +01002670 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002671 if "error_if_validators" in op:
2672 error_if_validators = op["error_if_validators"]
2673 else:
2674 error_if_validators = None
2675
Kevin Cheng550ccc52021-03-03 11:21:43 -08002676 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002677 num_operands = pCount + cCount
2678
2679 if isinstance(dtype_or_dtypeList, list):
2680 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002681 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002682 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002683 else:
2684 dtypeList = [dtype_or_dtypeList] * (num_operands)
2685
Kevin Cheng93a16282021-08-31 16:14:03 -07002686 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002687 assert (
2688 len(shapeList) == num_operands
2689 ), "shapeList length {} must match number of operands {}".format(
2690 len(shapeList), num_operands
2691 )
2692 assert (
2693 len(dtypeList) == num_operands
2694 ), "dtypeList length {} must match number of operands {}".format(
2695 len(dtypeList), num_operands
2696 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002697
2698 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002699 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002700 except KeyError:
2701 qgen = None
2702
2703 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002704
Matthew Haddon1c00b712021-10-01 15:51:03 +01002705 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002706 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002707 else:
2708 qinfo = None
2709
Jeremy Johnson1271c442023-09-05 11:39:26 +01002710 # Extra meta data for the desc.json
2711 tensMeta = {}
2712
2713 # Check we are using the new testArgs interface with an argsDict dictionary
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002714 if isinstance(testArgs, dict):
2715 # New interface with args info in dictionary
2716 argsDict = testArgs
Jeremy Johnson1271c442023-09-05 11:39:26 +01002717 assert "dg_type" in argsDict
2718 tvgInfo = tvgen_fcn(
2719 self, opName, dtypeList, shapeList, argsDict, error_name
2720 )
2721 if tvgInfo.dataGenDict:
2722 tensMeta["data_gen"] = tvgInfo.dataGenDict
2723 tens = tvgInfo.tensorList
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002724
2725 result = build_fcn(
2726 self,
2727 op,
2728 tens,
2729 argsDict,
2730 validator_fcns=error_if_validators,
2731 error_name=error_name,
2732 qinfo=qinfo,
2733 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01002734 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002735 # Old interface with args info in a list
Jeremy Johnson1271c442023-09-05 11:39:26 +01002736 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002737
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002738 try:
2739 if error_if_validators is None:
2740 if qinfo is not None:
2741 result = build_fcn(self, op, *tens, *testArgs, qinfo)
2742 else:
2743 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002744 else:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002745 if qinfo is not None:
2746 result = build_fcn(
2747 self,
2748 op,
2749 *tens,
2750 *testArgs,
2751 validator_fcns=error_if_validators,
2752 error_name=error_name,
2753 qinfo=qinfo,
2754 )
2755 else:
2756 result = build_fcn(
2757 self,
2758 op,
2759 *tens,
2760 *testArgs,
2761 validator_fcns=error_if_validators,
2762 error_name=error_name,
2763 )
2764 except TypeError as e:
2765 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
2766 raise e
Matthew Haddon1c00b712021-10-01 15:51:03 +01002767
Jeremy Johnson1271c442023-09-05 11:39:26 +01002768 if result:
Les Bell729b0352021-11-24 10:28:21 +00002769 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002770 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2771 # Add the compliance meta data
2772 # NOTE: This currently expects only one result output
2773 tensMeta["compliance"] = {
2774 "version": "0.1",
2775 "tensors": {result.resultTensor.name: result.complianceDict},
2776 }
2777 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002778 else:
2779 # The test is not valid
2780 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002781
Eric Kunzee5e26762020-10-13 16:11:07 -07002782 def createDynamicOpLists(self):
2783
Jeremy Johnson00423432022-09-12 17:27:37 +01002784 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2785 # Already created these lists (can occur when class is initialized more than once)
2786 return
2787
Eric Kunzee5e26762020-10-13 16:11:07 -07002788 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002789 if not self.args.level8k:
2790 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2791 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2792 else:
2793 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2794 KERNELS_2D = [[1, bigK], [bigK, 2]]
2795 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002796
Kevin Cheng1533b852021-09-01 12:51:58 -07002797 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002798 testName = "conv2d_{}x{}".format(k[0], k[1])
2799 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2800 self.TOSA_OP_LIST[testName]["filter"] = k
2801 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002802
Kevin Cheng550ccc52021-03-03 11:21:43 -08002803 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2804 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2805 "depthwise_conv2d_TEMPLATE"
2806 ].copy()
2807 self.TOSA_OP_LIST[testName]["filter"] = k
2808 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002809
Kevin Cheng550ccc52021-03-03 11:21:43 -08002810 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2811 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2812 "transpose_conv2d_TEMPLATE"
2813 ].copy()
2814 self.TOSA_OP_LIST[testName]["filter"] = k
2815 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002816
Kevin Cheng1533b852021-09-01 12:51:58 -07002817 for k in KERNELS_3D:
2818 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2819 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2820 self.TOSA_OP_LIST[testName]["filter"] = k
2821 self.TOSA_OP_LIST[testName]["template"] = False
2822
Eric Kunzee5e26762020-10-13 16:11:07 -07002823 # Delete any templates after having created any dynamic ops
2824 # This is a two-pass operation because it's bad practice to delete
2825 # keys from dictionaries while iterating
2826 keyList = []
2827 for k in self.TOSA_OP_LIST:
2828 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002829 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002830 keyList.append(k)
2831 continue
2832 except KeyError:
2833 pass
2834
2835 for k in keyList:
2836 del self.TOSA_OP_LIST[k]
2837
2838 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002839 """Fill in default fields for ops if they aren't already specified.
2840 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002841 for op in self.TOSA_OP_LIST:
2842
2843 # Required fields
2844 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002845 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002846 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002847 raise Exception(
2848 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2849 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002850
2851 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002852 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002853 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002854 raise Exception(
2855 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2856 op
2857 )
2858 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002859
2860 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002861 _ = self.TOSA_OP_LIST[op]["types"]
2862 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002863 raise Exception(
2864 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2865 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002866
2867 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002868 _ = self.TOSA_OP_LIST[op]["op"]
2869 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002870 raise Exception(
2871 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2872 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002873
2874 # Put in default rank range, if missing
2875 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002876 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002877 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002878 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002879
2880 # Tensor operator list
2881 # 'op': op name
2882 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002883 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2884 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002885 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2886 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002887 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002888
Kevin Cheng550ccc52021-03-03 11:21:43 -08002889 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002890 TYPE_INT_FP = [
2891 DType.INT8,
2892 DType.INT16,
2893 DType.INT32,
2894 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002895 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002896 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002897 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002898
Kevin Cheng550ccc52021-03-03 11:21:43 -08002899 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002900 TYPE_FI32 = [
2901 DType.FP32,
2902 DType.FP16,
2903 DType.BF16,
2904 DType.INT32,
2905 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002906 TYPE_FIB = [
2907 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002908 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002909 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002910 DType.INT8,
2911 DType.INT16,
2912 DType.INT32,
2913 DType.BOOL,
2914 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002915 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002916
James Ward24dbc422022-10-19 12:20:31 +01002917 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002918
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002919 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002920 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002921 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002922 [DType.INT8, DType.INT8, DType.INT32],
2923 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002924 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002925 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002926 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002927 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002928 ]
2929
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002930 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002931
2932 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002933 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002934 "argmax": {
2935 "op": Op.ARGMAX,
2936 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002937 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002938 "build_fcn": (
2939 build_argmax,
2940 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002941 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002942 TosaArgGen.agAxis,
2943 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002944 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002945 "error_if_validators": (
2946 TosaErrorValidator.evAxisSmallerZero,
2947 TosaErrorValidator.evAxisLargerRank,
2948 TosaErrorValidator.evArgmaxOutputRankMismatch,
2949 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2950 TosaErrorValidator.evWrongRank,
2951 TosaErrorValidator.evWrongInputType,
2952 TosaErrorValidator.evWrongOutputType,
2953 TosaErrorValidator.evWrongInputList,
2954 TosaErrorValidator.evWrongOutputList,
2955 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002956 "data_gen": {
2957 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
2958 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002959 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002960 "avg_pool2d": {
2961 "op": Op.AVG_POOL2D,
2962 "operands": (1, 0),
2963 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002964 "build_fcn": (
2965 build_pool2d,
2966 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002967 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002968 TosaArgGen.agPooling,
2969 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002970 "qgen": TosaQuantGen.qgUnary,
2971 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002972 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002973 "error_if_validators": (
2974 TosaErrorValidator.evKernelSmallerOne,
2975 TosaErrorValidator.evStrideSmallerOne,
2976 TosaErrorValidator.evPadSmallerZero,
2977 TosaErrorValidator.evWrongRank,
2978 TosaErrorValidator.evWrongInputType,
2979 TosaErrorValidator.evWrongOutputType,
2980 TosaErrorValidator.evWrongInputList,
2981 TosaErrorValidator.evWrongOutputList,
2982 TosaErrorValidator.evInputZeroPointNotZero,
2983 TosaErrorValidator.evOutputZeroPointNotZero,
2984 TosaErrorValidator.evPadLargerEqualKernel,
2985 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002986 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002987 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00002988 "data_gen": {
2989 "fp": (gtu.DataGenType.DOT_PRODUCT,),
2990 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002991 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002992 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002993 "conv2d_TEMPLATE": {
2994 "op": Op.CONV2D,
2995 "operands": (1, 2),
2996 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002997 "build_fcn": (
2998 build_conv2d,
2999 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003000 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003001 TosaArgGen.agConv,
3002 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003003 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003004 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003005 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3006 "error_if_validators": (
3007 TosaErrorValidator.evWrongInputType,
3008 TosaErrorValidator.evWrongOutputType,
3009 TosaErrorValidator.evWrongInputList,
3010 TosaErrorValidator.evWrongOutputList,
3011 TosaErrorValidator.evInputZeroPointNotZero,
3012 TosaErrorValidator.evWeightZeroPointNotZero,
3013 TosaErrorValidator.evPadSmallerZero,
3014 TosaErrorValidator.evStrideSmallerOne,
3015 TosaErrorValidator.evDilationSmallerOne,
3016 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003017 TosaErrorValidator.evConvOutputShapeMismatch,
3018 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003019 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003020 "data_gen": {
3021 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3022 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003023 "template": True,
3024 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003025 # Templated operator. Filled in by createDynamicOpLists
3026 "conv3d_TEMPLATE": {
3027 "op": Op.CONV3D,
3028 "operands": (1, 2),
3029 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003030 "build_fcn": (
3031 build_conv3d,
3032 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003033 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003034 TosaArgGen.agConv,
3035 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003036 "qgen": TosaQuantGen.qgConv,
3037 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003038 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3039 "error_if_validators": (
3040 TosaErrorValidator.evWrongInputType,
3041 TosaErrorValidator.evWrongOutputType,
3042 TosaErrorValidator.evWrongInputList,
3043 TosaErrorValidator.evWrongOutputList,
3044 TosaErrorValidator.evInputZeroPointNotZero,
3045 TosaErrorValidator.evWeightZeroPointNotZero,
3046 TosaErrorValidator.evPadSmallerZero,
3047 TosaErrorValidator.evStrideSmallerOne,
3048 TosaErrorValidator.evDilationSmallerOne,
3049 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003050 TosaErrorValidator.evConvOutputShapeMismatch,
3051 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003052 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003053 "template": True,
3054 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003055 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003056 "depthwise_conv2d_TEMPLATE": {
3057 "op": Op.DEPTHWISE_CONV2D,
3058 "operands": (1, 2),
3059 "filter": [1, 1],
3060 "rank": (4, 4),
3061 "build_fcn": (
3062 build_depthwise_conv2d,
3063 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003064 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003065 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003066 ),
3067 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003068 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003069 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3070 "error_if_validators": (
3071 TosaErrorValidator.evWrongInputType,
3072 TosaErrorValidator.evWrongOutputType,
3073 TosaErrorValidator.evWrongInputList,
3074 TosaErrorValidator.evWrongOutputList,
3075 TosaErrorValidator.evInputZeroPointNotZero,
3076 TosaErrorValidator.evWeightZeroPointNotZero,
3077 TosaErrorValidator.evPadSmallerZero,
3078 TosaErrorValidator.evStrideSmallerOne,
3079 TosaErrorValidator.evDilationSmallerOne,
3080 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003081 TosaErrorValidator.evConvOutputShapeMismatch,
3082 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003083 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003084 "template": True,
3085 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003086 "fully_connected": {
3087 "op": Op.FULLY_CONNECTED,
3088 "operands": (1, 2),
3089 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003090 "build_fcn": (
3091 build_fully_connected,
3092 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003093 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003094 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003095 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003096 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003097 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003098 "error_if_validators": (
3099 TosaErrorValidator.evInputZeroPointNotZero,
3100 TosaErrorValidator.evWeightZeroPointNotZero,
3101 TosaErrorValidator.evWrongRank,
3102 TosaErrorValidator.evWrongInputType,
3103 TosaErrorValidator.evWrongOutputType,
3104 TosaErrorValidator.evWrongInputList,
3105 TosaErrorValidator.evWrongOutputList,
3106 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003107 "data_gen": {
3108 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3109 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003110 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003111 "matmul": {
3112 "op": Op.MATMUL,
3113 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003114 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003115 "build_fcn": (
3116 build_matmul,
3117 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003118 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003119 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003120 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003121 "qgen": TosaQuantGen.qgMatmul,
3122 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003123 "error_if_validators": (
3124 TosaErrorValidator.evInputZeroPointNotZero,
3125 TosaErrorValidator.evWrongRank,
3126 TosaErrorValidator.evWrongInputType,
3127 TosaErrorValidator.evWrongOutputType,
3128 TosaErrorValidator.evWrongInputList,
3129 TosaErrorValidator.evWrongOutputList,
3130 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003131 "data_gen": {
3132 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003133 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003134 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003135 "max_pool2d": {
3136 "op": Op.MAX_POOL2D,
3137 "operands": (1, 0),
3138 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003139 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003140 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003141 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003142 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003143 TosaArgGen.agPooling,
3144 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003145 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003146 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003147 "error_if_validators": (
3148 TosaErrorValidator.evKernelSmallerOne,
3149 TosaErrorValidator.evStrideSmallerOne,
3150 TosaErrorValidator.evPadSmallerZero,
3151 TosaErrorValidator.evWrongRank,
3152 TosaErrorValidator.evWrongInputType,
3153 TosaErrorValidator.evWrongOutputType,
3154 TosaErrorValidator.evWrongInputList,
3155 TosaErrorValidator.evWrongOutputList,
3156 TosaErrorValidator.evPadLargerEqualKernel,
3157 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003158 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003159 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003160 "data_gen": {
3161 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3162 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003163 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003164 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003165 "transpose_conv2d_TEMPLATE": {
3166 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003167 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003168 "rank": (4, 4),
3169 "build_fcn": (
3170 build_transpose_conv2d,
3171 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003172 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003173 TosaArgGen.agTransposeConv2D,
3174 ),
3175 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003176 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003177 "invalid_test_validators": (
3178 TosaInvalidValidator.ivHeightWidthInvalid,
3179 TosaInvalidValidator.ivNonPositiveOutputShape,
3180 ),
3181 "error_if_validators": (
3182 TosaErrorValidator.evWrongInputType,
3183 TosaErrorValidator.evWrongOutputType,
3184 TosaErrorValidator.evWrongInputList,
3185 TosaErrorValidator.evWrongOutputList,
3186 TosaErrorValidator.evInputZeroPointNotZero,
3187 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003188 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003189 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003190 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003191 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003192 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003193 "template": True,
3194 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003195 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003196 "clamp": {
3197 "op": Op.CLAMP,
3198 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003199 "build_fcn": (
3200 build_clamp,
3201 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003202 TosaTensorValuesGen.tvgLazyGenDefault,
3203 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003204 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003205 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003206 "error_if_validators": (
3207 TosaErrorValidator.evMaxSmallerMin,
3208 TosaErrorValidator.evWrongInputType,
3209 TosaErrorValidator.evWrongOutputType,
3210 TosaErrorValidator.evWrongInputList,
3211 TosaErrorValidator.evWrongOutputList,
3212 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003213 "data_gen": {
3214 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3215 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003216 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003217 "sigmoid": {
3218 "op": Op.SIGMOID,
3219 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003220 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003221 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003222 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003223 TosaTensorValuesGen.tvgLazyGenDefault,
3224 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003225 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003226 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003227 "error_if_validators": (
3228 TosaErrorValidator.evWrongInputType,
3229 TosaErrorValidator.evWrongOutputType,
3230 TosaErrorValidator.evWrongInputList,
3231 TosaErrorValidator.evWrongOutputList,
3232 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003233 "data_gen": {
3234 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3235 },
3236 "compliance": {"ulp": 5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08003237 },
3238 "tanh": {
3239 "op": Op.TANH,
3240 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003241 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003242 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003243 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003244 TosaTensorValuesGen.tvgLazyGenDefault,
3245 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003246 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003247 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003248 "error_if_validators": (
3249 TosaErrorValidator.evWrongInputType,
3250 TosaErrorValidator.evWrongOutputType,
3251 TosaErrorValidator.evWrongInputList,
3252 TosaErrorValidator.evWrongOutputList,
3253 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003254 "data_gen": {
3255 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3256 },
3257 "compliance": {"ulp": 5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08003258 },
Won Jeon78155c62023-06-10 00:20:04 +00003259 "erf": {
3260 "op": Op.ERF,
3261 "operands": (1, 0),
3262 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003263 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003264 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003265 TosaTensorValuesGen.tvgLazyGenDefault,
3266 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003267 ),
3268 "types": TYPE_FP,
3269 "error_if_validators": (
3270 TosaErrorValidator.evWrongInputType,
3271 TosaErrorValidator.evWrongOutputType,
3272 TosaErrorValidator.evWrongInputList,
3273 TosaErrorValidator.evWrongOutputList,
3274 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003275 "data_gen": {
3276 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3277 },
3278 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003279 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003280 # Elementwise Binary Operators
3281 "add": {
3282 "op": Op.ADD,
3283 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003284 "build_fcn": (
3285 build_binary_broadcast,
3286 TosaTensorGen.tgBroadcastFuzz,
3287 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003288 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003289 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003290 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003291 "error_if_validators": (
3292 TosaErrorValidator.evRankMismatch,
3293 TosaErrorValidator.evWrongInputType,
3294 TosaErrorValidator.evWrongOutputType,
3295 TosaErrorValidator.evWrongInputList,
3296 TosaErrorValidator.evWrongOutputList,
3297 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003298 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003299 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003300 "data_gen": {
3301 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3302 },
3303 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003304 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003305 "arithmetic_right_shift": {
3306 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3307 "operands": (2, 0),
3308 "build_fcn": (
3309 build_arithmetic_right_shift,
3310 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003311 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003312 TosaArgGen.agArithmeticRightShift,
3313 ),
3314 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003315 "error_if_validators": (
3316 TosaErrorValidator.evRankMismatch,
3317 TosaErrorValidator.evWrongInputType,
3318 TosaErrorValidator.evWrongOutputType,
3319 TosaErrorValidator.evWrongInputList,
3320 TosaErrorValidator.evWrongOutputList,
3321 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003322 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003323 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003324 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003325 "bitwise_and": {
3326 "op": Op.BITWISE_AND,
3327 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003328 "build_fcn": (
3329 build_binary_broadcast,
3330 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003331 TosaTensorValuesGen.tvgLazyGenDefault,
3332 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003333 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003334 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003335 "error_if_validators": (
3336 TosaErrorValidator.evRankMismatch,
3337 TosaErrorValidator.evWrongInputType,
3338 TosaErrorValidator.evWrongOutputType,
3339 TosaErrorValidator.evWrongInputList,
3340 TosaErrorValidator.evWrongOutputList,
3341 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003342 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003343 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003344 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003345 "bitwise_or": {
3346 "op": Op.BITWISE_OR,
3347 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003348 "build_fcn": (
3349 build_binary_broadcast,
3350 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003351 TosaTensorValuesGen.tvgLazyGenDefault,
3352 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003353 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003354 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003355 "error_if_validators": (
3356 TosaErrorValidator.evRankMismatch,
3357 TosaErrorValidator.evWrongInputType,
3358 TosaErrorValidator.evWrongOutputType,
3359 TosaErrorValidator.evWrongInputList,
3360 TosaErrorValidator.evWrongOutputList,
3361 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003362 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003363 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003364 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003365 "bitwise_xor": {
3366 "op": Op.BITWISE_XOR,
3367 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003368 "build_fcn": (
3369 build_binary_broadcast,
3370 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003371 TosaTensorValuesGen.tvgLazyGenDefault,
3372 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003373 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003374 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003375 "error_if_validators": (
3376 TosaErrorValidator.evRankMismatch,
3377 TosaErrorValidator.evWrongInputType,
3378 TosaErrorValidator.evWrongOutputType,
3379 TosaErrorValidator.evWrongInputList,
3380 TosaErrorValidator.evWrongOutputList,
3381 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003382 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003383 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003384 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003385 "intdiv": {
3386 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003387 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003388 "build_fcn": (
3389 build_binary_broadcast,
3390 TosaTensorGen.tgBroadcastFuzz,
3391 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003392 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003393 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003394 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003395 "error_if_validators": (
3396 TosaErrorValidator.evRankMismatch,
3397 TosaErrorValidator.evWrongInputType,
3398 TosaErrorValidator.evWrongOutputType,
3399 TosaErrorValidator.evWrongInputList,
3400 TosaErrorValidator.evWrongOutputList,
3401 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003402 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003403 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003404 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003405 "logical_and": {
3406 "op": Op.LOGICAL_AND,
3407 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003408 "build_fcn": (
3409 build_binary_broadcast,
3410 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003411 TosaTensorValuesGen.tvgLazyGenDefault,
3412 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003413 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003414 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003415 "error_if_validators": (
3416 TosaErrorValidator.evRankMismatch,
3417 TosaErrorValidator.evWrongInputType,
3418 TosaErrorValidator.evWrongOutputType,
3419 TosaErrorValidator.evWrongInputList,
3420 TosaErrorValidator.evWrongOutputList,
3421 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003422 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003423 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003424 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003425 "logical_left_shift": {
3426 "op": Op.LOGICAL_LEFT_SHIFT,
3427 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003428 "build_fcn": (
3429 build_binary_broadcast,
3430 TosaTensorGen.tgBroadcastFuzz,
3431 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003432 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003433 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003434 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003435 "error_if_validators": (
3436 TosaErrorValidator.evRankMismatch,
3437 TosaErrorValidator.evWrongInputType,
3438 TosaErrorValidator.evWrongOutputType,
3439 TosaErrorValidator.evWrongInputList,
3440 TosaErrorValidator.evWrongOutputList,
3441 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003442 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003443 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003444 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003445 "logical_right_shift": {
3446 "op": Op.LOGICAL_RIGHT_SHIFT,
3447 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003448 "build_fcn": (
3449 build_binary_broadcast,
3450 TosaTensorGen.tgBroadcastFuzz,
3451 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003452 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003453 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003454 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003455 "error_if_validators": (
3456 TosaErrorValidator.evRankMismatch,
3457 TosaErrorValidator.evWrongInputType,
3458 TosaErrorValidator.evWrongOutputType,
3459 TosaErrorValidator.evWrongInputList,
3460 TosaErrorValidator.evWrongOutputList,
3461 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003462 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003463 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003464 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003465 "logical_or": {
3466 "op": Op.LOGICAL_OR,
3467 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003468 "build_fcn": (
3469 build_binary_broadcast,
3470 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003471 TosaTensorValuesGen.tvgLazyGenDefault,
3472 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003473 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003474 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003475 "error_if_validators": (
3476 TosaErrorValidator.evRankMismatch,
3477 TosaErrorValidator.evWrongInputType,
3478 TosaErrorValidator.evWrongOutputType,
3479 TosaErrorValidator.evWrongInputList,
3480 TosaErrorValidator.evWrongOutputList,
3481 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003482 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003483 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003484 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003485 "logical_xor": {
3486 "op": Op.LOGICAL_XOR,
3487 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003488 "build_fcn": (
3489 build_binary_broadcast,
3490 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003491 TosaTensorValuesGen.tvgLazyGenDefault,
3492 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003493 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003494 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003495 "error_if_validators": (
3496 TosaErrorValidator.evRankMismatch,
3497 TosaErrorValidator.evWrongInputType,
3498 TosaErrorValidator.evWrongOutputType,
3499 TosaErrorValidator.evWrongInputList,
3500 TosaErrorValidator.evWrongOutputList,
3501 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003502 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003503 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003504 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003505 "maximum": {
3506 "op": Op.MAXIMUM,
3507 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003508 "build_fcn": (
3509 build_binary_broadcast,
3510 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003511 TosaTensorValuesGen.tvgLazyGenDefault,
3512 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003513 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003514 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003515 "error_if_validators": (
3516 TosaErrorValidator.evRankMismatch,
3517 TosaErrorValidator.evWrongInputType,
3518 TosaErrorValidator.evWrongOutputType,
3519 TosaErrorValidator.evWrongInputList,
3520 TosaErrorValidator.evWrongOutputList,
3521 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003522 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003523 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003524 "data_gen": {
3525 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3526 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003527 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003528 "minimum": {
3529 "op": Op.MINIMUM,
3530 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003531 "build_fcn": (
3532 build_binary_broadcast,
3533 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003534 TosaTensorValuesGen.tvgLazyGenDefault,
3535 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003536 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003537 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003538 "error_if_validators": (
3539 TosaErrorValidator.evRankMismatch,
3540 TosaErrorValidator.evWrongInputType,
3541 TosaErrorValidator.evWrongOutputType,
3542 TosaErrorValidator.evWrongInputList,
3543 TosaErrorValidator.evWrongOutputList,
3544 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003545 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003546 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003547 "data_gen": {
3548 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3549 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003550 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003551 "mul": {
3552 "op": Op.MUL,
3553 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003554 "build_fcn": (
3555 build_mul,
3556 TosaTensorGen.tgBroadcastFuzz,
3557 TosaTensorValuesGen.tvgMul,
3558 TosaArgGen.agMul,
3559 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003560 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003561 "error_if_validators": (
3562 TosaErrorValidator.evWrongInputType,
3563 TosaErrorValidator.evWrongOutputType,
3564 TosaErrorValidator.evWrongInputList,
3565 TosaErrorValidator.evWrongOutputList,
3566 TosaErrorValidator.evRankMismatch,
3567 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003568 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003569 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003570 "data_gen": {
3571 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3572 },
3573 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003574 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003575 "pow": {
3576 "op": Op.POW,
3577 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003578 "build_fcn": (
3579 build_binary_broadcast,
3580 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003581 TosaTensorValuesGen.tvgPow,
3582 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003583 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003584 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003585 "error_if_validators": (
3586 TosaErrorValidator.evRankMismatch,
3587 TosaErrorValidator.evWrongInputType,
3588 TosaErrorValidator.evWrongOutputType,
3589 TosaErrorValidator.evWrongInputList,
3590 TosaErrorValidator.evWrongOutputList,
3591 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003592 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003593 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003594 "data_gen": {
3595 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3596 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003597 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003598 "sub": {
3599 "op": Op.SUB,
3600 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003601 "build_fcn": (
3602 build_binary_broadcast,
3603 TosaTensorGen.tgBroadcastFuzz,
3604 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003605 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003606 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003607 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003608 "error_if_validators": (
3609 TosaErrorValidator.evRankMismatch,
3610 TosaErrorValidator.evWrongInputType,
3611 TosaErrorValidator.evWrongOutputType,
3612 TosaErrorValidator.evWrongInputList,
3613 TosaErrorValidator.evWrongOutputList,
3614 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003615 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003616 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003617 "data_gen": {
3618 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3619 },
3620 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003621 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003622 "table": {
3623 "op": Op.TABLE,
3624 # Use the automatic generation functions to create the input array
3625 # but create the table tensor in the build function, as it may be
3626 # a different type from the input
3627 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003628 "build_fcn": (
3629 build_table,
3630 TosaTensorGen.tgBasic,
3631 TosaTensorValuesGen.tvgDefault,
3632 TosaArgGen.agTable,
3633 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003634 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003635 "error_if_validators": (
3636 TosaErrorValidator.evWrongInputType,
3637 TosaErrorValidator.evWrongOutputType,
3638 TosaErrorValidator.evWrongInputList,
3639 TosaErrorValidator.evWrongOutputList,
3640 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003641 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003642 # Elementwise Unary operators
3643 "abs": {
3644 "op": Op.ABS,
3645 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003646 "build_fcn": (
3647 build_unary,
3648 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003649 TosaTensorValuesGen.tvgLazyGenDefault,
3650 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003651 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003652 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003653 "error_if_validators": (
3654 TosaErrorValidator.evWrongInputType,
3655 TosaErrorValidator.evWrongOutputType,
3656 TosaErrorValidator.evWrongInputList,
3657 TosaErrorValidator.evWrongOutputList,
3658 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003659 "data_gen": {
3660 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3661 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003662 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003663 "bitwise_not": {
3664 "op": Op.BITWISE_NOT,
3665 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003666 "build_fcn": (
3667 build_unary,
3668 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003669 TosaTensorValuesGen.tvgLazyGenDefault,
3670 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003671 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003672 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003673 "error_if_validators": (
3674 TosaErrorValidator.evWrongInputType,
3675 TosaErrorValidator.evWrongOutputType,
3676 TosaErrorValidator.evWrongInputList,
3677 TosaErrorValidator.evWrongOutputList,
3678 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003679 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003680 "ceil": {
3681 "op": Op.CEIL,
3682 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003683 "build_fcn": (
3684 build_unary,
3685 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003686 TosaTensorValuesGen.tvgLazyGenDefault,
3687 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003688 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003689 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003690 "error_if_validators": (
3691 TosaErrorValidator.evWrongInputType,
3692 TosaErrorValidator.evWrongOutputType,
3693 TosaErrorValidator.evWrongInputList,
3694 TosaErrorValidator.evWrongOutputList,
3695 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003696 "data_gen": {
3697 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3698 },
3699 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003700 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003701 "clz": {
3702 "op": Op.CLZ,
3703 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003704 "build_fcn": (
3705 build_unary,
3706 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003707 TosaTensorValuesGen.tvgLazyGenDefault,
3708 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003709 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003710 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003711 "error_if_validators": (
3712 TosaErrorValidator.evWrongInputType,
3713 TosaErrorValidator.evWrongOutputType,
3714 TosaErrorValidator.evWrongInputList,
3715 TosaErrorValidator.evWrongOutputList,
3716 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003717 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003718 "exp": {
3719 "op": Op.EXP,
3720 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003721 "build_fcn": (
3722 build_unary,
3723 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003724 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003725 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003726 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003727 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003728 "error_if_validators": (
3729 TosaErrorValidator.evWrongInputType,
3730 TosaErrorValidator.evWrongOutputType,
3731 TosaErrorValidator.evWrongInputList,
3732 TosaErrorValidator.evWrongOutputList,
3733 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003734 "data_gen": {
3735 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3736 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003737 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003738 "floor": {
3739 "op": Op.FLOOR,
3740 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003741 "build_fcn": (
3742 build_unary,
3743 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003744 TosaTensorValuesGen.tvgLazyGenDefault,
3745 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003746 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003747 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003748 "error_if_validators": (
3749 TosaErrorValidator.evWrongInputType,
3750 TosaErrorValidator.evWrongOutputType,
3751 TosaErrorValidator.evWrongInputList,
3752 TosaErrorValidator.evWrongOutputList,
3753 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003754 "data_gen": {
3755 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3756 },
3757 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003758 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003759 "log": {
3760 "op": Op.LOG,
3761 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003762 "build_fcn": (
3763 build_unary,
3764 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003765 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003766 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003767 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003768 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003769 "error_if_validators": (
3770 TosaErrorValidator.evWrongInputType,
3771 TosaErrorValidator.evWrongOutputType,
3772 TosaErrorValidator.evWrongInputList,
3773 TosaErrorValidator.evWrongOutputList,
3774 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003775 "data_gen": {
3776 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3777 },
3778 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003779 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003780 "logical_not": {
3781 "op": Op.LOGICAL_NOT,
3782 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003783 "build_fcn": (
3784 build_unary,
3785 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003786 TosaTensorValuesGen.tvgLazyGenDefault,
3787 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003788 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003789 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003790 "error_if_validators": (
3791 TosaErrorValidator.evWrongInputType,
3792 TosaErrorValidator.evWrongOutputType,
3793 TosaErrorValidator.evWrongInputList,
3794 TosaErrorValidator.evWrongOutputList,
3795 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003796 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003797 "negate": {
3798 "op": Op.NEGATE,
3799 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003800 "build_fcn": (
3801 build_unary,
3802 TosaTensorGen.tgBasic,
3803 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003804 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003805 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003806 "qgen": TosaQuantGen.qgUnary,
3807 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003808 "error_if_validators": (
3809 TosaErrorValidator.evInputZeroPointNotZero,
3810 TosaErrorValidator.evOutputZeroPointNotZero,
3811 TosaErrorValidator.evWrongInputType,
3812 TosaErrorValidator.evWrongOutputType,
3813 TosaErrorValidator.evWrongInputList,
3814 TosaErrorValidator.evWrongOutputList,
3815 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003816 "data_gen": {
3817 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3818 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003819 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003820 "reciprocal": {
3821 "op": Op.RECIPROCAL,
3822 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003823 "build_fcn": (
3824 build_unary,
3825 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003826 TosaTensorValuesGen.tvgLazyGenDefault,
3827 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003828 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003829 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003830 "error_if_validators": (
3831 TosaErrorValidator.evWrongInputType,
3832 TosaErrorValidator.evWrongOutputType,
3833 TosaErrorValidator.evWrongInputList,
3834 TosaErrorValidator.evWrongOutputList,
3835 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003836 "data_gen": {
3837 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3838 },
3839 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08003840 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003841 "rsqrt": {
3842 "op": Op.RSQRT,
3843 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003844 "build_fcn": (
3845 build_unary,
3846 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00003847 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003848 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003849 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003850 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003851 "error_if_validators": (
3852 TosaErrorValidator.evWrongInputType,
3853 TosaErrorValidator.evWrongOutputType,
3854 TosaErrorValidator.evWrongInputList,
3855 TosaErrorValidator.evWrongOutputList,
3856 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003857 "data_gen": {
3858 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3859 },
3860 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08003861 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003862 # Elementwise Ternary operators
3863 "select": {
3864 "op": Op.SELECT,
3865 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003866 "build_fcn": (
3867 build_select,
3868 TosaTensorGen.tgBroadcastFuzz,
3869 TosaTensorValuesGen.tvgSelect,
3870 None,
3871 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003872 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003873 "error_if_validators": (
3874 TosaErrorValidator.evRankMismatch,
3875 TosaErrorValidator.evWrongInputType,
3876 TosaErrorValidator.evWrongOutputType,
3877 TosaErrorValidator.evWrongInputList,
3878 TosaErrorValidator.evWrongOutputList,
3879 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003880 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003881 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003882 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003883 # Comparison operators
3884 "equal": {
3885 "op": Op.EQUAL,
3886 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003887 "build_fcn": (
3888 build_comparison,
3889 TosaTensorGen.tgBroadcastFuzz,
3890 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003891 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003892 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003893 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003894 "error_if_validators": (
3895 TosaErrorValidator.evRankMismatch,
3896 TosaErrorValidator.evWrongInputType,
3897 TosaErrorValidator.evWrongOutputType,
3898 TosaErrorValidator.evWrongInputList,
3899 TosaErrorValidator.evWrongOutputList,
3900 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003901 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003902 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003903 "data_gen": {
3904 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3905 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003906 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003907 "greater_equal": {
3908 "op": Op.GREATER_EQUAL,
3909 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003910 "build_fcn": (
3911 build_comparison,
3912 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003913 TosaTensorValuesGen.tvgLazyGenDefault,
3914 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003915 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003916 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003917 "error_if_validators": (
3918 TosaErrorValidator.evRankMismatch,
3919 TosaErrorValidator.evWrongInputType,
3920 TosaErrorValidator.evWrongOutputType,
3921 TosaErrorValidator.evWrongInputList,
3922 TosaErrorValidator.evWrongOutputList,
3923 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003924 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003925 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003926 "data_gen": {
3927 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3928 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003929 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003930 "greater": {
3931 "op": Op.GREATER,
3932 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003933 "build_fcn": (
3934 build_comparison,
3935 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00003936 TosaTensorValuesGen.tvgLazyGenDefault,
3937 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003938 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003939 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003940 "error_if_validators": (
3941 TosaErrorValidator.evRankMismatch,
3942 TosaErrorValidator.evWrongInputType,
3943 TosaErrorValidator.evWrongOutputType,
3944 TosaErrorValidator.evWrongInputList,
3945 TosaErrorValidator.evWrongOutputList,
3946 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003947 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003948 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00003949 "data_gen": {
3950 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3951 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003952 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003953 # Reduction operators
3954 "reduce_all": {
3955 "op": Op.REDUCE_ALL,
3956 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003957 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003958 "build_fcn": (
3959 build_reduce,
3960 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003961 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003962 TosaArgGen.agAxis,
3963 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003964 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003965 "error_if_validators": (
3966 TosaErrorValidator.evAxisLargerRank,
3967 TosaErrorValidator.evAxisSmallerZero,
3968 TosaErrorValidator.evShapeOfAxisNotOne,
3969 TosaErrorValidator.evWrongInputType,
3970 TosaErrorValidator.evWrongOutputType,
3971 TosaErrorValidator.evWrongRank,
3972 TosaErrorValidator.evWrongInputList,
3973 TosaErrorValidator.evWrongOutputList,
3974 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003975 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003976 "reduce_any": {
3977 "op": Op.REDUCE_ANY,
3978 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003979 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003980 "build_fcn": (
3981 build_reduce,
3982 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003983 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003984 TosaArgGen.agAxis,
3985 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003986 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003987 "error_if_validators": (
3988 TosaErrorValidator.evAxisLargerRank,
3989 TosaErrorValidator.evAxisSmallerZero,
3990 TosaErrorValidator.evShapeOfAxisNotOne,
3991 TosaErrorValidator.evWrongInputType,
3992 TosaErrorValidator.evWrongOutputType,
3993 TosaErrorValidator.evWrongRank,
3994 TosaErrorValidator.evWrongInputList,
3995 TosaErrorValidator.evWrongOutputList,
3996 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003997 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003998 "reduce_max": {
3999 "op": Op.REDUCE_MAX,
4000 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004001 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004002 "build_fcn": (
4003 build_reduce,
4004 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004005 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004006 TosaArgGen.agAxis,
4007 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004008 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004009 "error_if_validators": (
4010 TosaErrorValidator.evAxisLargerRank,
4011 TosaErrorValidator.evAxisSmallerZero,
4012 TosaErrorValidator.evShapeOfAxisNotOne,
4013 TosaErrorValidator.evWrongInputType,
4014 TosaErrorValidator.evWrongOutputType,
4015 TosaErrorValidator.evWrongRank,
4016 TosaErrorValidator.evWrongInputList,
4017 TosaErrorValidator.evWrongOutputList,
4018 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004019 "data_gen": {
4020 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4021 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004022 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004023 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004024 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004025 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004026 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004027 "build_fcn": (
4028 build_reduce,
4029 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004030 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004031 TosaArgGen.agAxis,
4032 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004033 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004034 "error_if_validators": (
4035 TosaErrorValidator.evAxisLargerRank,
4036 TosaErrorValidator.evAxisSmallerZero,
4037 TosaErrorValidator.evShapeOfAxisNotOne,
4038 TosaErrorValidator.evWrongInputType,
4039 TosaErrorValidator.evWrongOutputType,
4040 TosaErrorValidator.evWrongRank,
4041 TosaErrorValidator.evWrongInputList,
4042 TosaErrorValidator.evWrongOutputList,
4043 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004044 "data_gen": {
4045 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4046 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004047 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004048 "reduce_product": {
4049 "op": Op.REDUCE_PRODUCT,
4050 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004051 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004052 "build_fcn": (
4053 build_reduce,
4054 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004055 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004056 TosaArgGen.agAxis,
4057 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004058 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004059 "error_if_validators": (
4060 TosaErrorValidator.evAxisLargerRank,
4061 TosaErrorValidator.evAxisSmallerZero,
4062 TosaErrorValidator.evShapeOfAxisNotOne,
4063 TosaErrorValidator.evWrongInputType,
4064 TosaErrorValidator.evWrongOutputType,
4065 TosaErrorValidator.evWrongRank,
4066 TosaErrorValidator.evWrongInputList,
4067 TosaErrorValidator.evWrongOutputList,
4068 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004069 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004070 "reduce_sum": {
4071 "op": Op.REDUCE_SUM,
4072 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004073 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004074 "build_fcn": (
4075 build_reduce,
4076 TosaTensorGen.tgBasic,
4077 TosaTensorValuesGen.tvgReduceSum,
4078 TosaArgGen.agAxis,
4079 ),
James Ward24dbc422022-10-19 12:20:31 +01004080 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004081 "error_if_validators": (
4082 TosaErrorValidator.evAxisLargerRank,
4083 TosaErrorValidator.evAxisSmallerZero,
4084 TosaErrorValidator.evShapeOfAxisNotOne,
4085 TosaErrorValidator.evWrongInputType,
4086 TosaErrorValidator.evWrongOutputType,
4087 TosaErrorValidator.evWrongRank,
4088 TosaErrorValidator.evWrongInputList,
4089 TosaErrorValidator.evWrongOutputList,
4090 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004091 "data_gen": {
4092 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4093 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004094 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004095 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004096 "concat": {
4097 "op": Op.CONCAT,
4098 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004099 "build_fcn": (
4100 build_concat,
4101 TosaTensorGen.tgConcat,
4102 TosaTensorValuesGen.tvgConcat,
4103 TosaArgGen.agAxis,
4104 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004105 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004106 "error_if_validators": (
4107 TosaErrorValidator.evAxisLargerRank,
4108 TosaErrorValidator.evAxisSmallerZero,
4109 TosaErrorValidator.evConcatInputRankMismatch,
4110 TosaErrorValidator.evConcatShapeSumMismatch,
4111 TosaErrorValidator.evConcatInputDimMismatch,
4112 TosaErrorValidator.evWrongInputType,
4113 TosaErrorValidator.evWrongOutputType,
4114 TosaErrorValidator.evWrongOutputList,
4115 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004116 },
4117 "pad": {
4118 "op": Op.PAD,
4119 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004120 "build_fcn": (
4121 build_pad,
4122 TosaTensorGen.tgBasic,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004123 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004124 TosaArgGen.agPad,
4125 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004126 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004127 "error_if_validators": (
4128 TosaErrorValidator.evWrongInputType,
4129 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004130 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004131 TosaErrorValidator.evWrongOutputType,
4132 TosaErrorValidator.evWrongInputList,
4133 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004134 TosaErrorValidator.evRankMismatch,
4135 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004136 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004137 "data_gen": {
4138 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4139 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004140 },
Won Jeona21b2e82023-08-10 10:33:01 +00004141 "dim": {
4142 "op": Op.DIM,
4143 "operands": (1, 0),
4144 "build_fcn": (
4145 build_dim,
4146 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004147 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004148 TosaArgGen.agAxis,
4149 ),
4150 "types": TYPE_FIB,
4151 "error_if_validators": (
4152 TosaErrorValidator.evAxisLargerRank,
4153 TosaErrorValidator.evAxisSmallerZero,
4154 TosaErrorValidator.evWrongInputType,
4155 TosaErrorValidator.evWrongInputList,
4156 TosaErrorValidator.evWrongOutputList,
4157 TosaErrorValidator.evWrongRank,
4158 ),
4159 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004160 "reshape": {
4161 "op": Op.RESHAPE,
4162 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004163 "build_fcn": (
4164 build_reshape,
4165 TosaTensorGen.tgBasic,
4166 TosaTensorValuesGen.tvgDefault,
4167 TosaArgGen.agReshape,
4168 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004169 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004170 "error_if_validators": (
4171 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4172 TosaErrorValidator.evWrongInputType,
4173 TosaErrorValidator.evWrongOutputType,
4174 TosaErrorValidator.evWrongInputList,
4175 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00004176 TosaErrorValidator.evReshapeOutputSizeMultiInference,
4177 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004178 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004179 },
4180 "reverse": {
4181 "op": Op.REVERSE,
4182 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004183 "build_fcn": (
4184 build_reverse,
4185 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004186 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004187 TosaArgGen.agAxis,
4188 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004189 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004190 "error_if_validators": (
4191 TosaErrorValidator.evAxisSmallerZero,
4192 TosaErrorValidator.evAxisLargerRank,
4193 TosaErrorValidator.evWrongInputType,
4194 TosaErrorValidator.evWrongOutputType,
4195 TosaErrorValidator.evWrongInputList,
4196 TosaErrorValidator.evWrongOutputList,
4197 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004198 },
4199 "slice": {
4200 "op": Op.SLICE,
4201 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004202 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004203 "build_fcn": (
4204 build_slice,
4205 TosaTensorGen.tgBasic,
4206 TosaTensorValuesGen.tvgDefault,
4207 TosaArgGen.agSlice,
4208 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004209 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004210 "error_if_validators": (
4211 TosaErrorValidator.evStartSmallerZero,
4212 TosaErrorValidator.evSizeSmallerEqualZero,
4213 TosaErrorValidator.evStartSizeOutsideBounds,
4214 TosaErrorValidator.evSizeOutputShapeMismatch,
4215 TosaErrorValidator.evInputSizeStartLengthMismatch,
4216 TosaErrorValidator.evWrongRank,
4217 TosaErrorValidator.evWrongInputType,
4218 TosaErrorValidator.evWrongOutputType,
4219 TosaErrorValidator.evWrongInputList,
4220 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004221 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004222 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004223 },
4224 "tile": {
4225 "op": Op.TILE,
4226 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004227 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004228 "build_fcn": (
4229 build_tile,
4230 TosaTensorGen.tgBasic,
4231 TosaTensorValuesGen.tvgDefault,
4232 TosaArgGen.agTile,
4233 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004234 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004235 "error_if_validators": (
4236 TosaErrorValidator.evWrongInputType,
4237 TosaErrorValidator.evWrongOutputType,
4238 TosaErrorValidator.evWrongInputList,
4239 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004240 TosaErrorValidator.evRankMismatch,
4241 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004242 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004243 },
4244 "transpose": {
4245 "op": Op.TRANSPOSE,
4246 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004247 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004248 "build_fcn": (
4249 build_transpose,
4250 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004251 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004252 TosaArgGen.agTranspose,
4253 ),
4254 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004255 "error_if_validators": (
4256 TosaErrorValidator.evIndexOutsideBounds,
4257 TosaErrorValidator.evIndexUsedTwice,
4258 TosaErrorValidator.evWrongInputType,
4259 TosaErrorValidator.evWrongOutputType,
4260 TosaErrorValidator.evWrongInputList,
4261 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004262 TosaErrorValidator.evWrongRank,
4263 TosaErrorValidator.evRankMismatch,
4264 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004265 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004266 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004267 # Data nodes
4268 "const": {
4269 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004270 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004271 "build_fcn": (
4272 build_const,
4273 TosaTensorGen.tgBasic,
4274 TosaTensorValuesGen.tvgDefault,
4275 None,
4276 ),
Luke Hutton65872422023-02-20 10:33:04 +00004277 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004278 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004279 "identity": {
4280 "op": Op.IDENTITY,
4281 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004282 "build_fcn": (
4283 build_unary,
4284 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004285 TosaTensorValuesGen.tvgLazyGenDefault,
4286 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004287 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004288 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004289 "data_gen": {
4290 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4291 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004292 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004293 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004294 "gather": {
4295 "op": Op.GATHER,
4296 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4297 "operands": (1, 0),
4298 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004299 "build_fcn": (
4300 build_gather,
4301 TosaTensorGen.tgBasic,
4302 TosaTensorValuesGen.tvgDefault,
4303 None,
4304 ),
James Ward24dbc422022-10-19 12:20:31 +01004305 "types": (
4306 DType.INT8,
4307 DType.INT16,
4308 DType.INT32,
4309 DType.FP16,
4310 DType.BF16,
4311 DType.FP32,
4312 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004313 "error_if_validators": (
4314 TosaErrorValidator.evWrongInputType,
4315 TosaErrorValidator.evWrongOutputType,
4316 TosaErrorValidator.evWrongInputList,
4317 TosaErrorValidator.evWrongOutputList,
4318 TosaErrorValidator.evWrongRank,
4319 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004320 },
4321 "scatter": {
4322 "op": Op.SCATTER,
4323 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004324 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004325 "operands": (2, 0),
4326 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004327 "build_fcn": (
4328 build_scatter,
4329 TosaTensorGen.tgScatter,
4330 TosaTensorValuesGen.tvgDefault,
4331 None,
4332 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004333 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004334 "error_if_validators": (
4335 TosaErrorValidator.evWrongInputType,
4336 TosaErrorValidator.evWrongOutputType,
4337 TosaErrorValidator.evWrongInputList,
4338 TosaErrorValidator.evWrongOutputList,
4339 TosaErrorValidator.evWrongRank,
4340 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004341 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004342 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004343 "resize": {
4344 "op": Op.RESIZE,
4345 "operands": (1, 0),
4346 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004347 "build_fcn": (
4348 build_resize,
4349 TosaTensorGen.tgNHWC,
4350 TosaTensorValuesGen.tvgDefault,
4351 TosaArgGen.agResize,
4352 ),
James Ward24dbc422022-10-19 12:20:31 +01004353 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004354 "invalid_test_validators": (
4355 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004356 ),
4357 "error_if_validators": (
4358 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004359 TosaErrorValidator.evScaleSmallerEqualZero,
4360 TosaErrorValidator.evScaleNLargerMax,
4361 TosaErrorValidator.evScaleDLargerMax,
4362 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004363 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004364 TosaErrorValidator.evBorderSmallerMin,
4365 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004366 TosaErrorValidator.evWrongInputType,
4367 TosaErrorValidator.evWrongOutputType,
4368 TosaErrorValidator.evWrongRank,
4369 TosaErrorValidator.evWrongInputList,
4370 TosaErrorValidator.evWrongOutputList,
4371 TosaErrorValidator.evBatchMismatch,
4372 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004373 TosaErrorValidator.evResizeOutputShapeMismatch,
4374 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004375 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004376 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004377 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004378 "cast": {
4379 "op": Op.CAST,
4380 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004381 "build_fcn": (
4382 build_cast,
4383 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004384 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004385 TosaArgGen.agCast,
4386 ),
James Ward8b390432022-08-12 20:48:56 +01004387 "types": (
4388 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004389 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004390 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004391 DType.INT8,
4392 DType.INT16,
4393 DType.INT32,
4394 DType.BOOL,
4395 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004396 "error_if_validators": (
4397 TosaErrorValidator.evWrongInputType,
4398 TosaErrorValidator.evWrongOutputType,
4399 TosaErrorValidator.evWrongInputList,
4400 TosaErrorValidator.evWrongOutputList,
4401 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004402 "data_gen": {
4403 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4404 },
4405 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004406 },
4407 "rescale": {
4408 "op": Op.RESCALE,
4409 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004410 "build_fcn": (
4411 build_rescale,
4412 TosaTensorGen.tgBasic,
4413 TosaTensorValuesGen.tvgDefault,
4414 TosaArgGen.agRescale,
4415 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004416 "types": [
4417 DType.UINT8,
4418 DType.INT8,
4419 DType.INT16,
4420 DType.INT32,
4421 DType.INT48,
4422 DType.UINT16,
4423 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004424 "error_if_validators": (
4425 TosaErrorValidator.evInputZeroPointNotZero,
4426 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004427 TosaErrorValidator.evU16InputZeroPointNotValid,
4428 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004429 TosaErrorValidator.evScaleTrue,
4430 TosaErrorValidator.evScaleNotTrue,
4431 TosaErrorValidator.evWrongInputType,
4432 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004433 TosaErrorValidator.evWrongInputList,
4434 TosaErrorValidator.evWrongOutputList,
4435 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004436 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004437 # Custom
4438 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004439 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004440 # Two varients of cond_if, one that generates one of two constant tensors (no
4441 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4442 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004443 "cond_if_const": {
4444 "op": Op.COND_IF,
4445 "operands": (0, 2),
4446 "build_fcn": (
4447 build_cond_if_const,
4448 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004449 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004450 TosaArgGen.agCondIf,
4451 ),
4452 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004453 "error_if_validators": (
4454 TosaErrorValidator.evOutputListThenGraphMismatch,
4455 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004456 TosaErrorValidator.evCondIfCondNotMatchingBool,
4457 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004458 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004459 },
4460 "cond_if_binary": {
4461 "op": Op.COND_IF,
4462 "operands": (2, 0),
4463 "build_fcn": (
4464 build_cond_if_binary,
4465 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004466 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004467 TosaArgGen.agCondIf,
4468 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004469 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004470 "error_if_validators": (
4471 TosaErrorValidator.evInputListThenGraphMismatch,
4472 TosaErrorValidator.evInputListElseGraphMismatch,
4473 TosaErrorValidator.evOutputListThenGraphMismatch,
4474 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004475 TosaErrorValidator.evCondIfCondNotMatchingBool,
4476 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004477 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004478 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004479 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004480 "while_loop": {
4481 "op": Op.WHILE_LOOP,
4482 "operands": (0, 1),
4483 "build_fcn": (
4484 build_while_loop,
4485 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004486 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004487 TosaArgGen.agWhileLoop,
4488 ),
4489 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004490 "error_if_validators": (
4491 TosaErrorValidator.evInputListOutputListMismatch,
4492 TosaErrorValidator.evInputListCondGraphMismatch,
4493 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4494 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4495 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004496 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004497 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004498 },
Luke Hutton57287132023-02-06 14:54:18 +00004499 "fft2d": {
4500 "op": Op.FFT2D,
4501 "operands": (2, 0),
4502 "rank": (3, 3),
4503 "build_fcn": (
4504 build_fft2d,
4505 TosaTensorGen.tgFFT2d,
4506 TosaTensorValuesGen.tvgDefault,
4507 TosaArgGen.agFFT2d,
4508 ),
4509 "types": [DType.FP32],
4510 "error_if_validators": (
4511 TosaErrorValidator.evWrongInputType,
4512 TosaErrorValidator.evWrongOutputType,
4513 TosaErrorValidator.evWrongInputList,
4514 TosaErrorValidator.evWrongOutputList,
4515 TosaErrorValidator.evWrongRank,
4516 TosaErrorValidator.evBatchMismatch,
4517 TosaErrorValidator.evKernelNotPowerOfTwo,
4518 TosaErrorValidator.evFFTInputShapeMismatch,
4519 TosaErrorValidator.evFFTOutputShapeMismatch,
4520 ),
4521 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004522 "rfft2d": {
4523 "op": Op.RFFT2D,
4524 "operands": (1, 0),
4525 "rank": (3, 3),
4526 "build_fcn": (
4527 build_rfft2d,
4528 TosaTensorGen.tgRFFT2d,
4529 TosaTensorValuesGen.tvgDefault,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004530 None,
Luke Hutton261b7b62023-01-10 14:50:31 +00004531 ),
4532 "types": [DType.FP32],
4533 "error_if_validators": (
4534 TosaErrorValidator.evWrongInputType,
4535 TosaErrorValidator.evWrongOutputType,
4536 TosaErrorValidator.evWrongInputList,
4537 TosaErrorValidator.evWrongOutputList,
4538 TosaErrorValidator.evWrongRank,
4539 TosaErrorValidator.evBatchMismatch,
4540 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004541 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004542 ),
4543 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004544 }
4545
Kevin Cheng550ccc52021-03-03 11:21:43 -08004546
Eric Kunzee5e26762020-10-13 16:11:07 -07004547class OutputShaper:
4548 # Methods in this class compute the expected output shape and datatype
4549 # for common classes of operations
4550 def __init__(self):
4551 pass
4552
4553 # These methods return arguments that can be used for
4554 # creating a new output tensor
4555 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004556 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4557 if error_name != ErrorIf.RankMismatch:
4558 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004559 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004560
4561 shape = []
4562 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004563 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004564 shape.append(b.shape[i])
4565 else:
4566 shape.append(a.shape[i])
4567
Jerry Ge135c9552023-05-23 20:59:32 +00004568 fuzz_idx = rng.integers(0, len(a.shape))
4569 if error_name == ErrorIf.DimensionMismatch:
4570 shape[fuzz_idx] += 1
4571
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004572 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004573 all_dtypes = [
4574 DType.INT8,
4575 DType.INT16,
4576 DType.INT32,
4577 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004578 DType.FP16,
4579 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004580 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004581 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004582 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4583 outputDType = rng.choice(wrong_dtypes)
4584 else:
4585 outputDType = a.dtype
4586
4587 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004588
4589 @staticmethod
4590 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004591 assert len(a.shape) == len(b.shape)
4592 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004593
4594 shape = []
4595 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004596 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004597 shape.append(a.shape[i])
4598
Kevin Cheng550ccc52021-03-03 11:21:43 -08004599 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004600
4601 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004602 def unaryOp(ser, rng, a, error_name=None):
4603 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004604 all_dtypes = [
4605 DType.INT8,
4606 DType.INT16,
4607 DType.INT32,
4608 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004609 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004610 DType.FP16,
4611 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004612 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004613 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4614 outputDType = rng.choice(wrong_dtypes)
4615 else:
4616 outputDType = a.dtype
4617
4618 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004619
4620 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004621 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004622 if error_name != ErrorIf.RankMismatch:
4623 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004624 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004625
4626 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004627 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004628 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004629 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4630 else:
4631 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004632
Jerry Ge135c9552023-05-23 20:59:32 +00004633 fuzz_idx = rng.integers(0, len(a.shape))
4634 if error_name == ErrorIf.DimensionMismatch:
4635 shape[fuzz_idx] += 1
4636
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004637 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004638 all_dtypes = [
4639 DType.INT8,
4640 DType.INT16,
4641 DType.INT32,
4642 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004643 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004644 DType.FP16,
4645 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004646 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004647 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4648 outputDType = rng.choice(wrong_dtypes)
4649 else:
4650 outputDType = a.dtype
4651
4652 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004653
4654 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004655 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004656 if error_name != ErrorIf.RankMismatch:
4657 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004658 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004659
4660 # Do broadcast
4661 shape = []
4662 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004663 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004664 shape.append(b.shape[i])
4665 else:
4666 shape.append(a.shape[i])
4667
Jerry Ge135c9552023-05-23 20:59:32 +00004668 fuzz_idx = rng.integers(0, len(a.shape))
4669 if error_name == ErrorIf.DimensionMismatch:
4670 shape[fuzz_idx] += 1
4671
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004672 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004673 wrong_dtypes = [
4674 DType.INT8,
4675 DType.INT16,
4676 DType.INT32,
4677 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004678 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004679 DType.FP16,
4680 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004681 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004682 outputDType = rng.choice(wrong_dtypes)
4683 else:
4684 outputDType = DType.BOOL
4685
4686 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004687
4688 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004689 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004690 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004691 if error_name not in [
4692 ErrorIf.AxisSmallerZero,
4693 ErrorIf.AxisLargerRank,
4694 ErrorIf.ShapeOfAxisNotOne,
4695 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004696 shape[axis] = 1
4697 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4698 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004699
Matthew Haddond6ce7252021-09-29 15:35:44 +01004700 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004701 all_dtypes = [
4702 DType.INT8,
4703 DType.INT16,
4704 DType.INT32,
4705 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004706 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004707 DType.FP16,
4708 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004709 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004710 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4711 outputDType = rng.choice(wrong_dtypes)
4712 else:
4713 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004714
Matthew Haddond6ce7252021-09-29 15:35:44 +01004715 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004716
4717 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004718 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004719 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004720
4721 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4722 del shape[axis]
4723
4724 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4725 remove = rng.choice([True, False])
4726 if remove and len(shape) > 1:
4727 del shape[0]
4728 else:
4729 shape.append(1)
4730 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4731 for i in range(len(shape)):
4732 shape[i] = shape[i] + rng.integers(1, 10)
4733
4734 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004735 all_dtypes = [
4736 DType.INT8,
4737 DType.INT16,
4738 DType.INT32,
4739 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004740 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004741 DType.FP16,
4742 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004743 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004744 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4745 outputDType = rng.choice(wrong_dtypes)
4746 else:
4747 outputDType = DType.INT32
4748
4749 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004750
4751 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004752 def conv2dOp(
4753 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4754 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004755
4756 # IFM: NHWC
4757 # Filter: OHWI
4758 # OFM: NHWC
4759
Kevin Cheng550ccc52021-03-03 11:21:43 -08004760 h = (
4761 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004762 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004763 + padding[0]
4764 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004765 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004766 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004767
Kevin Cheng550ccc52021-03-03 11:21:43 -08004768 w = (
4769 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004770 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004771 + padding[2]
4772 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004773 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004774 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004775
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004776 if error_name == ErrorIf.ConvOutputShapeMismatch:
4777 choices = [1, 2, 3]
4778 change = rng.choice(choices)
4779 # increment in multiples of stride to not hit non-integer error case
4780 if change in [1, 3]:
4781 h = h + (rng.choice(choices) * strides[0])
4782 if change in [2, 3]:
4783 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004784
Eric Kunzee5e26762020-10-13 16:11:07 -07004785 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4786
James Ward8b390432022-08-12 20:48:56 +01004787 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004788 # Pick some potentially correct output dtype if input type is incorrect
4789 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004790 else:
James Ward8b390432022-08-12 20:48:56 +01004791 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004792
4793 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004794 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004795 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004796 else:
4797 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004798 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004799 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004800
Kevin Cheng550ccc52021-03-03 11:21:43 -08004801 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004802
4803 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004804 def conv3dOp(
4805 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4806 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004807
4808 # IFM: NDHWC
4809 # Filter: ODHWI
4810 # OFM: NDHWC
4811
4812 d = (
4813 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004814 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004815 + padding[0]
4816 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004817 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004818 ) // strides[0] + 1
4819
4820 h = (
4821 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004822 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004823 + padding[2]
4824 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004825 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004826 ) // strides[1] + 1
4827
4828 w = (
4829 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004830 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004831 + padding[4]
4832 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004833 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004834 ) // strides[2] + 1
4835
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004836 if error_name == ErrorIf.ConvOutputShapeMismatch:
4837 choices = [1, 2, 3, 4]
4838 change = rng.choice(choices)
4839 # increment in multiples of stride to not hit non-integer error case
4840 if change in [1, 4]:
4841 d = d + (rng.choice(choices) * strides[0])
4842 if change in [2, 4]:
4843 h = h + (rng.choice(choices) * strides[1])
4844 if change in [3, 4]:
4845 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004846
Kevin Cheng1533b852021-09-01 12:51:58 -07004847 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4848
James Ward8b390432022-08-12 20:48:56 +01004849 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004850 # Pick some potentially correct output dtype if input type is incorrect
4851 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004852 else:
James Ward8b390432022-08-12 20:48:56 +01004853 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004854
4855 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004856 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004857 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004858 else:
4859 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004860 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004861 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004862
4863 return ser.addOutput(ofm_shape, out_dtype)
4864
4865 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004866 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004867 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004868 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004869 # IFM: NHWC
4870 # Filter: HWCM
4871 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004872
Kevin Cheng550ccc52021-03-03 11:21:43 -08004873 h = (
4874 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004875 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004876 + padding[0]
4877 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004878 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004879 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004880
Kevin Cheng550ccc52021-03-03 11:21:43 -08004881 w = (
4882 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004883 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004884 + padding[2]
4885 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004886 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004887 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004888
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004889 if error_name == ErrorIf.ConvOutputShapeMismatch:
4890 choices = [1, 2, 3]
4891 change = rng.choice(choices)
4892 # increment in multiples of stride to not hit non-integer error case
4893 if change in [1, 3]:
4894 h = h + (rng.choice(choices) * strides[0])
4895 if change in [2, 3]:
4896 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004897
Eric Kunzee5e26762020-10-13 16:11:07 -07004898 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4899
James Ward8b390432022-08-12 20:48:56 +01004900 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004901 # Pick some potentially correct output dtype if input type is incorrect
4902 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004903 else:
James Ward8b390432022-08-12 20:48:56 +01004904 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004905
4906 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004907 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004908 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004909 else:
4910 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004911 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004912 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004913
Kevin Cheng550ccc52021-03-03 11:21:43 -08004914 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004915
4916 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004917 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004918 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004919 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004920 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004921 h = 1
4922 w = 1
4923 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004924 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4925 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004926
4927 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004928 choices = [1, 2, 3]
4929 change = rng.choice(choices)
4930 # increment in multiples of stride to not hit non-integer error case
4931 if change in [1, 3]:
4932 h = h + (rng.choice(choices) * stride[0])
4933 if change in [2, 3]:
4934 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004935 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004936
4937 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004938 all_dtypes = [
4939 DType.INT8,
4940 DType.INT16,
4941 DType.INT32,
4942 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004943 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004944 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004945 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004946 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004947 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4948 outputDType = rng.choice(wrong_dtypes)
4949 else:
4950 outputDType = ifm.dtype
4951
4952 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004953
4954 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004955 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004956 # input: N, IC
4957 # filter: OC, IC
4958 # output: N, OC
4959
4960 output_shape = [input.shape[0], filter.shape[0]]
4961
James Ward8b390432022-08-12 20:48:56 +01004962 # Validated in arg_gen (also invalidated for ErrorIf)
4963 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004964
Kevin Cheng550ccc52021-03-03 11:21:43 -08004965 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004966
4967 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004968 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004969 # a: N, H, C
4970 # b: N, C, W
4971 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004972
Kevin Cheng2d60f002021-06-09 14:18:32 -07004973 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004974
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004975 if error_name == ErrorIf.WrongOutputType:
4976 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004977 incorrect_types = (
4978 DType.INT4,
4979 DType.INT8,
4980 DType.INT16,
4981 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004982 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004983 DType.FP16,
4984 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004985 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004986 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004987 incorrect_types = (
4988 DType.INT4,
4989 DType.INT8,
4990 DType.INT16,
4991 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004992 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004993 DType.FP16,
4994 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004995 )
James Ward24dbc422022-10-19 12:20:31 +01004996 elif (
4997 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4998 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004999 incorrect_types = (
5000 DType.INT4,
5001 DType.INT8,
5002 DType.INT16,
5003 DType.INT32,
5004 DType.INT48,
5005 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005006 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005007 elif error_name == ErrorIf.WrongInputType:
5008 # Pick some potentially correct output dtype if input type is incorrect
5009 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005010 else:
James Ward8b390432022-08-12 20:48:56 +01005011 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005012
Kevin Cheng550ccc52021-03-03 11:21:43 -08005013 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005014
5015 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005016 def concatOp(ser, rng, axis, inputs, error_name=None):
5017 input1 = inputs[0]
5018 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005019
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005020 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005021 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005022 if not (
5023 # unable to concat tensors of different ranks
5024 error_name == ErrorIf.ConcatInputRankMismatch
5025 # unable to concat tensors along an invalid axis
5026 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005027 ):
5028 for tensor in remaining_inputs:
5029 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005030
Matthew Haddon01c359d2021-10-15 16:30:48 +01005031 if error_name == ErrorIf.ConcatShapeSumMismatch:
5032 output_shape[axis] += rng.integers(5, 10)
5033
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005034 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005035 all_dtypes = {
5036 DType.INT8,
5037 DType.INT16,
5038 DType.INT32,
5039 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005040 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005041 DType.FP16,
5042 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005043 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005044 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5045 outputDType = rng.choice(wrong_dtypes)
5046 else:
5047 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005048
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005049 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005050
5051 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005052 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005053
5054 output_shape = a.shape.copy()
5055
5056 for i in range(len(output_shape)):
5057 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5058
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005059 if error_name == ErrorIf.PadOutputShapeMismatch:
5060 bad_dim = rng.choice(range(len(output_shape)))
5061 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005062 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005063 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005064
Matthew Haddone807aae2021-10-11 18:12:58 +01005065 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005066 all_dtypes = [
5067 DType.INT8,
5068 DType.INT16,
5069 DType.INT32,
5070 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005071 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005072 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005073 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005074 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005075 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5076 outputDType = rng.choice(wrong_dtypes)
5077 else:
5078 outputDType = a.dtype
5079
5080 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005081
5082 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005083 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00005084 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00005085
5086 if error_name == ErrorIf.WrongOutputType:
5087 all_dtypes = [
5088 DType.INT8,
5089 DType.INT16,
5090 DType.INT32,
5091 DType.INT48,
5092 DType.FP32,
5093 DType.FP16,
5094 DType.BF16,
5095 ]
5096 wrong_dtypes = list(set(all_dtypes))
5097 outputDType = rng.choice(wrong_dtypes)
5098 else:
5099 outputDType = DType.SHAPE
5100
5101 return ser.addOutput(output_shape, outputDType)
5102
5103 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005104 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005105 output_shape = shape.copy()
5106
Matthew Haddone807aae2021-10-11 18:12:58 +01005107 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5108 for i in range(len(output_shape)):
5109 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5110
5111 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005112 all_dtypes = [
5113 DType.INT8,
5114 DType.INT16,
5115 DType.INT32,
5116 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005117 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005118 DType.FP16,
5119 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005120 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005121 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5122 outputDType = rng.choice(wrong_dtypes)
5123 else:
5124 outputDType = a.dtype
5125
5126 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005127
5128 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005129 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005130
Matthew Haddone807aae2021-10-11 18:12:58 +01005131 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005132 all_dtypes = [
5133 DType.INT8,
5134 DType.INT16,
5135 DType.INT32,
5136 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005137 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005138 DType.FP16,
5139 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005140 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005141 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005142 outputDType = rng.choice(wrong_dtypes)
5143 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005144 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005145
Luke Huttona4e48ca2023-02-22 11:53:48 +00005146 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005147 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005148 for index in range(len(output_shape)):
5149 if output_shape[index] <= 2:
5150 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5151 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005152 output_shape[index] = output_shape[index] + rng.choice(
5153 [-2, -1, 1, 2]
5154 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005155 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5156 output_shape = input.shape.copy()
5157 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005158 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005159
5160 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005161
5162 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005163 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005164
5165 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005166 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005167
5168 for i in range(len(output_shape)):
5169 output_shape[i] = a.shape[i] * multiples[i]
5170
Luke Huttona4e48ca2023-02-22 11:53:48 +00005171 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005172 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005173
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005174 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005175 all_dtypes = [
5176 DType.INT8,
5177 DType.INT16,
5178 DType.INT32,
5179 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005180 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005181 DType.FP16,
5182 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005183 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005184 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5185 outputDType = rng.choice(wrong_dtypes)
5186 else:
5187 outputDType = a.dtype
5188
5189 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005190
5191 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005192 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005193 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005194
Kevin Cheng550ccc52021-03-03 11:21:43 -08005195 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005196
Luke Huttona4e48ca2023-02-22 11:53:48 +00005197 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005198 for i in range(len(output_shape)):
5199 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005200
Luke Huttona4e48ca2023-02-22 11:53:48 +00005201 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5202 for i in range(len(output_shape)):
5203 output_shape[i] += rng.integers(1, 10)
5204 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005205 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005206
Matthew Haddone807aae2021-10-11 18:12:58 +01005207 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005208 all_dtypes = [
5209 DType.INT8,
5210 DType.INT16,
5211 DType.INT32,
5212 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005213 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005214 DType.FP16,
5215 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005216 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005217 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5218 outputDType = rng.choice(wrong_dtypes)
5219 else:
5220 outputDType = a.dtype
5221
5222 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005223
5224 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005225 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005226 if error_name != ErrorIf.WrongRank:
5227 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005228 assert len(indices.shape) == 2
5229 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005230
Kevin Cheng77d0f762020-11-24 10:26:32 -08005231 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5232
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005233 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005234 all_dtypes = [
5235 DType.INT8,
5236 DType.INT16,
5237 DType.INT32,
5238 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005239 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005240 DType.FP16,
5241 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005242 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005243 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5244 outputDType = rng.choice(wrong_dtypes)
5245 else:
5246 outputDType = values.dtype
5247
5248 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005249
5250 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005251 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005252 if error_name != ErrorIf.WrongRank:
5253 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005254 assert len(indices.shape) == 2
5255 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005256 assert values_in.shape[0] == indices.shape[0] # N
5257 assert input.shape[1] == indices.shape[1] # W
5258 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005259
5260 output_shape = values_in.shape
5261
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005262 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005263 all_dtypes = [
5264 DType.INT8,
5265 DType.INT16,
5266 DType.INT32,
5267 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005268 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005269 DType.FP16,
5270 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005271 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005272 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5273 outputDType = rng.choice(wrong_dtypes)
5274 else:
5275 outputDType = values_in.dtype
5276
5277 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005278
5279 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005280 def tableOp(ser, rng, input, error_name=None):
5281 # Same shape as the input, dtype dependent on input dtype
5282 if error_name != ErrorIf.WrongInputType:
5283 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005284 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005285 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005286 wrong_dtypes = [
5287 DType.INT8,
5288 DType.INT16,
5289 DType.INT32,
5290 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005291 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005292 DType.FP16,
5293 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005294 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005295 wrong_dtypes.remove(output_dtype)
5296 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005297 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005298
5299 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005300 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005301 serializer,
5302 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005303 input,
5304 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005305 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005306 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005307 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005308 input_dtype,
5309 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005310 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005311 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005312 # Calculate OH, OW
5313 scale_y_n = scale[0]
5314 scale_y_d = scale[1]
5315 scale_x_n = scale[2]
5316 scale_x_d = scale[3]
5317 if error_name == ErrorIf.ScaleSmallerEqualZero:
5318 scale_y_n = max(scale_y_n, 1)
5319 scale_y_d = max(scale_y_d, 1)
5320 scale_x_n = max(scale_x_n, 1)
5321 scale_x_d = max(scale_x_d, 1)
5322
5323 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5324 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5325
5326 if error_name is not None:
5327 # Make sure the output tensor is valid, which can occur when
5328 # scale, offset or border have been changed for ERROR_IFs
5329 oh = max(oh, 1)
5330 ow = max(ow, 1)
5331 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005332 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5333 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005334
5335 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5336 choices = [1, 2, 3]
5337 change = rng.choice(choices)
5338 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5339 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005340 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005341 oh -= scale_y_d
5342 assert oh > 0 # Should have been caught in agResize
5343 else:
5344 oh += scale_y_d
5345 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005346 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005347 ow -= scale_x_d
5348 assert ow > 0 # Should have been caught in agResize
5349 else:
5350 ow += scale_x_d
5351
Matthew Haddon848efb42021-09-09 12:30:53 +01005352 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005353 output_dims = [
5354 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005355 oh,
5356 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005357 input.shape[0],
5358 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005359 elif error_name == ErrorIf.BatchMismatch:
5360 output_dims = [
5361 input.shape[0] + rng.integers(1, 10),
5362 oh,
5363 ow,
5364 input.shape[3],
5365 ]
5366 elif error_name == ErrorIf.ChannelMismatch:
5367 output_dims = [
5368 input.shape[0],
5369 oh,
5370 ow,
5371 input.shape[3] + rng.integers(1, 10),
5372 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005373 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005374 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005375
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005376 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005377
5378 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005379 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005380 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005381
5382 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005383 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005384 if error_name == ErrorIf.ConvOutputShapeMismatch:
5385 choices = [1, 2, 3]
5386 change = rng.choice(choices)
5387 if change in [1, 3]:
5388 output_shape[1] = output_shape[1] + rng.choice(choices)
5389 if change in [2, 3]:
5390 output_shape[2] = output_shape[2] + rng.choice(choices)
5391
James Ward8b390432022-08-12 20:48:56 +01005392 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005393 # Pick some potentially correct output dtype if input type is incorrect
5394 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005395 else:
James Ward8b390432022-08-12 20:48:56 +01005396 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005397
5398 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005399 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005400 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005401 else:
5402 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005403 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005404 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005405
Kevin Cheng550ccc52021-03-03 11:21:43 -08005406 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005407
5408 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005409 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5410 outputs = []
5411
5412 assert ifm1.dtype == ifm2.dtype
5413 input_dtype = ifm1.dtype
5414
5415 if error_name != ErrorIf.FFTInputShapeMismatch:
5416 assert ifm1.shape == ifm2.shape
5417
5418 input_shape = ifm1.shape
5419 if error_name != ErrorIf.WrongRank:
5420 assert len(input_shape) == 3
5421
5422 output_shape = input_shape.copy()
5423 output_dtype = input_dtype
5424
5425 if error_name == ErrorIf.WrongOutputType:
5426 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005427 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005428 output_dtype = rng.choice(wrong_dtypes)
5429 elif error_name == ErrorIf.BatchMismatch:
5430 output_shape[0] += rng.integers(1, 10)
5431 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5432 modify_dim = rng.choice([1, 2])
5433 output_shape[modify_dim] += rng.integers(1, 10)
5434
5435 outputs.append(serializer.addOutput(output_shape, output_dtype))
5436 outputs.append(serializer.addOutput(output_shape, output_dtype))
5437 return outputs
5438
5439 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005440 def rfft2dOp(serializer, rng, value, error_name=None):
5441 outputs = []
5442
5443 input_shape = value.shape
5444 if error_name != ErrorIf.WrongRank:
5445 assert len(input_shape) == 3
5446
5447 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5448
5449 output_dtype = value.dtype
5450 if error_name == ErrorIf.WrongOutputType:
5451 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005452 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005453 output_dtype = rng.choice(wrong_dtypes)
5454 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005455 output_shape[0] += rng.integers(1, 10)
5456 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5457 modify_dim = rng.choice([1, 2])
5458 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005459
5460 outputs.append(serializer.addOutput(output_shape, output_dtype))
5461 outputs.append(serializer.addOutput(output_shape, output_dtype))
5462 return outputs