blob: 8fcea2983053afc8a2127945382fe1f4e4142fb1 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
40 TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6)
41 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010054 # Work out floating point range
55 self.random_fp_low = min(args.tensor_fp_value_range)
56 self.random_fp_high = max(args.tensor_fp_value_range)
Jeremy Johnson1271c442023-09-05 11:39:26 +010057 # JSON schema validation
58 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnson65ba8092023-10-09 16:31:13 +010059 # Data generator library when not generating the data later
60 if not args.lazy_data_gen:
61 self.dgl = GenerateLibrary(args.generate_lib_path)
62 else:
63 self.dgl = None
Eric Kunzee5e26762020-10-13 16:11:07 -070064
65 def createSerializer(self, opName, testPath):
66 self.testPath = os.path.join(opName, testPath)
67
68 fullPath = os.path.join(self.basePath, self.testPath)
69 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010070 # Embed const data in the flatbuffer
71 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010072 if self.args.lazy_data_gen:
73 # Lazy data generation - so make constants files
74 constMode = ts.ConstMode.INPUTS
75 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010076 constMode = ts.ConstMode.EMBED_DUMP
77 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070078
79 def getSerializer(self):
80 return self.ser
81
Jeremy Johnson1271c442023-09-05 11:39:26 +010082 def serialize(self, testName, metaData=None):
83 path = Path(self.basePath) / self.testPath
84
85 # Write out TOSA flatbuffer binary
86 path_fb = path / f"{testName}.tosa"
87 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070088 fd.write(self.ser.serialize())
89
Jeremy Johnson1271c442023-09-05 11:39:26 +010090 # Get JSON descriptor from serializer
91 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
92
93 if metaData:
94 # Add extra meta data to desc.json
95 desc["meta"] = metaData
96
97 # Validate desc.json before we output it
98 self.descSchemaValidator.validate_config(desc)
99
100 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100101 if "data_gen" in metaData:
102 if self.args.lazy_data_gen:
103 # Output datagen meta data as CPP data
104 path_md = path / f"{testName}_meta_data_gen.cpp"
105 with path_md.open("w") as fd:
106 fd.write(TOSA_AUTOGENERATED_HEADER)
107 fd.write("// Test meta data for data generation setup\n\n")
108 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
109 json.dump(metaData["data_gen"], fd)
110 fd.write(')";\n\n')
111 else:
112 # Generate the data
113 self.dgl.set_config(desc)
114 self.dgl.write_numpy_files(path)
115
Jeremy Johnson1271c442023-09-05 11:39:26 +0100116 if "compliance" in metaData:
117 # Output datagen meta data as CPP data
118 path_md = path / f"{testName}_meta_compliance.cpp"
119 with path_md.open("w") as fd:
120 fd.write(TOSA_AUTOGENERATED_HEADER)
121 fd.write("// Test meta data for compliance validation\n\n")
122 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
123 json.dump(metaData["compliance"], fd)
124 fd.write(')";\n\n')
125
126 # Write desc.json
127 path_desc = path / "desc.json"
128 with path_desc.open("w") as fd:
129 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700130
Matthew Haddon74567092021-07-16 15:38:20 +0100131 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000132 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100133 seed = self.random_seed + 1
134 self.rng = np.random.default_rng(seed)
135
Jeremy Johnson1271c442023-09-05 11:39:26 +0100136 def getDTypeRange(self, dtype, high_inclusive=False):
137 # Returns dtype value range boundaries (low, high)
138 # The high boundary is excluded in the range
139 # unless high_inclusive is True
140
141 if dtype in (DType.FP32, DType.FP16, DType.BF16):
142 return (self.random_fp_low, self.random_fp_high)
143 elif dtype == DType.BOOL:
144 rng = (0, 2)
145 elif dtype == DType.UINT8:
146 rng = (0, 256)
147 elif dtype == DType.UINT16:
148 rng = (0, 65536)
149 elif dtype == DType.INT4:
150 # TOSA specific INT4 weight range from -7 to 7
151 rng = (-7, 8)
152 elif dtype == DType.INT8:
153 rng = (-128, 128)
154 elif dtype == DType.INT16:
155 rng = (-32768, 32768)
156 elif dtype in (DType.INT32, DType.SHAPE):
157 # restricting too large value for SHAPE
158 rng = (-(1 << 31), (1 << 31))
159 elif dtype == DType.INT48:
160 rng = (-(1 << 47), (1 << 47))
161 else:
162 raise Exception("Unknown dtype: {}".format(dtype))
163
164 if not high_inclusive:
165 # Exclusive high: low <= range < high
166 return rng
167 else:
168 # Inclusive range: low <= range <= high
169 return (rng[0], rng[1] - 1)
170
Eric Kunzee5e26762020-10-13 16:11:07 -0700171 def getRandTensor(self, shape, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100172 low, high = self.getDTypeRange(dtype)
173
Eric Kunzee5e26762020-10-13 16:11:07 -0700174 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700175 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700176 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100177 return np.int64(self.rng.integers(low=low, high=high, size=shape))
178 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
179 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
180
181 if dtype == DType.FP16:
182 return np.float16(f_tensor)
183 else:
184 f32_tensor = np.float32(f_tensor)
185 if dtype == DType.BF16:
186 # Floor the last 16 bits of each f32 value
187 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
188 else:
189 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700190 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191 # All other integer types
192 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700193
Kevin Cheng989cb052021-04-28 16:29:44 -0700194 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700195 placeholders = []
196
Kevin Cheng989cb052021-04-28 16:29:44 -0700197 assert len(shape_list) == len(dtype_list)
198
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700200 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100201 if not self.args.lazy_data_gen:
202 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700203 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700204
205 return placeholders
206
Kevin Cheng989cb052021-04-28 16:29:44 -0700207 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700208 consts = []
209
Kevin Cheng989cb052021-04-28 16:29:44 -0700210 assert len(shape_list) == len(dtype_list)
211
Jeremy Johnson1271c442023-09-05 11:39:26 +0100212 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700213 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100214 if not self.args.lazy_data_gen:
215 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700216 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700217
218 return consts
219
220 def makeShape(self, rank):
221 if self.targetted_shape:
222 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800223 return np.int32(
224 self.rng.integers(
225 low=self.args.tensor_shape_range[0],
226 high=self.args.tensor_shape_range[1],
227 size=rank,
228 )
229 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700230
231 def setTargetShape(self, shape):
232 self.targetted_shape = shape
233
234 def randInt(self, low=0, high=256):
235 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
236
237 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100238 low, high = self.getDTypeRange(dtype)
239
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100240 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100241 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100242 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100243 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100244 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100245 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
246 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700247 elif dtype == DType.BOOL:
248 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700249 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700250 # Special size
251 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700252
253 return np.int32(self.rng.integers(low, high, size=1))[0]
254
255 def shapeStr(self, shape):
256
257 sStr = []
258 # Convert to strings
259 for i in shape:
260 sStr.append(str(i))
261
Kevin Cheng550ccc52021-03-03 11:21:43 -0800262 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700263
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100264 def typeStr(self, dtype):
265 if isinstance(dtype, list) or isinstance(dtype, tuple):
266 assert len(dtype) >= 2
267 strs = [self.typeStr(t) for t in dtype]
268 # Limit types to the first 2 as the 3rd is the accumulator
269 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700270 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100271 if dtype in gtu.DTYPE_ATTRIBUTES:
272 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700273 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100274 raise Exception(
275 "Unknown dtype, cannot convert to string: {}".format(dtype)
276 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700277
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100278 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100279 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100280 if dtype in gtu.DTYPE_ATTRIBUTES:
281 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700282 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100283 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700284
Luke Hutton57287132023-02-06 14:54:18 +0000285 def constrictBatchSize(self, shape):
286 # Limit the batch size unless an explicit target shape set
287 if self.args.max_batch_size and not self.args.target_shapes:
288 shape[0] = min(shape[0], self.args.max_batch_size)
289 return shape
290
James Ward30124a82023-02-02 14:56:33 +0000291 def makeDimension(self):
292 return self.randInt(
293 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
294 )
295
Jeremy Johnson1271c442023-09-05 11:39:26 +0100296 def tensorComplianceMetaData(self, op, argsDict, outputTensor, errorName):
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100297 if errorName or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype):
298 # No compliance for error tests or other data types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100299 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100300
Jeremy Johnson1271c442023-09-05 11:39:26 +0100301 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100302 compliance_tens = {
303 "mode": None,
304 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
305 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
306 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100307 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
308 mode = gtu.ComplianceMode.DOT_PRODUCT
309 compliance_tens["dot_product_info"] = {
310 "s": argsDict["s"],
311 "ks": argsDict["ks"],
Jeremy Johnson1271c442023-09-05 11:39:26 +0100312 }
313 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
314 mode = gtu.ComplianceMode.FP_SPECIAL
315 elif "compliance" in op and "ulp" in op["compliance"]:
316 mode = gtu.ComplianceMode.ULP
317 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
318 elif op["op"] == Op.REDUCE_PRODUCT:
319 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100320 elif op["op"] in (Op.ADD, Op.MUL, Op.SUB, Op.CEIL, Op.FLOOR, Op.CAST):
321 mode = gtu.ComplianceMode.ROUND
Jeremy Johnson1271c442023-09-05 11:39:26 +0100322 else:
323 mode = gtu.ComplianceMode.EXACT
324 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
325
326 return compliance_tens
327
328 # Build Op functions
329 # Create the output tensor (calling OutputShaper as needed)
330 # Do final tweaks to attributes (if necessary for errorIf)
331 # Add Op into graph
332 # Return resulting tensor information or BuildInfo
333
334 class BuildInfo:
335 """Enhanced build information containing result tensor and associated compliance dict."""
336
337 def __init__(self, resultTensor, complianceDict):
338 self.resultTensor = resultTensor
339 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700340
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100341 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
342 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
343
Matthew Haddon848efb42021-09-09 12:30:53 +0100344 # build_placeholder returns an int, ABS/other ops does not
345 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000346 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100347 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000348 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000349 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100350 return result_tens
351
352 # Ensure new output type has correct qinfo
353 if error_name == ErrorIf.WrongOutputType:
354 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000355 qinfo = [
356 TosaQuantGen.getZeroPoint(self, a.dtype),
357 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
358 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100359
360 # Invalidate Input/Output list for error if checks.
361 input_list = [a.name]
362 output_list = [result_tens.name]
363 pCount, cCount = op["operands"]
364 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000365 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
366 self, error_name, input_list, output_list
367 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100368
Les Bell729b0352021-11-24 10:28:21 +0000369 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100370 self.ser,
371 validator_fcns,
372 error_name,
373 op=op,
374 input_dtype=a.dtype,
375 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000376 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000377 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100378 input_list=input_list,
379 output_list=output_list,
380 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000381 ):
382 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100383
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000384 attr = None
385 if op["op"] == Op.NEGATE:
386 attr = ts.TosaSerializerAttribute()
387 attr.NegateAttribute(qinfo[0], qinfo[1])
388
389 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700390 return result_tens
391
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100392 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 result_tens = OutputShaper.binaryBroadcastOp(
394 self.ser, self.rng, a, b, error_name
395 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100396
397 # Invalidate Input/Output list for error if checks.
398 input_list = [a.name, b.name]
399 output_list = [result_tens.name]
400 pCount, cCount = op["operands"]
401 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000402 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
403 self, error_name, input_list, output_list
404 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100405
Les Bell729b0352021-11-24 10:28:21 +0000406 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100407 self.ser,
408 validator_fcns,
409 error_name,
410 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000411 input1=a,
412 input2=b,
413 input_dtype=a.dtype,
414 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000415 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100416 input_list=input_list,
417 output_list=output_list,
418 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000419 ):
420 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100421
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000422 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700423 return result_tens
424
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100425 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700426 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000427 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700428 return result_tens
429
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000430 def build_arithmetic_right_shift(
431 self, op, a, b, round, validator_fcns=None, error_name=None
432 ):
433 result_tens = OutputShaper.binaryBroadcastOp(
434 self.ser, self.rng, a, b, error_name
435 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100436
437 # Invalidate Input/Output list for error if checks.
438 input_list = [a.name, b.name]
439 output_list = [result_tens.name]
440 pCount, cCount = op["operands"]
441 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000442 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
443 self, error_name, input_list, output_list
444 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100445
Les Bell729b0352021-11-24 10:28:21 +0000446 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100447 self.ser,
448 validator_fcns,
449 error_name,
450 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000451 input1=a,
452 input2=b,
453 input_dtype=a.dtype,
454 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000455 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100456 input_list=input_list,
457 output_list=output_list,
458 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000459 ):
460 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800461
462 attr = ts.TosaSerializerAttribute()
463 attr.ArithmeticRightShiftAttribute(round)
464
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000465 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800466 return result_tens
467
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100468 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000469 result_tens = OutputShaper.binaryBroadcastOp(
470 self.ser, self.rng, a, b, error_name
471 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700472
473 # Special for multiply:
474 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100475 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700476 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100477 if error_name == ErrorIf.WrongOutputType:
478 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
479 outputDType = self.rng.choice(all_dtypes)
480 result_tens.setDtype(outputDType)
481
482 # Invalidate Input/Output list for error if checks.
483 input_list = [a.name, b.name]
484 output_list = [result_tens.name]
485 pCount, cCount = op["operands"]
486 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000487 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
488 self, error_name, input_list, output_list
489 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100490
Les Bell729b0352021-11-24 10:28:21 +0000491 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100492 self.ser,
493 validator_fcns,
494 error_name,
495 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000496 input1=a,
497 input2=b,
498 input_dtype=a.dtype,
499 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000500 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100501 input_list=input_list,
502 output_list=output_list,
503 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000504 ):
505 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700506
Kevin Chengaee1fac2020-11-11 13:54:06 -0800507 attr = ts.TosaSerializerAttribute()
508 attr.MulAttribute(shift)
509
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000510 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700511 return result_tens
512
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100513 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
514 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700515
Kevin Chengfe392ce2021-10-18 21:51:55 +0000516 attr = ts.TosaSerializerAttribute()
517 attr.TableAttribute(table)
518
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100519 # Invalidate Input/Output list for error if checks.
520 input_list = [a.name]
521 output_list = [result_tens.name]
522 pCount, cCount = op["operands"]
523 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000524 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
525 self, error_name, input_list, output_list
526 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100527
Les Bell729b0352021-11-24 10:28:21 +0000528 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100529 self.ser,
530 validator_fcns,
531 error_name,
532 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000533 input_shape=a.shape,
534 input_dtype=a.dtype,
535 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000536 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100537 input_list=input_list,
538 output_list=output_list,
539 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000540 ):
541 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100542
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000543 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700544
545 return result_tens
546
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100547 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
548 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
549
550 # Invalidate Input/Output list for error if checks.
551 input_list = [cond.name, a.name, b.name]
552 output_list = [result_tens.name]
553 pCount, cCount = op["operands"]
554 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000555 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
556 self, error_name, input_list, output_list
557 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100558
Les Bell729b0352021-11-24 10:28:21 +0000559 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100560 self.ser,
561 validator_fcns,
562 error_name,
563 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000564 input1=cond,
565 input2=a,
566 input3=b,
567 input_shape=a.shape,
568 input_dtype=a.dtype,
569 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000570 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100571 input_list=input_list,
572 output_list=output_list,
573 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000574 ):
575 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100576
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000577 self.ser.addOperator(
578 op["op"],
579 input_list,
580 output_list,
581 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700582 return result_tens
583
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100584 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000585 result_tens = OutputShaper.binaryComparisonOp(
586 self.ser, self.rng, a, b, error_name
587 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100588
589 # Invalidate Input/Output list for error if checks.
590 input_list = [a.name, b.name]
591 output_list = [result_tens.name]
592 pCount, cCount = op["operands"]
593 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000594 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
595 self, error_name, input_list, output_list
596 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100597
Les Bell729b0352021-11-24 10:28:21 +0000598 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100599 self.ser,
600 validator_fcns,
601 error_name,
602 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000603 input1=a,
604 input2=b,
605 input_shape=a.shape,
606 input_dtype=a.dtype,
607 output_shape=result_tens.shape,
608 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000609 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100610 input_list=input_list,
611 output_list=output_list,
612 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000613 ):
614 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100615
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000616 self.ser.addOperator(
617 op["op"],
618 input_list,
619 output_list,
620 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700621 return result_tens
622
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100623 def build_argmax(self, op, a, axis, validator_fcns, error_name):
624 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
625
626 # Invalidate Input/Output list for error if checks.
627 input_list = [a.name]
628 output_list = [result_tens.name]
629 pCount, cCount = op["operands"]
630 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000631 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
632 self, error_name, input_list, output_list
633 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100634
Les Bell729b0352021-11-24 10:28:21 +0000635 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100636 self.ser,
637 validator_fcns,
638 error_name,
639 op=op,
640 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000641 input_shape=a.shape,
642 input_dtype=a.dtype,
643 output_shape=result_tens.shape,
644 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000645 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100646 input_list=input_list,
647 output_list=output_list,
648 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000649 ):
650 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700651
652 attr = ts.TosaSerializerAttribute()
653 attr.AxisAttribute(axis)
654
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000655 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700656 return result_tens
657
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000658 def build_pool2d(
659 self,
660 op,
661 input,
James Ward8b390432022-08-12 20:48:56 +0100662 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000663 stride,
664 pad,
665 kernel,
666 validator_fcns=None,
667 error_name=None,
668 qinfo=None,
669 ):
670 result_tens = OutputShaper.pool2dOp(
671 self.ser, self.rng, input, kernel, stride, pad, error_name
672 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100673
674 # Ensure new output type has correct qinfo
675 if error_name == ErrorIf.WrongInputType:
676 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000677 qinfo = [
678 TosaQuantGen.getZeroPoint(self, input.dtype),
679 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
680 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100681
682 # Invalidate Input/Output list for error if checks.
683 input_list = [input.name]
684 output_list = [result_tens.name]
685 pCount, cCount = op["operands"]
686 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000687 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
688 self, error_name, input_list, output_list
689 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100690
Les Bell729b0352021-11-24 10:28:21 +0000691 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100692 self.ser,
693 validator_fcns,
694 error_name,
695 op=op,
696 input_shape=input.shape,
697 input_dtype=input.dtype,
698 output_shape=result_tens.shape,
699 output_dtype=result_tens.dtype,
700 kernel=kernel,
701 stride=stride,
702 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000703 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000704 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100705 input_list=input_list,
706 output_list=output_list,
707 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000708 ):
709 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700710
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000711 if qinfo is None:
712 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700713
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000714 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100715 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000716
717 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700718 return result_tens
719
James Ward8b390432022-08-12 20:48:56 +0100720 def build_maxpool2d(
721 self,
722 op,
723 input,
724 stride,
725 pad,
726 kernel,
727 validator_fcns=None,
728 error_name=None,
729 qinfo=None,
730 ):
731 # Same as build_pool2d but manually sets accum_dtype value
732 # (maxpool has no accum_dtype)
733 return self.build_pool2d(
734 op,
735 input,
736 DType.UNKNOWN,
737 stride,
738 pad,
739 kernel,
740 validator_fcns,
741 error_name,
742 qinfo,
743 )
744
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000745 def build_conv2d(
746 self,
747 op,
748 ifm,
749 filter,
750 bias,
James Ward8b390432022-08-12 20:48:56 +0100751 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000752 strides,
753 padding,
754 dilations,
755 validator_fcns=None,
756 error_name=None,
757 qinfo=None,
758 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800759 assert len(padding) == 4
760 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100761 self.ser,
762 self.rng,
763 ifm,
764 filter,
765 accum_dtype,
766 strides,
767 padding,
768 dilations,
769 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000770 )
771
772 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000773 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
774 DType.INT8,
775 DType.UINT8,
776 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000777 qinfo = [
778 TosaQuantGen.getZeroPoint(self, ifm.dtype),
779 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
780 ]
Les Bell0e027d42021-11-09 14:42:14 +0000781
782 # Invalidate Input/Output list for error_if checks.
783 input_list = [ifm.name, filter.name, bias.name]
784 output_list = [result_tens.name]
785 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000786 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
787 self, error_name, input_list, output_list
788 )
Les Bell0e027d42021-11-09 14:42:14 +0000789
Les Bell729b0352021-11-24 10:28:21 +0000790 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000791 self.ser,
792 validator_fcns,
793 error_name,
794 op=op,
795 input_dtype=ifm.dtype,
796 weight_dtype=filter.dtype,
797 output_dtype=result_tens.dtype,
798 qinfo=qinfo,
799 input_list=input_list,
800 num_operands=num_operands,
801 output_list=output_list,
802 pad=padding,
803 stride=strides,
804 dilation=dilations,
805 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100806 weight_shape=filter.shape,
807 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000808 ):
809 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700810
811 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000812 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700813
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000814 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700815 return result_tens
816
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000817 def build_conv3d(
818 self,
819 op,
820 ifm,
821 filter,
822 bias,
James Ward8b390432022-08-12 20:48:56 +0100823 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000824 strides,
825 padding,
826 dilations,
827 validator_fcns=None,
828 error_name=None,
829 qinfo=None,
830 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700831 assert len(padding) == 6
832 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100833 self.ser,
834 self.rng,
835 ifm,
836 filter,
837 accum_dtype,
838 strides,
839 padding,
840 dilations,
841 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000842 )
843
844 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000845 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
846 DType.INT8,
847 DType.UINT8,
848 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000849 qinfo = [
850 TosaQuantGen.getZeroPoint(self, ifm.dtype),
851 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
852 ]
Les Bell0e027d42021-11-09 14:42:14 +0000853
854 # Invalidate Input/Output list for error_if checks.
855 input_list = [ifm.name, filter.name, bias.name]
856 output_list = [result_tens.name]
857 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000858 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
859 self, error_name, input_list, output_list
860 )
Les Bell0e027d42021-11-09 14:42:14 +0000861
Les Bell729b0352021-11-24 10:28:21 +0000862 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000863 self.ser,
864 validator_fcns,
865 error_name,
866 op=op,
867 input_dtype=ifm.dtype,
868 weight_dtype=filter.dtype,
869 output_dtype=result_tens.dtype,
870 qinfo=qinfo,
871 input_list=input_list,
872 num_operands=num_operands,
873 output_list=output_list,
874 pad=padding,
875 stride=strides,
876 dilation=dilations,
877 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100878 weight_shape=filter.shape,
879 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000880 ):
881 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700882
883 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000884 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700885
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000886 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700887 return result_tens
888
Kevin Cheng550ccc52021-03-03 11:21:43 -0800889 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000890 self,
891 op,
892 ifm,
893 filter,
894 bias,
James Ward8b390432022-08-12 20:48:56 +0100895 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000896 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700897 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000898 output_shape,
899 validator_fcns=None,
900 error_name=None,
901 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800902 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700903 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000904 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100905 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000906 )
Les Bell0e027d42021-11-09 14:42:14 +0000907
908 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000909 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
910 DType.INT8,
911 DType.UINT8,
912 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000913 qinfo = [
914 TosaQuantGen.getZeroPoint(self, ifm.dtype),
915 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
916 ]
Les Bell0e027d42021-11-09 14:42:14 +0000917
918 # Invalidate Input/Output list for error_if checks.
919 input_list = [ifm.name, filter.name, bias.name]
920 output_list = [result_tens.name]
921 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000922 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
923 self, error_name, input_list, output_list
924 )
Les Bell0e027d42021-11-09 14:42:14 +0000925
Les Bell729b0352021-11-24 10:28:21 +0000926 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000927 self.ser,
928 validator_fcns,
929 error_name,
930 op=op,
931 input_dtype=ifm.dtype,
932 weight_dtype=filter.dtype,
933 output_dtype=result_tens.dtype,
934 qinfo=qinfo,
935 input_list=input_list,
936 num_operands=num_operands,
937 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700938 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000939 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000940 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100941 weight_shape=filter.shape,
942 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000943 ):
944 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700945
946 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000947 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700948
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000949 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700950 return result_tens
951
Kevin Cheng550ccc52021-03-03 11:21:43 -0800952 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000953 self,
954 op,
955 ifm,
956 filter,
957 bias,
James Ward8b390432022-08-12 20:48:56 +0100958 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000959 strides,
960 padding,
961 dilations,
962 validator_fcns=None,
963 error_name=None,
964 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800965 ):
966 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100967 self.ser,
968 self.rng,
969 ifm,
970 filter,
971 accum_dtype,
972 strides,
973 padding,
974 dilations,
975 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000976 )
977
978 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000979 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
980 DType.INT8,
981 DType.UINT8,
982 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000983 qinfo = [
984 TosaQuantGen.getZeroPoint(self, ifm.dtype),
985 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
986 ]
Les Bell0e027d42021-11-09 14:42:14 +0000987
988 # Invalidate Input/Output list for error_if checks.
989 input_list = [ifm.name, filter.name, bias.name]
990 output_list = [result_tens.name]
991 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000992 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
993 self, error_name, input_list, output_list
994 )
Les Bell0e027d42021-11-09 14:42:14 +0000995
Les Bell729b0352021-11-24 10:28:21 +0000996 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000997 self.ser,
998 validator_fcns,
999 error_name,
1000 op=op,
1001 input_dtype=ifm.dtype,
1002 weight_dtype=filter.dtype,
1003 output_dtype=result_tens.dtype,
1004 qinfo=qinfo,
1005 input_list=input_list,
1006 num_operands=num_operands,
1007 output_list=output_list,
1008 pad=padding,
1009 stride=strides,
1010 dilation=dilations,
1011 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001012 weight_shape=filter.shape,
1013 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001014 ):
1015 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001016
1017 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001018 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07001019
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001020 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001021 return result_tens
1022
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001023 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001024 self,
1025 op,
1026 ifm,
1027 filter,
1028 bias,
1029 accum_dtype,
1030 validator_fcns=None,
1031 error_name=None,
1032 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001033 ):
1034 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001035 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001036 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001037
1038 # Invalidate Input/Output list for error if checks.
1039 input_list = [ifm.name, filter.name, bias.name]
1040 output_list = [result_tens.name]
1041 pCount, cCount = op["operands"]
1042 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001043 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1044 self, error_name, input_list, output_list
1045 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001046
Les Bell729b0352021-11-24 10:28:21 +00001047 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001048 self.ser,
1049 validator_fcns,
1050 error_name,
1051 op=op,
1052 input_shape=ifm.shape,
1053 input_dtype=ifm.dtype,
1054 weight_dtype=filter.dtype,
1055 output_shape=result_tens.shape,
1056 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001057 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001058 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001059 input_list=input_list,
1060 output_list=output_list,
1061 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001062 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001063 ):
1064 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001065
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001066 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001067 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001068
1069 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001070 return result_tens
1071
James Ward8b390432022-08-12 20:48:56 +01001072 def build_matmul(
Jeremy Johnson1271c442023-09-05 11:39:26 +01001073 self, op, a, b, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001074 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001075 accum_dtype = args_dict["acc_type"]
1076 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001077 self.ser, self.rng, a, b, accum_dtype, error_name
1078 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001079
1080 # Invalidate Input/Output list for error if checks.
1081 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001082 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001083 pCount, cCount = op["operands"]
1084 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001085 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1086 self, error_name, input_list, output_list
1087 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001088
Les Bell729b0352021-11-24 10:28:21 +00001089 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001090 self.ser,
1091 validator_fcns,
1092 error_name,
1093 op=op,
1094 input_shape=a.shape,
1095 input_dtype=a.dtype,
1096 input2_shape=b.shape,
1097 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001098 output_shape=result_tensor.shape,
1099 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001100 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001101 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001102 input_list=input_list,
1103 output_list=output_list,
1104 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001105 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001106 ):
1107 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001108
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001109 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001110 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001111
1112 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001113
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001114 if gtu.dtypeIsSupportedByCompliance(a.dtype):
1115 compliance = self.tensorComplianceMetaData(
1116 op, args_dict, result_tensor, error_name
1117 )
1118 else:
1119 compliance = None
Jeremy Johnson1271c442023-09-05 11:39:26 +01001120
1121 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001122
Matthew Haddond6ce7252021-09-29 15:35:44 +01001123 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1124 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1125
1126 # Invalidate Input/Output list for error if checks.
1127 input_list = [a.name]
1128 output_list = [result_tens.name]
1129 pCount, cCount = op["operands"]
1130 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001131 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1132 self, error_name, input_list, output_list
1133 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001134
Les Bell729b0352021-11-24 10:28:21 +00001135 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001136 self.ser,
1137 validator_fcns,
1138 error_name,
1139 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001140 axis=axis,
1141 input_shape=a.shape,
1142 output_shape=result_tens.shape,
1143 input_dtype=a.dtype,
1144 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001145 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001146 input_list=input_list,
1147 output_list=output_list,
1148 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001149 ):
1150 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001151
1152 attr = ts.TosaSerializerAttribute()
1153 attr.AxisAttribute(axis)
1154
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001155 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001156 return result_tens
1157
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001158 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1159 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001160
Jeremy Johnson18e26662021-07-22 16:15:29 +01001161 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001162
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001163 if error_name == ErrorIf.MaxSmallerMin:
1164 # Make sure the numbers are different to invoke this error
1165 while v[0] == v[1]:
1166 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1167 max_val = min(v)
1168 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001169 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001170 max_val = max(v)
1171 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001172
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001173 # Invalidate Input/Output list for error if checks.
1174 input_list = [a.name]
1175 output_list = [result_tens.name]
1176 pCount, cCount = op["operands"]
1177 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001178 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1179 self, error_name, input_list, output_list
1180 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001181
Les Bell729b0352021-11-24 10:28:21 +00001182 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001183 self.ser,
1184 validator_fcns,
1185 error_name,
1186 op=op,
1187 max_val=max_val,
1188 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001189 input_shape=a.shape,
1190 output_shape=result_tens.shape,
1191 input_dtype=a.dtype,
1192 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001193 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001194 input_list=input_list,
1195 output_list=output_list,
1196 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001197 ):
1198 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001199
1200 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001201 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1202 if a.dtype == DType.FP16:
1203 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1204 min_val = min_val.astype(np.float32)
1205 max_val = max_val.astype(np.float32)
1206
1207 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001208 else:
James Ward34071252022-12-07 15:48:47 +00001209 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001210
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001211 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001212 return result_tens
1213
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001214 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1215 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001216 attr = ts.TosaSerializerAttribute()
1217
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001218 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001219
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001220 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001221 return result_tens
1222
1223 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001224 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1225 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001226
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001227 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001228 return result_tens
1229
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001230 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1231 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1232
1233 # Invalidate Input/Output list for error if checks.
1234 input_list = [a.name]
1235 output_list = [result_tens.name]
1236 pCount, cCount = op["operands"]
1237 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001238 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1239 self, error_name, input_list, output_list
1240 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001241
Les Bell729b0352021-11-24 10:28:21 +00001242 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001243 self.ser,
1244 validator_fcns,
1245 error_name,
1246 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001247 input_shape=a.shape,
1248 output_shape=result_tens.shape,
1249 input_dtype=a.dtype,
1250 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001251 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001252 input_list=input_list,
1253 output_list=output_list,
1254 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001255 ):
1256 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001257
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001258 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001259 return result_tens
1260
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001261 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1262 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1263
1264 # Invalidate Input/Output list for error if checks.
1265 input_list = [a.name]
1266 output_list = [result_tens.name]
1267 pCount, cCount = op["operands"]
1268 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001269 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1270 self, error_name, input_list, output_list
1271 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001272
Les Bell729b0352021-11-24 10:28:21 +00001273 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001274 self.ser,
1275 validator_fcns,
1276 error_name,
1277 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001278 input_shape=a.shape,
1279 output_shape=result_tens.shape,
1280 input_dtype=a.dtype,
1281 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001282 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001283 input_list=input_list,
1284 output_list=output_list,
1285 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001286 ):
1287 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001288
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001289 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001290 return result_tens
1291
Won Jeon78155c62023-06-10 00:20:04 +00001292 def build_erf(self, op, a, validator_fcns=None, error_name=None):
1293 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1294
1295 # Invalidate Input/Output list for error if checks.
1296 input_list = [a.name]
1297 output_list = [result_tens.name]
1298 pCount, cCount = op["operands"]
1299 num_operands = pCount + cCount
1300 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1301 self, error_name, input_list, output_list
1302 )
1303
1304 if not TosaErrorValidator.evValidateErrorIfs(
1305 self.ser,
1306 validator_fcns,
1307 error_name,
1308 op=op,
1309 input_shape=a.shape,
1310 output_shape=result_tens.shape,
1311 input_dtype=a.dtype,
1312 output_dtype=result_tens.dtype,
1313 result_tensors=[result_tens],
1314 input_list=input_list,
1315 output_list=output_list,
1316 num_operands=num_operands,
1317 ):
1318 return None
1319
1320 self.ser.addOperator(op["op"], input_list, output_list)
1321 return result_tens
1322
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001323 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1324 if error_name != ErrorIf.WrongInputType:
1325 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001326
1327 # To store variable length list of input tensors we need to store axis along with it
1328 axis = a[-1]
1329 a = a[:-1]
1330
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001331 result_tens = OutputShaper.concatOp(
1332 self.ser, self.rng, axis, *a, error_name=error_name
1333 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001334
Matthew Haddon818ab902021-07-27 09:12:49 +01001335 input_tensor_names = []
1336 for tensor in a:
1337 input_tensor_names.append(tensor.name)
1338
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001339 # Invalidate Input/Output list for error if checks.
1340 input_list = input_tensor_names
1341 output_list = [result_tens.name]
1342 pCount, cCount = op["operands"]
1343 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001344 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1345 self, error_name, input_list, output_list
1346 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001347
Les Bell729b0352021-11-24 10:28:21 +00001348 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001349 self.ser,
1350 validator_fcns,
1351 error_name,
1352 op=op,
1353 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001354 input_shape=a[0].shape,
1355 output_shape=result_tens.shape,
1356 input_dtype=a[0].dtype,
1357 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001358 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001359 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001360 input_list=input_list,
1361 output_list=output_list,
1362 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001363 ):
1364 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001365
1366 attr = ts.TosaSerializerAttribute()
1367 attr.AxisAttribute(axis)
1368
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001369 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001370 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001371
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001372 def build_pad(
1373 self,
1374 op,
1375 a,
1376 padding,
1377 pad_const_int,
1378 pad_const_float,
1379 validator_fcns=None,
1380 error_name=None,
1381 qinfo=None,
1382 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001383 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001384
Kevin Chengfe392ce2021-10-18 21:51:55 +00001385 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001386 attr.PadAttribute(
1387 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1388 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001389
Matthew Haddone807aae2021-10-11 18:12:58 +01001390 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001391 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001392 output_list = [result_tens.name]
1393 pCount, cCount = op["operands"]
1394 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001395 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1396 self, error_name, input_list, output_list
1397 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001398
Les Bell729b0352021-11-24 10:28:21 +00001399 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001400 self.ser,
1401 validator_fcns,
1402 error_name,
1403 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001404 input_shape=a.shape,
1405 output_shape=result_tens.shape,
1406 input_dtype=a.dtype,
1407 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001408 pad=padding,
1409 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001410 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001411 input_list=input_list,
1412 output_list=output_list,
1413 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001414 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001415 ):
1416 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001417
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001418 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001419 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001420
Won Jeona21b2e82023-08-10 10:33:01 +00001421 def build_dim(
1422 self,
1423 op,
1424 a,
1425 axis,
1426 validator_fcns=None,
1427 error_name=None,
1428 qinfo=None,
1429 ):
1430 result_tens = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
1431
1432 # Invalidate Input/Output list for error if checks.
1433 input_list = [a.name]
1434 output_list = [result_tens.name]
1435 pCount, cCount = op["operands"]
1436 num_operands = pCount + cCount
1437 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1438 self, error_name, input_list, output_list
1439 )
1440
1441 if not TosaErrorValidator.evValidateErrorIfs(
1442 self.ser,
1443 validator_fcns,
1444 error_name,
1445 op=op,
1446 axis=axis,
1447 input_shape=a.shape,
1448 input_dtype=a.dtype,
1449 output_shape=result_tens.shape,
1450 output_dtype=result_tens.dtype,
1451 result_tensors=[result_tens],
1452 input_list=input_list,
1453 output_list=output_list,
1454 num_operands=num_operands,
1455 ):
1456 return None
1457
1458 attr = ts.TosaSerializerAttribute()
1459 attr.AxisAttribute(axis)
1460
1461 self.ser.addOperator(op["op"], input_list, output_list, attr)
1462 return result_tens
1463
Matthew Haddone807aae2021-10-11 18:12:58 +01001464 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001465 result_tens = OutputShaper.reshapeOp(
1466 self.ser, self.rng, a, newShape, error_name
1467 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001468
1469 # Invalidate Input/Output list for error if checks.
1470 input_list = [a.name]
1471 output_list = [result_tens.name]
1472 pCount, cCount = op["operands"]
1473 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001474 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1475 self, error_name, input_list, output_list
1476 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001477
Les Bell729b0352021-11-24 10:28:21 +00001478 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001479 self.ser,
1480 validator_fcns,
1481 error_name,
1482 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001483 input_shape=a.shape,
1484 output_shape=result_tens.shape,
1485 input_dtype=a.dtype,
1486 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001487 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001488 input_list=input_list,
1489 output_list=output_list,
1490 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001491 ):
1492 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001493
1494 attr = ts.TosaSerializerAttribute()
1495 attr.ReshapeAttribute(newShape)
1496
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001497 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001498 return result_tens
1499
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001500 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1501 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1502
1503 # Invalidate Input/Output list for error if checks.
1504 input_list = [a.name]
1505 output_list = [result_tens.name]
1506 pCount, cCount = op["operands"]
1507 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001508 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1509 self, error_name, input_list, output_list
1510 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001511
Les Bell729b0352021-11-24 10:28:21 +00001512 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001513 self.ser,
1514 validator_fcns,
1515 error_name,
1516 op=op,
1517 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001518 input_shape=a.shape,
1519 output_shape=result_tens.shape,
1520 input_dtype=a.dtype,
1521 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001522 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001523 input_list=input_list,
1524 output_list=output_list,
1525 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001526 ):
1527 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001528
1529 attr = ts.TosaSerializerAttribute()
1530 attr.AxisAttribute(axis)
1531
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001532 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001533 return result_tens
1534
Matthew Haddone807aae2021-10-11 18:12:58 +01001535 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1536 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001537
Kevin Chengfe392ce2021-10-18 21:51:55 +00001538 attr = ts.TosaSerializerAttribute()
1539 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001540
Matthew Haddone807aae2021-10-11 18:12:58 +01001541 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001542 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001543 output_list = [result_tens.name]
1544 pCount, cCount = op["operands"]
1545 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001546 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1547 self, error_name, input_list, output_list
1548 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001549
Les Bell729b0352021-11-24 10:28:21 +00001550 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001551 self.ser,
1552 validator_fcns,
1553 error_name,
1554 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001555 input_shape=a.shape,
1556 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001557 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001558 input_dtype=a.dtype,
1559 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001560 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001561 input_list=input_list,
1562 output_list=output_list,
1563 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001564 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001565 ):
1566 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001567
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001568 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001569 return result_tens
1570
Matthew Haddone807aae2021-10-11 18:12:58 +01001571 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001572 result_tens = OutputShaper.sliceOp(
1573 self.ser, self.rng, a, start, size, error_name
1574 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001575
1576 # Invalidate Input/Output list for error if checks.
1577 input_list = [a.name]
1578 output_list = [result_tens.name]
1579 pCount, cCount = op["operands"]
1580 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001581 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1582 self, error_name, input_list, output_list
1583 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001584
Les Bell729b0352021-11-24 10:28:21 +00001585 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001586 self.ser,
1587 validator_fcns,
1588 error_name,
1589 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001590 input_shape=a.shape,
1591 output_shape=result_tens.shape,
1592 input_dtype=a.dtype,
1593 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001594 start=start,
1595 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001596 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001597 input_list=input_list,
1598 output_list=output_list,
1599 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001600 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001601 ):
1602 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001603
1604 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001605 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001606
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001607 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001608 return result_tens
1609
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001610 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1611 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1612
1613 # Invalidate Input/Output list for error if checks.
1614 input_list = [a.name]
1615 output_list = [result_tens.name]
1616 pCount, cCount = op["operands"]
1617 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001618 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1619 self, error_name, input_list, output_list
1620 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001621
Les Bell729b0352021-11-24 10:28:21 +00001622 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001623 self.ser,
1624 validator_fcns,
1625 error_name,
1626 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001627 input_shape=a.shape,
1628 output_shape=result_tens.shape,
1629 input_dtype=a.dtype,
1630 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001631 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001632 input_list=input_list,
1633 output_list=output_list,
1634 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001635 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001636 ):
1637 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001638
1639 attr = ts.TosaSerializerAttribute()
1640 attr.TileAttribute(multiples)
1641
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001642 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001643 return result_tens
1644
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001645 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001646
1647 # Create a new indicies tensor
1648 # here with data that doesn't exceed the dimensions of the values tensor
1649
Kevin Cheng550ccc52021-03-03 11:21:43 -08001650 K = values.shape[1] # K
1651 W = self.randInt(
1652 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1653 ) # W
1654 indicies_arr = np.int32(
1655 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1656 ) # (N, W)
1657 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001658
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001659 result_tens = OutputShaper.gatherOp(
1660 self.ser, self.rng, values, indicies, error_name
1661 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001662
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001663 # Invalidate Input/Output list for error if checks.
1664 input_list = [values.name, indicies.name]
1665 output_list = [result_tens.name]
1666 pCount, cCount = op["operands"]
1667 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001668 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1669 self, error_name, input_list, output_list
1670 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001671
Les Bell729b0352021-11-24 10:28:21 +00001672 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001673 self.ser,
1674 validator_fcns,
1675 error_name,
1676 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001677 input_shape=values.shape,
1678 output_shape=result_tens.shape,
1679 input_dtype=values.dtype,
1680 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001681 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001682 input_list=input_list,
1683 output_list=output_list,
1684 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001685 ):
1686 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001687
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001688 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001689
1690 return result_tens
1691
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001692 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001693
1694 # Create a new indicies tensor
1695 # here with data that doesn't exceed the dimensions of the values_in tensor
1696
Kevin Cheng550ccc52021-03-03 11:21:43 -08001697 K = values_in.shape[1] # K
1698 W = input.shape[1] # W
1699 indicies_arr = np.int32(
1700 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1701 ) # (N, W)
1702 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001703
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001704 result_tens = OutputShaper.scatterOp(
1705 self.ser, self.rng, values_in, indicies, input, error_name
1706 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001707
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001708 # Invalidate Input/Output list for error if checks.
1709 input_list = [values_in.name, indicies.name, input.name]
1710 output_list = [result_tens.name]
1711 pCount, cCount = op["operands"]
1712 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001713 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1714 self, error_name, input_list, output_list
1715 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001716
Les Bell729b0352021-11-24 10:28:21 +00001717 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001718 self.ser,
1719 validator_fcns,
1720 error_name,
1721 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001722 input_shape=values_in.shape,
1723 output_shape=result_tens.shape,
1724 input_dtype=values_in.dtype,
1725 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001726 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001727 input_list=input_list,
1728 output_list=output_list,
1729 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001730 ):
1731 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001732
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001733 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001734
Kevin Cheng77d0f762020-11-24 10:26:32 -08001735 return result_tens
1736
Kevin Cheng550ccc52021-03-03 11:21:43 -08001737 def build_resize(
1738 self,
1739 op,
1740 input,
1741 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001742 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001743 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001744 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001745 input_dtype,
1746 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001747 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001748 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001749 ):
1750 result_tens = OutputShaper.resizeOp(
1751 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001752 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001753 input,
1754 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001755 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001756 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001757 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001758 input_dtype,
1759 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001760 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001761 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001762
Matthew Haddon848efb42021-09-09 12:30:53 +01001763 # Invalidate Input/Output list for error if checks.
1764 input_list = [input.name]
1765 output_list = [result_tens.name]
1766 pCount, cCount = op["operands"]
1767 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001768 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1769 self, error_name, input_list, output_list
1770 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001771
Les Bell729b0352021-11-24 10:28:21 +00001772 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001773 self.ser,
1774 validator_fcns,
1775 error_name,
1776 op=op,
1777 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001778 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001779 input_dtype=input_dtype,
1780 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001781 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001782 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001783 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001784 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001785 input_list=input_list,
1786 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001787 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001788 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001789 ):
1790 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001791
Eric Kunzee5e26762020-10-13 16:11:07 -07001792 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001793
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001794 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001795
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001796 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001797 return result_tens
1798
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001799 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1800 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1801 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001802 self.ser.addOperator(
1803 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1804 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001805 return result_tens
1806
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001807 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001808 self.ser.addOutputTensor(val)
1809 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001810
1811 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001812 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001813 result_tens = OutputShaper.typeConversionOp(
1814 self.ser, self.rng, val, out_dtype, error_name
1815 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001816
1817 # Invalidate Input/Output list for error if checks.
1818 input_list = [val.name]
1819 output_list = [result_tens.name]
1820 pCount, cCount = op["operands"]
1821 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001822 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1823 self, error_name, input_list, output_list
1824 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001825
Les Bell729b0352021-11-24 10:28:21 +00001826 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001827 self.ser,
1828 validator_fcns,
1829 error_name,
1830 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001831 input_shape=val.shape,
1832 output_shape=result_tens.shape,
1833 input_dtype=val.dtype,
1834 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001835 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001836 input_list=input_list,
1837 output_list=output_list,
1838 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001839 ):
1840 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001841
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001842 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001843 return result_tens
1844
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001845 def build_rescale(
1846 self,
1847 op,
1848 val,
1849 out_dtype,
1850 scale32,
1851 double_round,
1852 per_channel,
1853 validator_fcns,
1854 error_name,
1855 ):
1856 result_tens = OutputShaper.typeConversionOp(
1857 self.ser, self.rng, val, out_dtype, error_name
1858 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001859
1860 if per_channel:
1861 nc = val.shape[-1]
1862 else:
1863 nc = 1
1864
1865 in_type_width = self.typeWidth(val.dtype)
1866 out_type_width = self.typeWidth(out_dtype)
1867
Kevin Cheng3a478572021-01-22 17:21:02 -08001868 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001869 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001870 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001871 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001872 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001873 in_type_width += 1
1874 elif error_name in [
1875 ErrorIf.InputZeroPointNotZero,
1876 ErrorIf.U16InputZeroPointNotValid,
1877 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001878 input_zp = self.randInt(-128, 128)
1879 if input_zp == 0:
1880 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001881 in_type_width += 1
1882 elif val.dtype == DType.UINT16:
1883 # Must come after ErrorIf.U16InputZeroPointNotValid check
1884 input_zp = self.rng.choice([0, 32768])
1885 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001886 else:
1887 input_zp = 0
1888
Kevin Cheng3a478572021-01-22 17:21:02 -08001889 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001890 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001891 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001892 elif out_dtype == DType.UINT8:
1893 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001894 out_type_width += 1
1895 elif error_name in [
1896 ErrorIf.OutputZeroPointNotZero,
1897 ErrorIf.U16OutputZeroPointNotValid,
1898 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001899 output_zp = self.randInt(-128, 128)
1900 if output_zp == 0:
1901 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001902 out_type_width += 1
1903 elif out_dtype == DType.UINT16:
1904 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1905 output_zp = self.rng.choice([0, 32768])
1906 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001907 else:
1908 output_zp = 0
1909
1910 # Calculate scale based on:
1911 # scale = a *(2^output_width)/(2^input_width))
1912
1913 a = np.float32(self.rng.random(size=[nc]))
1914 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1915
1916 if scale32:
1917 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001918 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001919 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1920 else:
1921 # Cap the scaling at 2^15 - 1 for scale16
1922 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1923
Kevin Cheng550ccc52021-03-03 11:21:43 -08001924 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001925
1926 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1927 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001928 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1929 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001930
1931 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001932 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1933 scale_arr[i], scale32
1934 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001935 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1936 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001937
Kevin Cheng550ccc52021-03-03 11:21:43 -08001938 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001939 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001940 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001941 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001942 assert val.placeholderFilename
1943 values = np.load(
1944 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1945 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001946 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1947 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1948 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1949 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001950 if not np.all(np.array_equal(values, val_adj)):
1951 # Values changed so overwrite file with new values
1952 np.save(
1953 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1954 val_adj,
1955 False,
1956 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001957
Matthew Haddonc2025212021-10-08 21:21:05 +01001958 # Invalidate Input/Output list for error if checks.
1959 input_list = [val.name]
1960 output_list = [result_tens.name]
1961 pCount, cCount = op["operands"]
1962 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001963 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1964 self, error_name, input_list, output_list
1965 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001966
1967 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001968 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001969 self.ser,
1970 validator_fcns,
1971 error_name,
1972 op=op,
1973 input_dtype=val.dtype,
1974 output_dtype=out_dtype,
1975 input_shape=val.shape,
1976 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001977 scale32=scale32,
1978 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001979 input_list=input_list,
1980 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001981 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01001982 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001983 ):
1984 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001985
Eric Kunzee5e26762020-10-13 16:11:07 -07001986 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001987 attr.RescaleAttribute(
1988 input_zp,
1989 output_zp,
1990 multiplier_arr,
1991 shift_arr,
1992 scale32,
1993 double_round,
1994 per_channel,
1995 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001996
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001997 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001998 return result_tens
1999
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002000 def _get_condition_tensor(self, op, cond, error_name):
2001 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002002 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002003 else:
2004 cond_type = DType.BOOL
2005 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2006 choice = self.rng.choice([1, 2])
2007 if choice == 1:
2008 cond_shape = [2]
2009 else:
2010 cond_shape = [1, 2]
2011 else:
2012 # Must be of size 1 (rank 0)
2013 cond_shape = []
2014 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2015 return cond_tens
2016
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002017 def build_cond_if_const(
2018 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2019 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002020 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002021 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002022 # and fill them with const nodes for the body.
2023
2024 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002025 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002026
2027 # Make then/else tensors
2028 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002029
2030 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002031 if error_name in [
2032 ErrorIf.CondIfOutputListThenGraphMismatch,
2033 ErrorIf.CondIfOutputListElseGraphMismatch,
2034 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002035 incorrect_shape = deepcopy(then_tens.shape)
2036 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002037 incorrect_shape[i] += (
2038 self.rng.choice([-3, -2, 2, 3])
2039 if incorrect_shape[i] > 3
2040 else self.rng.choice([1, 2, 4])
2041 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002042 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2043
Jeremy Johnson18e26662021-07-22 16:15:29 +01002044 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2045 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002046
2047 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002048 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002049
2050 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002051 then_block = "THEN_BLOCK"
2052 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002053 attr = ts.TosaSerializerAttribute()
2054 attr.CondIfAttribute(then_block, else_block)
2055
2056 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002057 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002058
Jerry Ge9e94af82022-10-27 09:57:00 -07002059 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002060 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002061 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2062 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2063 else:
2064 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002065 self.ser.addOutputTensor(then_tens)
2066
Jerry Ge9e94af82022-10-27 09:57:00 -07002067 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002068 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2069 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2070 else:
2071 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002072 self.ser.addOutputTensor(else_tens)
2073
Les Bell729b0352021-11-24 10:28:21 +00002074 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002075 self.ser,
2076 validator_fcns,
2077 error_name,
2078 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002079 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002080 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002081 ):
2082 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002083
Eric Kunzee5e26762020-10-13 16:11:07 -07002084 return result_tens
2085
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002086 def build_cond_if_binary(
2087 self, op, a, b, cond, validator_fcns=None, error_name=None
2088 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002089 # For cond_if with a binary op in the then/else blocks, take a and b and
2090 # alternately add or subtract them based on the condition
2091
2092 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002093 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002094
Kevin Cheng550ccc52021-03-03 11:21:43 -08002095 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002096
2097 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002098 then_block = "THEN_BLOCK"
2099 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002100 attr = ts.TosaSerializerAttribute()
2101 attr.CondIfAttribute(then_block, else_block)
2102
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002103 if error_name in [
2104 ErrorIf.CondIfInputListThenGraphMismatch,
2105 ErrorIf.CondIfInputListElseGraphMismatch,
2106 ErrorIf.CondIfOutputListElseGraphMismatch,
2107 ErrorIf.CondIfOutputListThenGraphMismatch,
2108 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002109 incorrect_shape = a.shape.copy()
2110 for i in range(len(incorrect_shape)):
2111 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2112 incorrect_block_input = deepcopy(a)
2113 incorrect_block_input.shape = incorrect_shape
2114
Eric Kunzee5e26762020-10-13 16:11:07 -07002115 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002116 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002117 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002118 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002119
James Ward24dbc422022-10-19 12:20:31 +01002120 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002121 then_op, else_op = Op.ADD, Op.SUB
2122 elif a.dtype in (DType.INT8, DType.INT16):
2123 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2124 else:
2125 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002126
Les Bell6040b4d2021-10-11 12:50:31 +01002127 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002128 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002129 if (
2130 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2131 and block == then_block
2132 ) or (
2133 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2134 and block == else_block
2135 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002136 self.ser.addInputTensor(incorrect_block_input)
2137 self.ser.addInputTensor(b)
2138 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002139 elif (
2140 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2141 and block == then_block
2142 ) or (
2143 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2144 and block == else_block
2145 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002146 self.ser.addInputTensor(a)
2147 self.ser.addInputTensor(b)
2148 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2149 else:
2150 self.ser.addInputTensor(a)
2151 self.ser.addInputTensor(b)
2152 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002153 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002154
Les Bell729b0352021-11-24 10:28:21 +00002155 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002156 self.ser,
2157 validator_fcns,
2158 error_name,
2159 op=op,
2160 a=a,
2161 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002162 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002163 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002164 ):
2165 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002166
Eric Kunzee5e26762020-10-13 16:11:07 -07002167 return result_tens
2168
Matthew Haddon630c17c2021-10-14 15:05:41 +01002169 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002170 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002171
Kevin Cheng550ccc52021-03-03 11:21:43 -08002172 cond_block = "COND_BLOCK"
2173 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002174
2175 attr = ts.TosaSerializerAttribute()
2176 attr.WhileLoopAttribute(cond_block, body_block)
2177
2178 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002179 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002180 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002181 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002182
2183 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002184 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2185 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002186 if error_name == ErrorIf.InputListOutputListMismatch:
2187 incorrect_acc = deepcopy(acc)
2188 for i in range(len(incorrect_acc.shape)):
2189 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2190 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2191 else:
2192 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002193
2194 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002195 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002196 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002197 [iter.name, a.name, acc.name],
2198 [iter_out.name, a_out.name, acc_out.name],
2199 attr,
2200 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002201 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002202
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002203 if error_name in [
2204 ErrorIf.InputListCondGraphMismatch,
2205 ErrorIf.InputListBodyGraphInputMismatch,
2206 ErrorIf.InputListBodyGraphOutputMismatch,
2207 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002208 incorrect_iter = deepcopy(iter)
2209 for i in range(len(incorrect_iter.shape)):
2210 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2211 if len(incorrect_iter.shape) == 0:
2212 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2213
2214 incorrect_acc = deepcopy(acc)
2215 for i in range(len(incorrect_acc.shape)):
2216 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2217
Eric Kunzee5e26762020-10-13 16:11:07 -07002218 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002219 self.ser.addBasicBlock(cond_block)
2220
Matthew Haddon630c17c2021-10-14 15:05:41 +01002221 if error_name == ErrorIf.InputListCondGraphMismatch:
2222 self.ser.addInputTensor(incorrect_iter)
2223 self.ser.addInputTensor(a)
2224 self.ser.addInputTensor(incorrect_acc)
2225 else:
2226 self.ser.addInputTensor(iter)
2227 self.ser.addInputTensor(a)
2228 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002229 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002230
2231 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002232 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002233 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002234 cond_type = DType.BOOL
2235 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2236 choice = self.rng.choice([1, 2])
2237 if choice == 1:
2238 cond_shape = [3]
2239 else:
2240 cond_shape = [1, 2]
2241 else:
2242 cond_shape = []
2243 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002244
Kevin Cheng550ccc52021-03-03 11:21:43 -08002245 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002246
2247 # BODY block (input: a, acc, iter, output: a, acc, iter)
2248 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002249 self.ser.addBasicBlock(body_block)
2250
Matthew Haddon630c17c2021-10-14 15:05:41 +01002251 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2252 self.ser.addInputTensor(incorrect_iter)
2253 self.ser.addInputTensor(a)
2254 self.ser.addInputTensor(incorrect_acc)
2255 else:
2256 self.ser.addInputTensor(iter)
2257 self.ser.addInputTensor(a)
2258 self.ser.addInputTensor(acc)
2259
Kevin Cheng550ccc52021-03-03 11:21:43 -08002260 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002261
2262 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002263 iter_body_out = self.ser.addIntermediate(
2264 incorrect_iter.shape, incorrect_iter.dtype
2265 )
2266 acc_body_out = self.ser.addIntermediate(
2267 incorrect_acc.shape, incorrect_acc.dtype
2268 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002269 else:
2270 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2271 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2272
Eric Kunzee5e26762020-10-13 16:11:07 -07002273 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2274 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2275 self.ser.addOutputTensor(iter_body_out)
2276 self.ser.addOutputTensor(a)
2277 self.ser.addOutputTensor(acc_body_out)
2278
Les Bell729b0352021-11-24 10:28:21 +00002279 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002280 self.ser,
2281 validator_fcns,
2282 error_name,
2283 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002284 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002285 ):
2286 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002287
Eric Kunzee5e26762020-10-13 16:11:07 -07002288 return acc_out
2289
Luke Hutton57287132023-02-06 14:54:18 +00002290 def build_fft2d(
2291 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2292 ):
2293 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2294
2295 input_names = [val1.name, val2.name]
2296 pCount, cCount = op["operands"]
2297 num_operands = pCount + cCount
2298
2299 output_names = [res.name for res in results]
2300 output_shapes = [res.shape for res in results]
2301 output_dtypes = [res.dtype for res in results]
2302
2303 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2304 self, error_name, input_names, output_names
2305 )
2306
2307 if not TosaErrorValidator.evValidateErrorIfs(
2308 self.ser,
2309 validator_fcns,
2310 error_name,
2311 op=op,
2312 inverse=inverse,
2313 input1=val1,
2314 input2=val2,
2315 input_shape=val1.shape,
2316 input_dtype=val1.dtype,
2317 output_shape=output_shapes,
2318 output_dtype=output_dtypes,
2319 result_tensors=results,
2320 input_list=input_names,
2321 output_list=output_names,
2322 num_operands=num_operands,
2323 ):
2324 return None
2325
2326 attr = ts.TosaSerializerAttribute()
2327 attr.FFTAttribute(inverse)
2328
2329 self.ser.addOperator(op["op"], input_names, output_names, attr)
2330 return results
2331
Luke Hutton261b7b62023-01-10 14:50:31 +00002332 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2333 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2334
2335 input_names = [val.name]
2336 pCount, cCount = op["operands"]
2337 num_operands = pCount + cCount
2338
2339 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002340 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002341 output_dtypes = [res.dtype for res in results]
2342
2343 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2344 self, error_name, input_names, output_names
2345 )
2346
2347 if not TosaErrorValidator.evValidateErrorIfs(
2348 self.ser,
2349 validator_fcns,
2350 error_name,
2351 op=op,
2352 input_shape=val.shape,
2353 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002354 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002355 output_dtype=output_dtypes,
2356 result_tensors=results,
2357 input_list=input_names,
2358 output_list=output_names,
2359 num_operands=num_operands,
2360 ):
2361 return None
2362
2363 self.ser.addOperator(op["op"], input_names, output_names)
2364 return results
2365
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002366 def create_filter_lists(
2367 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2368 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002369 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2370 default_test_rank_range = range(1, 5)
2371 if not shapeFilter:
2372 shapeFilter = [None]
2373
2374 # Calculate the filters based on what is requested and what the operator allows
2375 rmin, rmax = op["rank"]
2376 if rankFilter is not None:
2377 cleanRankFilter = []
2378 # Ensure rankFilter values are allowed by operator
2379 for rank in rankFilter:
2380 if rank >= rmin and rank <= rmax:
2381 cleanRankFilter.append(rank)
2382 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002383 # Ensure default behaviour is bounded by default range or by operator,
2384 # whichever is the smaller range of ranks.
2385 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002386 cleanRankFilter = (
2387 opRankRange
2388 if len(opRankRange) <= len(default_test_rank_range)
2389 else default_test_rank_range
2390 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002391 else:
2392 cleanRankFilter = range(rmin, rmax + 1)
2393
2394 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002395
Matthew Haddon1c00b712021-10-01 15:51:03 +01002396 if dtypeFilter is not None:
2397 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002398 # Create list of operator dtypes filtered by requested dtypes
2399 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002400 if dtype in dtypeFilter or (
2401 isinstance(dtype, list) and dtype[0] in dtypeFilter
2402 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002403 cleanDtypeFilter.append(dtype)
2404 else:
2405 cleanDtypeFilter = dtypes
2406
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002407 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002408 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002409 "shapeFilter": shapeFilter,
2410 "rankFilter": cleanRankFilter,
2411 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002412 }
2413 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002414 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002415 if validator is not None:
2416 validator_info = validator(check=False, op=op)
2417 else:
2418 return None
2419
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002420 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002421
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002422 # Set parameters as required
2423 if error_arguments["rank"] is not None:
2424 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002425 else:
2426 rankFilter = cleanRankFilter
2427
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002428 if error_arguments["dtype"] is not None:
2429 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002430 else:
2431 dtypeFilter = cleanDtypeFilter
2432
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002433 if error_arguments["shape"] is not None:
2434 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002435 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002436 shapeFilter = shapeFilter[
2437 :2
2438 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002439
2440 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002441 "shapeFilter": shapeFilter,
2442 "rankFilter": rankFilter,
2443 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002444 }
2445 return filterDict
2446
Kevin Cheng550ccc52021-03-03 11:21:43 -08002447 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002448 self,
2449 opName,
2450 shapeFilter=[None],
2451 rankFilter=None,
2452 dtypeFilter=None,
2453 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002454 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002455
2456 try:
2457 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002458 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002459 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002460
2461 # Initialize a new random number generator
2462 self.rng = np.random.default_rng(self.random_seed)
2463
Jeremy Johnson1271c442023-09-05 11:39:26 +01002464 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002465
Eric Kunzee5e26762020-10-13 16:11:07 -07002466 # Test list consists of a tuple of:
2467 # (opName, testNameStr, dtype, shapeList, argumentsList)
2468 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002469 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002470 error_if_validators = op["error_if_validators"]
2471 else:
2472 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002473
Matthew Haddon1c00b712021-10-01 15:51:03 +01002474 for validator in error_if_validators:
2475 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002476 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002477 else:
2478 error_name = None
2479
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002480 filterDict = self.create_filter_lists(
2481 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2482 )
2483 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002484 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002485 cleanRankFilter = filterDict["rankFilter"]
2486 cleanDtypeFilter = filterDict["dtypeFilter"]
2487 cleanShapeFilter = filterDict["shapeFilter"]
2488 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002489
2490 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002491 for t in cleanDtypeFilter:
2492 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002493 # Filter out by rank
2494 if shape is not None and len(shape) != r:
2495 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002496 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002497 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002498
Matthew Haddon74567092021-07-16 15:38:20 +01002499 shapeStr = self.shapeStr(shapeList[0])
2500 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002501
Matthew Haddon74567092021-07-16 15:38:20 +01002502 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2503 argList = []
2504 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002505 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002506 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002507 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002508
Matthew Haddon74567092021-07-16 15:38:20 +01002509 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002510 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002511 if argStr:
2512 testStr = "{}_{}_{}_{}".format(
2513 opName, shapeStr, typeStr, argStr
2514 )
2515 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002516 testStr = "{}_{}_{}".format(
2517 opName, shapeStr, typeStr
2518 )
2519 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002520 if argStr:
2521 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2522 opName, error_name, shapeStr, typeStr, argStr
2523 )
2524 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002525 testStr = "{}_ERRORIF_{}_{}_{}".format(
2526 opName, error_name, shapeStr, typeStr
2527 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002528
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002529 testList.append(
2530 (opName, testStr, t, error_name, shapeList, args)
2531 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002532
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002533 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002534 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2535 if "invalid_test_validators" in op:
2536 invalid_test_validators = op["invalid_test_validators"]
2537 clean_testList = []
2538 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002539 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002540 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002541 if validator_fcn(
2542 opName=test[0],
2543 input_dtype=test[2],
2544 shapeList=test[4],
2545 args=test[5],
2546 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002547 remove_test = True
2548 if not remove_test:
2549 clean_testList.append(test)
2550 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002551
2552 return testList
2553
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002554 def serializeTest(
2555 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2556 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002557 try:
2558 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002559 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002560 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002561
Jeremy Johnson0c716862023-04-13 17:18:19 +01002562 if self.args.verbose:
2563 print(f"Creating {testStr}")
2564
Eric Kunzee5e26762020-10-13 16:11:07 -07002565 # Create a serializer
2566 self.createSerializer(opName, testStr)
2567
Jeremy Johnson1271c442023-09-05 11:39:26 +01002568 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002569 if "error_if_validators" in op:
2570 error_if_validators = op["error_if_validators"]
2571 else:
2572 error_if_validators = None
2573
Kevin Cheng550ccc52021-03-03 11:21:43 -08002574 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002575 num_operands = pCount + cCount
2576
2577 if isinstance(dtype_or_dtypeList, list):
2578 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002579 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002580 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002581 else:
2582 dtypeList = [dtype_or_dtypeList] * (num_operands)
2583
Kevin Cheng93a16282021-08-31 16:14:03 -07002584 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002585 assert (
2586 len(shapeList) == num_operands
2587 ), "shapeList length {} must match number of operands {}".format(
2588 len(shapeList), num_operands
2589 )
2590 assert (
2591 len(dtypeList) == num_operands
2592 ), "dtypeList length {} must match number of operands {}".format(
2593 len(dtypeList), num_operands
2594 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002595
2596 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002597 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002598 except KeyError:
2599 qgen = None
2600
2601 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002602
Matthew Haddon1c00b712021-10-01 15:51:03 +01002603 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002604 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002605 else:
2606 qinfo = None
2607
Jeremy Johnson1271c442023-09-05 11:39:26 +01002608 # Extra meta data for the desc.json
2609 tensMeta = {}
2610
2611 # Check we are using the new testArgs interface with an argsDict dictionary
2612 if len(testArgs) == 1 and isinstance(testArgs[0], dict):
2613 argsDict = testArgs[0]
2614 assert "dg_type" in argsDict
2615 tvgInfo = tvgen_fcn(
2616 self, opName, dtypeList, shapeList, argsDict, error_name
2617 )
2618 if tvgInfo.dataGenDict:
2619 tensMeta["data_gen"] = tvgInfo.dataGenDict
2620 tens = tvgInfo.tensorList
2621 else:
2622 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002623
Matthew Haddon1c00b712021-10-01 15:51:03 +01002624 try:
2625 if error_if_validators is None:
2626 if qinfo is not None:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002627 result = build_fcn(self, op, *tens, *testArgs, qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002628 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002629 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002630 else:
2631 if qinfo is not None:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002632 result = build_fcn(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002633 self,
2634 op,
2635 *tens,
2636 *testArgs,
2637 validator_fcns=error_if_validators,
2638 error_name=error_name,
2639 qinfo=qinfo,
2640 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002641 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002642 result = build_fcn(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002643 self,
2644 op,
2645 *tens,
2646 *testArgs,
2647 validator_fcns=error_if_validators,
2648 error_name=error_name,
2649 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002650 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002651 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002652 raise e
2653
Jeremy Johnson1271c442023-09-05 11:39:26 +01002654 if result:
Les Bell729b0352021-11-24 10:28:21 +00002655 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002656 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2657 # Add the compliance meta data
2658 # NOTE: This currently expects only one result output
2659 tensMeta["compliance"] = {
2660 "version": "0.1",
2661 "tensors": {result.resultTensor.name: result.complianceDict},
2662 }
2663 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002664 else:
2665 # The test is not valid
2666 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002667
Eric Kunzee5e26762020-10-13 16:11:07 -07002668 def createDynamicOpLists(self):
2669
Jeremy Johnson00423432022-09-12 17:27:37 +01002670 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2671 # Already created these lists (can occur when class is initialized more than once)
2672 return
2673
Eric Kunzee5e26762020-10-13 16:11:07 -07002674 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002675 if not self.args.level8k:
2676 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2677 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2678 else:
2679 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2680 KERNELS_2D = [[1, bigK], [bigK, 2]]
2681 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002682
Kevin Cheng1533b852021-09-01 12:51:58 -07002683 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002684 testName = "conv2d_{}x{}".format(k[0], k[1])
2685 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2686 self.TOSA_OP_LIST[testName]["filter"] = k
2687 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002688
Kevin Cheng550ccc52021-03-03 11:21:43 -08002689 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2690 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2691 "depthwise_conv2d_TEMPLATE"
2692 ].copy()
2693 self.TOSA_OP_LIST[testName]["filter"] = k
2694 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002695
Kevin Cheng550ccc52021-03-03 11:21:43 -08002696 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2697 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2698 "transpose_conv2d_TEMPLATE"
2699 ].copy()
2700 self.TOSA_OP_LIST[testName]["filter"] = k
2701 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002702
Kevin Cheng1533b852021-09-01 12:51:58 -07002703 for k in KERNELS_3D:
2704 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2705 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2706 self.TOSA_OP_LIST[testName]["filter"] = k
2707 self.TOSA_OP_LIST[testName]["template"] = False
2708
Eric Kunzee5e26762020-10-13 16:11:07 -07002709 # Delete any templates after having created any dynamic ops
2710 # This is a two-pass operation because it's bad practice to delete
2711 # keys from dictionaries while iterating
2712 keyList = []
2713 for k in self.TOSA_OP_LIST:
2714 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002715 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002716 keyList.append(k)
2717 continue
2718 except KeyError:
2719 pass
2720
2721 for k in keyList:
2722 del self.TOSA_OP_LIST[k]
2723
2724 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002725 """Fill in default fields for ops if they aren't already specified.
2726 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002727 for op in self.TOSA_OP_LIST:
2728
2729 # Required fields
2730 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002731 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002732 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002733 raise Exception(
2734 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2735 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002736
2737 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002738 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002739 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002740 raise Exception(
2741 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2742 op
2743 )
2744 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002745
2746 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002747 _ = self.TOSA_OP_LIST[op]["types"]
2748 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002749 raise Exception(
2750 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2751 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002752
2753 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002754 _ = self.TOSA_OP_LIST[op]["op"]
2755 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002756 raise Exception(
2757 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2758 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002759
2760 # Put in default rank range, if missing
2761 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002762 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002763 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002764 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002765
2766 # Tensor operator list
2767 # 'op': op name
2768 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002769 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2770 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002771 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2772 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002773 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002774
Kevin Cheng550ccc52021-03-03 11:21:43 -08002775 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002776 TYPE_INT_FP = [
2777 DType.INT8,
2778 DType.INT16,
2779 DType.INT32,
2780 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002781 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002782 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002783 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002784
Kevin Cheng550ccc52021-03-03 11:21:43 -08002785 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002786 TYPE_FI32 = [
2787 DType.FP32,
2788 DType.FP16,
2789 DType.BF16,
2790 DType.INT32,
2791 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002792 TYPE_FIB = [
2793 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002794 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002795 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002796 DType.INT8,
2797 DType.INT16,
2798 DType.INT32,
2799 DType.BOOL,
2800 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002801 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002802
James Ward24dbc422022-10-19 12:20:31 +01002803 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002804
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002805 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002806 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002807 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002808 [DType.INT8, DType.INT8, DType.INT32],
2809 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002810 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002811 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002812 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002813 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002814 ]
2815
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002816 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002817
2818 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002819 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002820 "argmax": {
2821 "op": Op.ARGMAX,
2822 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002823 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002824 "build_fcn": (
2825 build_argmax,
2826 TosaTensorGen.tgBasic,
2827 TosaTensorValuesGen.tvgDefault,
2828 TosaArgGen.agAxis,
2829 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002830 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002831 "error_if_validators": (
2832 TosaErrorValidator.evAxisSmallerZero,
2833 TosaErrorValidator.evAxisLargerRank,
2834 TosaErrorValidator.evArgmaxOutputRankMismatch,
2835 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2836 TosaErrorValidator.evWrongRank,
2837 TosaErrorValidator.evWrongInputType,
2838 TosaErrorValidator.evWrongOutputType,
2839 TosaErrorValidator.evWrongInputList,
2840 TosaErrorValidator.evWrongOutputList,
2841 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002842 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002843 "avg_pool2d": {
2844 "op": Op.AVG_POOL2D,
2845 "operands": (1, 0),
2846 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002847 "build_fcn": (
2848 build_pool2d,
2849 TosaTensorGen.tgNHWC,
2850 TosaTensorValuesGen.tvgDefault,
2851 TosaArgGen.agPooling,
2852 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002853 "qgen": TosaQuantGen.qgUnary,
2854 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002855 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002856 "error_if_validators": (
2857 TosaErrorValidator.evKernelSmallerOne,
2858 TosaErrorValidator.evStrideSmallerOne,
2859 TosaErrorValidator.evPadSmallerZero,
2860 TosaErrorValidator.evWrongRank,
2861 TosaErrorValidator.evWrongInputType,
2862 TosaErrorValidator.evWrongOutputType,
2863 TosaErrorValidator.evWrongInputList,
2864 TosaErrorValidator.evWrongOutputList,
2865 TosaErrorValidator.evInputZeroPointNotZero,
2866 TosaErrorValidator.evOutputZeroPointNotZero,
2867 TosaErrorValidator.evPadLargerEqualKernel,
2868 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002869 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002870 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002871 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002872 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002873 "conv2d_TEMPLATE": {
2874 "op": Op.CONV2D,
2875 "operands": (1, 2),
2876 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002877 "build_fcn": (
2878 build_conv2d,
2879 TosaTensorGen.tgConv2D,
2880 TosaTensorValuesGen.tvgDefault,
2881 TosaArgGen.agConv,
2882 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002883 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002884 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002885 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2886 "error_if_validators": (
2887 TosaErrorValidator.evWrongInputType,
2888 TosaErrorValidator.evWrongOutputType,
2889 TosaErrorValidator.evWrongInputList,
2890 TosaErrorValidator.evWrongOutputList,
2891 TosaErrorValidator.evInputZeroPointNotZero,
2892 TosaErrorValidator.evWeightZeroPointNotZero,
2893 TosaErrorValidator.evPadSmallerZero,
2894 TosaErrorValidator.evStrideSmallerOne,
2895 TosaErrorValidator.evDilationSmallerOne,
2896 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002897 TosaErrorValidator.evConvOutputShapeMismatch,
2898 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002899 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002900 "template": True,
2901 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002902 # Templated operator. Filled in by createDynamicOpLists
2903 "conv3d_TEMPLATE": {
2904 "op": Op.CONV3D,
2905 "operands": (1, 2),
2906 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002907 "build_fcn": (
2908 build_conv3d,
2909 TosaTensorGen.tgConv3D,
2910 TosaTensorValuesGen.tvgDefault,
2911 TosaArgGen.agConv,
2912 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002913 "qgen": TosaQuantGen.qgConv,
2914 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002915 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2916 "error_if_validators": (
2917 TosaErrorValidator.evWrongInputType,
2918 TosaErrorValidator.evWrongOutputType,
2919 TosaErrorValidator.evWrongInputList,
2920 TosaErrorValidator.evWrongOutputList,
2921 TosaErrorValidator.evInputZeroPointNotZero,
2922 TosaErrorValidator.evWeightZeroPointNotZero,
2923 TosaErrorValidator.evPadSmallerZero,
2924 TosaErrorValidator.evStrideSmallerOne,
2925 TosaErrorValidator.evDilationSmallerOne,
2926 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002927 TosaErrorValidator.evConvOutputShapeMismatch,
2928 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002929 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002930 "template": True,
2931 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002932 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002933 "depthwise_conv2d_TEMPLATE": {
2934 "op": Op.DEPTHWISE_CONV2D,
2935 "operands": (1, 2),
2936 "filter": [1, 1],
2937 "rank": (4, 4),
2938 "build_fcn": (
2939 build_depthwise_conv2d,
2940 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002941 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002942 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002943 ),
2944 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002945 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002946 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2947 "error_if_validators": (
2948 TosaErrorValidator.evWrongInputType,
2949 TosaErrorValidator.evWrongOutputType,
2950 TosaErrorValidator.evWrongInputList,
2951 TosaErrorValidator.evWrongOutputList,
2952 TosaErrorValidator.evInputZeroPointNotZero,
2953 TosaErrorValidator.evWeightZeroPointNotZero,
2954 TosaErrorValidator.evPadSmallerZero,
2955 TosaErrorValidator.evStrideSmallerOne,
2956 TosaErrorValidator.evDilationSmallerOne,
2957 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002958 TosaErrorValidator.evConvOutputShapeMismatch,
2959 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002960 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002961 "template": True,
2962 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002963 "fully_connected": {
2964 "op": Op.FULLY_CONNECTED,
2965 "operands": (1, 2),
2966 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002967 "build_fcn": (
2968 build_fully_connected,
2969 TosaTensorGen.tgFullyConnected,
2970 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002971 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002972 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002973 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002974 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002975 "error_if_validators": (
2976 TosaErrorValidator.evInputZeroPointNotZero,
2977 TosaErrorValidator.evWeightZeroPointNotZero,
2978 TosaErrorValidator.evWrongRank,
2979 TosaErrorValidator.evWrongInputType,
2980 TosaErrorValidator.evWrongOutputType,
2981 TosaErrorValidator.evWrongInputList,
2982 TosaErrorValidator.evWrongOutputList,
2983 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002984 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002985 "matmul": {
2986 "op": Op.MATMUL,
2987 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002988 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002989 "build_fcn": (
2990 build_matmul,
2991 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002992 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01002993 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002994 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002995 "qgen": TosaQuantGen.qgMatmul,
2996 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002997 "error_if_validators": (
2998 TosaErrorValidator.evInputZeroPointNotZero,
2999 TosaErrorValidator.evWrongRank,
3000 TosaErrorValidator.evWrongInputType,
3001 TosaErrorValidator.evWrongOutputType,
3002 TosaErrorValidator.evWrongInputList,
3003 TosaErrorValidator.evWrongOutputList,
3004 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003005 "data_gen": {
3006 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3007 "int": (gtu.DataGenType.PSEUDO_RANDOM,),
3008 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003009 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003010 "max_pool2d": {
3011 "op": Op.MAX_POOL2D,
3012 "operands": (1, 0),
3013 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003014 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01003015 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003016 TosaTensorGen.tgNHWC,
3017 TosaTensorValuesGen.tvgDefault,
3018 TosaArgGen.agPooling,
3019 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003020 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003021 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003022 "error_if_validators": (
3023 TosaErrorValidator.evKernelSmallerOne,
3024 TosaErrorValidator.evStrideSmallerOne,
3025 TosaErrorValidator.evPadSmallerZero,
3026 TosaErrorValidator.evWrongRank,
3027 TosaErrorValidator.evWrongInputType,
3028 TosaErrorValidator.evWrongOutputType,
3029 TosaErrorValidator.evWrongInputList,
3030 TosaErrorValidator.evWrongOutputList,
3031 TosaErrorValidator.evPadLargerEqualKernel,
3032 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003033 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003034 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003035 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003036 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003037 "transpose_conv2d_TEMPLATE": {
3038 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003039 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003040 "rank": (4, 4),
3041 "build_fcn": (
3042 build_transpose_conv2d,
3043 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003044 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003045 TosaArgGen.agTransposeConv2D,
3046 ),
3047 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003048 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003049 "invalid_test_validators": (
3050 TosaInvalidValidator.ivHeightWidthInvalid,
3051 TosaInvalidValidator.ivNonPositiveOutputShape,
3052 ),
3053 "error_if_validators": (
3054 TosaErrorValidator.evWrongInputType,
3055 TosaErrorValidator.evWrongOutputType,
3056 TosaErrorValidator.evWrongInputList,
3057 TosaErrorValidator.evWrongOutputList,
3058 TosaErrorValidator.evInputZeroPointNotZero,
3059 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003060 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003061 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003062 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003063 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003064 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003065 "template": True,
3066 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003067 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003068 "clamp": {
3069 "op": Op.CLAMP,
3070 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003071 "build_fcn": (
3072 build_clamp,
3073 TosaTensorGen.tgBasic,
3074 TosaTensorValuesGen.tvgDefault,
3075 None,
3076 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003077 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003078 "error_if_validators": (
3079 TosaErrorValidator.evMaxSmallerMin,
3080 TosaErrorValidator.evWrongInputType,
3081 TosaErrorValidator.evWrongOutputType,
3082 TosaErrorValidator.evWrongInputList,
3083 TosaErrorValidator.evWrongOutputList,
3084 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003085 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003086 "sigmoid": {
3087 "op": Op.SIGMOID,
3088 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003089 "build_fcn": (
3090 build_sigmoid,
3091 TosaTensorGen.tgBasic,
3092 TosaTensorValuesGen.tvgDefault,
3093 None,
3094 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003095 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003096 "error_if_validators": (
3097 TosaErrorValidator.evWrongInputType,
3098 TosaErrorValidator.evWrongOutputType,
3099 TosaErrorValidator.evWrongInputList,
3100 TosaErrorValidator.evWrongOutputList,
3101 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003102 },
3103 "tanh": {
3104 "op": Op.TANH,
3105 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003106 "build_fcn": (
3107 build_tanh,
3108 TosaTensorGen.tgBasic,
3109 TosaTensorValuesGen.tvgDefault,
3110 None,
3111 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003112 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003113 "error_if_validators": (
3114 TosaErrorValidator.evWrongInputType,
3115 TosaErrorValidator.evWrongOutputType,
3116 TosaErrorValidator.evWrongInputList,
3117 TosaErrorValidator.evWrongOutputList,
3118 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003119 },
Won Jeon78155c62023-06-10 00:20:04 +00003120 "erf": {
3121 "op": Op.ERF,
3122 "operands": (1, 0),
3123 "build_fcn": (
3124 build_erf,
3125 TosaTensorGen.tgBasic,
3126 TosaTensorValuesGen.tvgDefault,
3127 None,
3128 ),
3129 "types": TYPE_FP,
3130 "error_if_validators": (
3131 TosaErrorValidator.evWrongInputType,
3132 TosaErrorValidator.evWrongOutputType,
3133 TosaErrorValidator.evWrongInputList,
3134 TosaErrorValidator.evWrongOutputList,
3135 ),
3136 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003137 # Elementwise Binary Operators
3138 "add": {
3139 "op": Op.ADD,
3140 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003141 "build_fcn": (
3142 build_binary_broadcast,
3143 TosaTensorGen.tgBroadcastFuzz,
3144 TosaTensorValuesGen.tvgAddSub,
3145 None,
3146 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003147 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003148 "error_if_validators": (
3149 TosaErrorValidator.evRankMismatch,
3150 TosaErrorValidator.evWrongInputType,
3151 TosaErrorValidator.evWrongOutputType,
3152 TosaErrorValidator.evWrongInputList,
3153 TosaErrorValidator.evWrongOutputList,
3154 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003155 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003156 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003157 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003158 "arithmetic_right_shift": {
3159 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3160 "operands": (2, 0),
3161 "build_fcn": (
3162 build_arithmetic_right_shift,
3163 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003164 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003165 TosaArgGen.agArithmeticRightShift,
3166 ),
3167 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003168 "error_if_validators": (
3169 TosaErrorValidator.evRankMismatch,
3170 TosaErrorValidator.evWrongInputType,
3171 TosaErrorValidator.evWrongOutputType,
3172 TosaErrorValidator.evWrongInputList,
3173 TosaErrorValidator.evWrongOutputList,
3174 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003175 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003176 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003177 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003178 "bitwise_and": {
3179 "op": Op.BITWISE_AND,
3180 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003181 "build_fcn": (
3182 build_binary_broadcast,
3183 TosaTensorGen.tgBroadcastFuzz,
3184 TosaTensorValuesGen.tvgDefault,
3185 None,
3186 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003187 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003188 "error_if_validators": (
3189 TosaErrorValidator.evRankMismatch,
3190 TosaErrorValidator.evWrongInputType,
3191 TosaErrorValidator.evWrongOutputType,
3192 TosaErrorValidator.evWrongInputList,
3193 TosaErrorValidator.evWrongOutputList,
3194 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003195 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003196 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003197 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003198 "bitwise_or": {
3199 "op": Op.BITWISE_OR,
3200 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003201 "build_fcn": (
3202 build_binary_broadcast,
3203 TosaTensorGen.tgBroadcastFuzz,
3204 TosaTensorValuesGen.tvgDefault,
3205 None,
3206 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003207 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003208 "error_if_validators": (
3209 TosaErrorValidator.evRankMismatch,
3210 TosaErrorValidator.evWrongInputType,
3211 TosaErrorValidator.evWrongOutputType,
3212 TosaErrorValidator.evWrongInputList,
3213 TosaErrorValidator.evWrongOutputList,
3214 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003215 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003216 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003217 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003218 "bitwise_xor": {
3219 "op": Op.BITWISE_XOR,
3220 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003221 "build_fcn": (
3222 build_binary_broadcast,
3223 TosaTensorGen.tgBroadcastFuzz,
3224 TosaTensorValuesGen.tvgDefault,
3225 None,
3226 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003227 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003228 "error_if_validators": (
3229 TosaErrorValidator.evRankMismatch,
3230 TosaErrorValidator.evWrongInputType,
3231 TosaErrorValidator.evWrongOutputType,
3232 TosaErrorValidator.evWrongInputList,
3233 TosaErrorValidator.evWrongOutputList,
3234 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003235 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003236 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003237 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003238 "intdiv": {
3239 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003240 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003241 "build_fcn": (
3242 build_binary_broadcast,
3243 TosaTensorGen.tgBroadcastFuzz,
3244 TosaTensorValuesGen.tvgIntDiv,
3245 None,
3246 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003247 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003248 "error_if_validators": (
3249 TosaErrorValidator.evRankMismatch,
3250 TosaErrorValidator.evWrongInputType,
3251 TosaErrorValidator.evWrongOutputType,
3252 TosaErrorValidator.evWrongInputList,
3253 TosaErrorValidator.evWrongOutputList,
3254 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003255 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003256 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003257 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003258 "logical_and": {
3259 "op": Op.LOGICAL_AND,
3260 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003261 "build_fcn": (
3262 build_binary_broadcast,
3263 TosaTensorGen.tgBroadcastFuzz,
3264 TosaTensorValuesGen.tvgDefault,
3265 None,
3266 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003267 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003268 "error_if_validators": (
3269 TosaErrorValidator.evRankMismatch,
3270 TosaErrorValidator.evWrongInputType,
3271 TosaErrorValidator.evWrongOutputType,
3272 TosaErrorValidator.evWrongInputList,
3273 TosaErrorValidator.evWrongOutputList,
3274 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003275 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003276 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003277 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003278 "logical_left_shift": {
3279 "op": Op.LOGICAL_LEFT_SHIFT,
3280 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003281 "build_fcn": (
3282 build_binary_broadcast,
3283 TosaTensorGen.tgBroadcastFuzz,
3284 TosaTensorValuesGen.tvgLogicalShift,
3285 None,
3286 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003287 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003288 "error_if_validators": (
3289 TosaErrorValidator.evRankMismatch,
3290 TosaErrorValidator.evWrongInputType,
3291 TosaErrorValidator.evWrongOutputType,
3292 TosaErrorValidator.evWrongInputList,
3293 TosaErrorValidator.evWrongOutputList,
3294 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003295 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003296 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003297 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003298 "logical_right_shift": {
3299 "op": Op.LOGICAL_RIGHT_SHIFT,
3300 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003301 "build_fcn": (
3302 build_binary_broadcast,
3303 TosaTensorGen.tgBroadcastFuzz,
3304 TosaTensorValuesGen.tvgLogicalShift,
3305 None,
3306 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003307 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003308 "error_if_validators": (
3309 TosaErrorValidator.evRankMismatch,
3310 TosaErrorValidator.evWrongInputType,
3311 TosaErrorValidator.evWrongOutputType,
3312 TosaErrorValidator.evWrongInputList,
3313 TosaErrorValidator.evWrongOutputList,
3314 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003315 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003316 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003317 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003318 "logical_or": {
3319 "op": Op.LOGICAL_OR,
3320 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003321 "build_fcn": (
3322 build_binary_broadcast,
3323 TosaTensorGen.tgBroadcastFuzz,
3324 TosaTensorValuesGen.tvgDefault,
3325 None,
3326 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003327 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003328 "error_if_validators": (
3329 TosaErrorValidator.evRankMismatch,
3330 TosaErrorValidator.evWrongInputType,
3331 TosaErrorValidator.evWrongOutputType,
3332 TosaErrorValidator.evWrongInputList,
3333 TosaErrorValidator.evWrongOutputList,
3334 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003335 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003336 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003337 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 "logical_xor": {
3339 "op": Op.LOGICAL_XOR,
3340 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003341 "build_fcn": (
3342 build_binary_broadcast,
3343 TosaTensorGen.tgBroadcastFuzz,
3344 TosaTensorValuesGen.tvgDefault,
3345 None,
3346 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003347 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003348 "error_if_validators": (
3349 TosaErrorValidator.evRankMismatch,
3350 TosaErrorValidator.evWrongInputType,
3351 TosaErrorValidator.evWrongOutputType,
3352 TosaErrorValidator.evWrongInputList,
3353 TosaErrorValidator.evWrongOutputList,
3354 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003355 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003356 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003357 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003358 "maximum": {
3359 "op": Op.MAXIMUM,
3360 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003361 "build_fcn": (
3362 build_binary_broadcast,
3363 TosaTensorGen.tgBroadcastFuzz,
3364 TosaTensorValuesGen.tvgDefault,
3365 None,
3366 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003367 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003368 "error_if_validators": (
3369 TosaErrorValidator.evRankMismatch,
3370 TosaErrorValidator.evWrongInputType,
3371 TosaErrorValidator.evWrongOutputType,
3372 TosaErrorValidator.evWrongInputList,
3373 TosaErrorValidator.evWrongOutputList,
3374 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003375 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003376 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003377 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003378 "minimum": {
3379 "op": Op.MINIMUM,
3380 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003381 "build_fcn": (
3382 build_binary_broadcast,
3383 TosaTensorGen.tgBroadcastFuzz,
3384 TosaTensorValuesGen.tvgDefault,
3385 None,
3386 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003387 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003388 "error_if_validators": (
3389 TosaErrorValidator.evRankMismatch,
3390 TosaErrorValidator.evWrongInputType,
3391 TosaErrorValidator.evWrongOutputType,
3392 TosaErrorValidator.evWrongInputList,
3393 TosaErrorValidator.evWrongOutputList,
3394 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003395 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003396 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003397 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003398 "mul": {
3399 "op": Op.MUL,
3400 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003401 "build_fcn": (
3402 build_mul,
3403 TosaTensorGen.tgBroadcastFuzz,
3404 TosaTensorValuesGen.tvgMul,
3405 TosaArgGen.agMul,
3406 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003407 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003408 "error_if_validators": (
3409 TosaErrorValidator.evWrongInputType,
3410 TosaErrorValidator.evWrongOutputType,
3411 TosaErrorValidator.evWrongInputList,
3412 TosaErrorValidator.evWrongOutputList,
3413 TosaErrorValidator.evRankMismatch,
3414 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003415 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003416 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003417 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003418 "pow": {
3419 "op": Op.POW,
3420 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003421 "build_fcn": (
3422 build_binary_broadcast,
3423 TosaTensorGen.tgBroadcastFuzz,
3424 TosaTensorValuesGen.tvgDefault,
3425 None,
3426 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003427 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003428 "error_if_validators": (
3429 TosaErrorValidator.evRankMismatch,
3430 TosaErrorValidator.evWrongInputType,
3431 TosaErrorValidator.evWrongOutputType,
3432 TosaErrorValidator.evWrongInputList,
3433 TosaErrorValidator.evWrongOutputList,
3434 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003435 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003436 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003437 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003438 "sub": {
3439 "op": Op.SUB,
3440 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003441 "build_fcn": (
3442 build_binary_broadcast,
3443 TosaTensorGen.tgBroadcastFuzz,
3444 TosaTensorValuesGen.tvgAddSub,
3445 None,
3446 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003447 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003448 "error_if_validators": (
3449 TosaErrorValidator.evRankMismatch,
3450 TosaErrorValidator.evWrongInputType,
3451 TosaErrorValidator.evWrongOutputType,
3452 TosaErrorValidator.evWrongInputList,
3453 TosaErrorValidator.evWrongOutputList,
3454 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003455 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003456 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003457 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003458 "table": {
3459 "op": Op.TABLE,
3460 # Use the automatic generation functions to create the input array
3461 # but create the table tensor in the build function, as it may be
3462 # a different type from the input
3463 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003464 "build_fcn": (
3465 build_table,
3466 TosaTensorGen.tgBasic,
3467 TosaTensorValuesGen.tvgDefault,
3468 TosaArgGen.agTable,
3469 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003470 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003471 "error_if_validators": (
3472 TosaErrorValidator.evWrongInputType,
3473 TosaErrorValidator.evWrongOutputType,
3474 TosaErrorValidator.evWrongInputList,
3475 TosaErrorValidator.evWrongOutputList,
3476 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003477 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003478 # Elementwise Unary operators
3479 "abs": {
3480 "op": Op.ABS,
3481 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003482 "build_fcn": (
3483 build_unary,
3484 TosaTensorGen.tgBasic,
3485 TosaTensorValuesGen.tvgDefault,
3486 None,
3487 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003488 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003489 "error_if_validators": (
3490 TosaErrorValidator.evWrongInputType,
3491 TosaErrorValidator.evWrongOutputType,
3492 TosaErrorValidator.evWrongInputList,
3493 TosaErrorValidator.evWrongOutputList,
3494 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003495 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003496 "bitwise_not": {
3497 "op": Op.BITWISE_NOT,
3498 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003499 "build_fcn": (
3500 build_unary,
3501 TosaTensorGen.tgBasic,
3502 TosaTensorValuesGen.tvgDefault,
3503 None,
3504 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003505 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003506 "error_if_validators": (
3507 TosaErrorValidator.evWrongInputType,
3508 TosaErrorValidator.evWrongOutputType,
3509 TosaErrorValidator.evWrongInputList,
3510 TosaErrorValidator.evWrongOutputList,
3511 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003512 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003513 "ceil": {
3514 "op": Op.CEIL,
3515 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003516 "build_fcn": (
3517 build_unary,
3518 TosaTensorGen.tgBasic,
3519 TosaTensorValuesGen.tvgDefault,
3520 None,
3521 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003522 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003523 "error_if_validators": (
3524 TosaErrorValidator.evWrongInputType,
3525 TosaErrorValidator.evWrongOutputType,
3526 TosaErrorValidator.evWrongInputList,
3527 TosaErrorValidator.evWrongOutputList,
3528 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003529 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003530 "clz": {
3531 "op": Op.CLZ,
3532 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003533 "build_fcn": (
3534 build_unary,
3535 TosaTensorGen.tgBasic,
3536 TosaTensorValuesGen.tvgDefault,
3537 None,
3538 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003539 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003540 "error_if_validators": (
3541 TosaErrorValidator.evWrongInputType,
3542 TosaErrorValidator.evWrongOutputType,
3543 TosaErrorValidator.evWrongInputList,
3544 TosaErrorValidator.evWrongOutputList,
3545 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003546 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003547 "exp": {
3548 "op": Op.EXP,
3549 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003550 "build_fcn": (
3551 build_unary,
3552 TosaTensorGen.tgBasic,
3553 TosaTensorValuesGen.tvgDefault,
3554 None,
3555 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003556 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003557 "error_if_validators": (
3558 TosaErrorValidator.evWrongInputType,
3559 TosaErrorValidator.evWrongOutputType,
3560 TosaErrorValidator.evWrongInputList,
3561 TosaErrorValidator.evWrongOutputList,
3562 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003563 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003564 "floor": {
3565 "op": Op.FLOOR,
3566 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003567 "build_fcn": (
3568 build_unary,
3569 TosaTensorGen.tgBasic,
3570 TosaTensorValuesGen.tvgDefault,
3571 None,
3572 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003573 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003574 "error_if_validators": (
3575 TosaErrorValidator.evWrongInputType,
3576 TosaErrorValidator.evWrongOutputType,
3577 TosaErrorValidator.evWrongInputList,
3578 TosaErrorValidator.evWrongOutputList,
3579 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003580 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003581 "log": {
3582 "op": Op.LOG,
3583 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003584 "build_fcn": (
3585 build_unary,
3586 TosaTensorGen.tgBasic,
3587 TosaTensorValuesGen.tvgDefault,
3588 None,
3589 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003590 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003591 "error_if_validators": (
3592 TosaErrorValidator.evWrongInputType,
3593 TosaErrorValidator.evWrongOutputType,
3594 TosaErrorValidator.evWrongInputList,
3595 TosaErrorValidator.evWrongOutputList,
3596 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003597 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003598 "logical_not": {
3599 "op": Op.LOGICAL_NOT,
3600 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003601 "build_fcn": (
3602 build_unary,
3603 TosaTensorGen.tgBasic,
3604 TosaTensorValuesGen.tvgDefault,
3605 None,
3606 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003607 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003608 "error_if_validators": (
3609 TosaErrorValidator.evWrongInputType,
3610 TosaErrorValidator.evWrongOutputType,
3611 TosaErrorValidator.evWrongInputList,
3612 TosaErrorValidator.evWrongOutputList,
3613 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003614 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003615 "negate": {
3616 "op": Op.NEGATE,
3617 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003618 "build_fcn": (
3619 build_unary,
3620 TosaTensorGen.tgBasic,
3621 TosaTensorValuesGen.tvgNegate,
3622 None,
3623 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003624 "qgen": TosaQuantGen.qgUnary,
3625 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003626 "error_if_validators": (
3627 TosaErrorValidator.evInputZeroPointNotZero,
3628 TosaErrorValidator.evOutputZeroPointNotZero,
3629 TosaErrorValidator.evWrongInputType,
3630 TosaErrorValidator.evWrongOutputType,
3631 TosaErrorValidator.evWrongInputList,
3632 TosaErrorValidator.evWrongOutputList,
3633 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003634 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003635 "reciprocal": {
3636 "op": Op.RECIPROCAL,
3637 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003638 "build_fcn": (
3639 build_unary,
3640 TosaTensorGen.tgBasic,
3641 TosaTensorValuesGen.tvgDefault,
3642 None,
3643 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003644 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003645 "error_if_validators": (
3646 TosaErrorValidator.evWrongInputType,
3647 TosaErrorValidator.evWrongOutputType,
3648 TosaErrorValidator.evWrongInputList,
3649 TosaErrorValidator.evWrongOutputList,
3650 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003651 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003652 "rsqrt": {
3653 "op": Op.RSQRT,
3654 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003655 "build_fcn": (
3656 build_unary,
3657 TosaTensorGen.tgBasic,
3658 TosaTensorValuesGen.tvgDefault,
3659 None,
3660 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003661 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003662 "error_if_validators": (
3663 TosaErrorValidator.evWrongInputType,
3664 TosaErrorValidator.evWrongOutputType,
3665 TosaErrorValidator.evWrongInputList,
3666 TosaErrorValidator.evWrongOutputList,
3667 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003668 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003669 # Elementwise Ternary operators
3670 "select": {
3671 "op": Op.SELECT,
3672 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003673 "build_fcn": (
3674 build_select,
3675 TosaTensorGen.tgBroadcastFuzz,
3676 TosaTensorValuesGen.tvgSelect,
3677 None,
3678 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003679 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003680 "error_if_validators": (
3681 TosaErrorValidator.evRankMismatch,
3682 TosaErrorValidator.evWrongInputType,
3683 TosaErrorValidator.evWrongOutputType,
3684 TosaErrorValidator.evWrongInputList,
3685 TosaErrorValidator.evWrongOutputList,
3686 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003687 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003688 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003689 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003690 # Comparison operators
3691 "equal": {
3692 "op": Op.EQUAL,
3693 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003694 "build_fcn": (
3695 build_comparison,
3696 TosaTensorGen.tgBroadcastFuzz,
3697 TosaTensorValuesGen.tvgEqual,
3698 None,
3699 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003700 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003701 "error_if_validators": (
3702 TosaErrorValidator.evRankMismatch,
3703 TosaErrorValidator.evWrongInputType,
3704 TosaErrorValidator.evWrongOutputType,
3705 TosaErrorValidator.evWrongInputList,
3706 TosaErrorValidator.evWrongOutputList,
3707 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003708 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003709 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003710 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003711 "greater_equal": {
3712 "op": Op.GREATER_EQUAL,
3713 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003714 "build_fcn": (
3715 build_comparison,
3716 TosaTensorGen.tgBroadcastFuzz,
3717 TosaTensorValuesGen.tvgDefault,
3718 None,
3719 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003720 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003721 "error_if_validators": (
3722 TosaErrorValidator.evRankMismatch,
3723 TosaErrorValidator.evWrongInputType,
3724 TosaErrorValidator.evWrongOutputType,
3725 TosaErrorValidator.evWrongInputList,
3726 TosaErrorValidator.evWrongOutputList,
3727 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003728 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003729 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003730 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003731 "greater": {
3732 "op": Op.GREATER,
3733 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003734 "build_fcn": (
3735 build_comparison,
3736 TosaTensorGen.tgBroadcastFuzz,
3737 TosaTensorValuesGen.tvgDefault,
3738 None,
3739 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003740 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003741 "error_if_validators": (
3742 TosaErrorValidator.evRankMismatch,
3743 TosaErrorValidator.evWrongInputType,
3744 TosaErrorValidator.evWrongOutputType,
3745 TosaErrorValidator.evWrongInputList,
3746 TosaErrorValidator.evWrongOutputList,
3747 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003748 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003749 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003750 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003751 # Reduction operators
3752 "reduce_all": {
3753 "op": Op.REDUCE_ALL,
3754 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003755 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003756 "build_fcn": (
3757 build_reduce,
3758 TosaTensorGen.tgBasic,
3759 TosaTensorValuesGen.tvgDefault,
3760 TosaArgGen.agAxis,
3761 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003762 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003763 "error_if_validators": (
3764 TosaErrorValidator.evAxisLargerRank,
3765 TosaErrorValidator.evAxisSmallerZero,
3766 TosaErrorValidator.evShapeOfAxisNotOne,
3767 TosaErrorValidator.evWrongInputType,
3768 TosaErrorValidator.evWrongOutputType,
3769 TosaErrorValidator.evWrongRank,
3770 TosaErrorValidator.evWrongInputList,
3771 TosaErrorValidator.evWrongOutputList,
3772 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003773 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003774 "reduce_any": {
3775 "op": Op.REDUCE_ANY,
3776 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003777 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003778 "build_fcn": (
3779 build_reduce,
3780 TosaTensorGen.tgBasic,
3781 TosaTensorValuesGen.tvgDefault,
3782 TosaArgGen.agAxis,
3783 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003784 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003785 "error_if_validators": (
3786 TosaErrorValidator.evAxisLargerRank,
3787 TosaErrorValidator.evAxisSmallerZero,
3788 TosaErrorValidator.evShapeOfAxisNotOne,
3789 TosaErrorValidator.evWrongInputType,
3790 TosaErrorValidator.evWrongOutputType,
3791 TosaErrorValidator.evWrongRank,
3792 TosaErrorValidator.evWrongInputList,
3793 TosaErrorValidator.evWrongOutputList,
3794 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003795 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003796 "reduce_max": {
3797 "op": Op.REDUCE_MAX,
3798 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003799 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003800 "build_fcn": (
3801 build_reduce,
3802 TosaTensorGen.tgBasic,
3803 TosaTensorValuesGen.tvgDefault,
3804 TosaArgGen.agAxis,
3805 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003806 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003807 "error_if_validators": (
3808 TosaErrorValidator.evAxisLargerRank,
3809 TosaErrorValidator.evAxisSmallerZero,
3810 TosaErrorValidator.evShapeOfAxisNotOne,
3811 TosaErrorValidator.evWrongInputType,
3812 TosaErrorValidator.evWrongOutputType,
3813 TosaErrorValidator.evWrongRank,
3814 TosaErrorValidator.evWrongInputList,
3815 TosaErrorValidator.evWrongOutputList,
3816 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003817 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003818 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003819 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003820 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003821 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003822 "build_fcn": (
3823 build_reduce,
3824 TosaTensorGen.tgBasic,
3825 TosaTensorValuesGen.tvgDefault,
3826 TosaArgGen.agAxis,
3827 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003828 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003829 "error_if_validators": (
3830 TosaErrorValidator.evAxisLargerRank,
3831 TosaErrorValidator.evAxisSmallerZero,
3832 TosaErrorValidator.evShapeOfAxisNotOne,
3833 TosaErrorValidator.evWrongInputType,
3834 TosaErrorValidator.evWrongOutputType,
3835 TosaErrorValidator.evWrongRank,
3836 TosaErrorValidator.evWrongInputList,
3837 TosaErrorValidator.evWrongOutputList,
3838 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003839 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003840 "reduce_product": {
3841 "op": Op.REDUCE_PRODUCT,
3842 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003843 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003844 "build_fcn": (
3845 build_reduce,
3846 TosaTensorGen.tgBasic,
3847 TosaTensorValuesGen.tvgDefault,
3848 TosaArgGen.agAxis,
3849 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003850 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003851 "error_if_validators": (
3852 TosaErrorValidator.evAxisLargerRank,
3853 TosaErrorValidator.evAxisSmallerZero,
3854 TosaErrorValidator.evShapeOfAxisNotOne,
3855 TosaErrorValidator.evWrongInputType,
3856 TosaErrorValidator.evWrongOutputType,
3857 TosaErrorValidator.evWrongRank,
3858 TosaErrorValidator.evWrongInputList,
3859 TosaErrorValidator.evWrongOutputList,
3860 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003861 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003862 "reduce_sum": {
3863 "op": Op.REDUCE_SUM,
3864 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003865 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003866 "build_fcn": (
3867 build_reduce,
3868 TosaTensorGen.tgBasic,
3869 TosaTensorValuesGen.tvgReduceSum,
3870 TosaArgGen.agAxis,
3871 ),
James Ward24dbc422022-10-19 12:20:31 +01003872 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003873 "error_if_validators": (
3874 TosaErrorValidator.evAxisLargerRank,
3875 TosaErrorValidator.evAxisSmallerZero,
3876 TosaErrorValidator.evShapeOfAxisNotOne,
3877 TosaErrorValidator.evWrongInputType,
3878 TosaErrorValidator.evWrongOutputType,
3879 TosaErrorValidator.evWrongRank,
3880 TosaErrorValidator.evWrongInputList,
3881 TosaErrorValidator.evWrongOutputList,
3882 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003883 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003884 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003885 "concat": {
3886 "op": Op.CONCAT,
3887 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003888 "build_fcn": (
3889 build_concat,
3890 TosaTensorGen.tgConcat,
3891 TosaTensorValuesGen.tvgConcat,
3892 TosaArgGen.agAxis,
3893 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003894 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003895 "error_if_validators": (
3896 TosaErrorValidator.evAxisLargerRank,
3897 TosaErrorValidator.evAxisSmallerZero,
3898 TosaErrorValidator.evConcatInputRankMismatch,
3899 TosaErrorValidator.evConcatShapeSumMismatch,
3900 TosaErrorValidator.evConcatInputDimMismatch,
3901 TosaErrorValidator.evWrongInputType,
3902 TosaErrorValidator.evWrongOutputType,
3903 TosaErrorValidator.evWrongOutputList,
3904 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003905 },
3906 "pad": {
3907 "op": Op.PAD,
3908 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003909 "build_fcn": (
3910 build_pad,
3911 TosaTensorGen.tgBasic,
3912 TosaTensorValuesGen.tvgDefault,
3913 TosaArgGen.agPad,
3914 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003915 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003916 "error_if_validators": (
3917 TosaErrorValidator.evWrongInputType,
3918 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003919 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003920 TosaErrorValidator.evWrongOutputType,
3921 TosaErrorValidator.evWrongInputList,
3922 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003923 TosaErrorValidator.evRankMismatch,
3924 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003925 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003926 },
Won Jeona21b2e82023-08-10 10:33:01 +00003927 "dim": {
3928 "op": Op.DIM,
3929 "operands": (1, 0),
3930 "build_fcn": (
3931 build_dim,
3932 TosaTensorGen.tgBasic,
3933 TosaTensorValuesGen.tvgDefault,
3934 TosaArgGen.agAxis,
3935 ),
3936 "types": TYPE_FIB,
3937 "error_if_validators": (
3938 TosaErrorValidator.evAxisLargerRank,
3939 TosaErrorValidator.evAxisSmallerZero,
3940 TosaErrorValidator.evWrongInputType,
3941 TosaErrorValidator.evWrongInputList,
3942 TosaErrorValidator.evWrongOutputList,
3943 TosaErrorValidator.evWrongRank,
3944 ),
3945 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003946 "reshape": {
3947 "op": Op.RESHAPE,
3948 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003949 "build_fcn": (
3950 build_reshape,
3951 TosaTensorGen.tgBasic,
3952 TosaTensorValuesGen.tvgDefault,
3953 TosaArgGen.agReshape,
3954 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003955 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003956 "error_if_validators": (
3957 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3958 TosaErrorValidator.evWrongInputType,
3959 TosaErrorValidator.evWrongOutputType,
3960 TosaErrorValidator.evWrongInputList,
3961 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00003962 TosaErrorValidator.evReshapeOutputSizeMultiInference,
3963 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003964 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003965 },
3966 "reverse": {
3967 "op": Op.REVERSE,
3968 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003969 "build_fcn": (
3970 build_reverse,
3971 TosaTensorGen.tgBasic,
3972 TosaTensorValuesGen.tvgDefault,
3973 TosaArgGen.agAxis,
3974 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003975 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003976 "error_if_validators": (
3977 TosaErrorValidator.evAxisSmallerZero,
3978 TosaErrorValidator.evAxisLargerRank,
3979 TosaErrorValidator.evWrongInputType,
3980 TosaErrorValidator.evWrongOutputType,
3981 TosaErrorValidator.evWrongInputList,
3982 TosaErrorValidator.evWrongOutputList,
3983 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003984 },
3985 "slice": {
3986 "op": Op.SLICE,
3987 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003988 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003989 "build_fcn": (
3990 build_slice,
3991 TosaTensorGen.tgBasic,
3992 TosaTensorValuesGen.tvgDefault,
3993 TosaArgGen.agSlice,
3994 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003995 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003996 "error_if_validators": (
3997 TosaErrorValidator.evStartSmallerZero,
3998 TosaErrorValidator.evSizeSmallerEqualZero,
3999 TosaErrorValidator.evStartSizeOutsideBounds,
4000 TosaErrorValidator.evSizeOutputShapeMismatch,
4001 TosaErrorValidator.evInputSizeStartLengthMismatch,
4002 TosaErrorValidator.evWrongRank,
4003 TosaErrorValidator.evWrongInputType,
4004 TosaErrorValidator.evWrongOutputType,
4005 TosaErrorValidator.evWrongInputList,
4006 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004007 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004008 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004009 },
4010 "tile": {
4011 "op": Op.TILE,
4012 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004013 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004014 "build_fcn": (
4015 build_tile,
4016 TosaTensorGen.tgBasic,
4017 TosaTensorValuesGen.tvgDefault,
4018 TosaArgGen.agTile,
4019 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004020 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004021 "error_if_validators": (
4022 TosaErrorValidator.evWrongInputType,
4023 TosaErrorValidator.evWrongOutputType,
4024 TosaErrorValidator.evWrongInputList,
4025 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004026 TosaErrorValidator.evRankMismatch,
4027 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004028 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004029 },
4030 "transpose": {
4031 "op": Op.TRANSPOSE,
4032 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004033 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004034 "build_fcn": (
4035 build_transpose,
4036 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004037 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004038 TosaArgGen.agTranspose,
4039 ),
4040 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004041 "error_if_validators": (
4042 TosaErrorValidator.evIndexOutsideBounds,
4043 TosaErrorValidator.evIndexUsedTwice,
4044 TosaErrorValidator.evWrongInputType,
4045 TosaErrorValidator.evWrongOutputType,
4046 TosaErrorValidator.evWrongInputList,
4047 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004048 TosaErrorValidator.evWrongRank,
4049 TosaErrorValidator.evRankMismatch,
4050 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004051 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004052 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004053 # Data nodes
4054 "const": {
4055 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004056 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004057 "build_fcn": (
4058 build_const,
4059 TosaTensorGen.tgBasic,
4060 TosaTensorValuesGen.tvgDefault,
4061 None,
4062 ),
Luke Hutton65872422023-02-20 10:33:04 +00004063 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004064 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004065 "identity": {
4066 "op": Op.IDENTITY,
4067 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004068 "build_fcn": (
4069 build_unary,
4070 TosaTensorGen.tgBasic,
4071 TosaTensorValuesGen.tvgDefault,
4072 None,
4073 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004074 "types": TYPE_FIB,
4075 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004076 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004077 "gather": {
4078 "op": Op.GATHER,
4079 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4080 "operands": (1, 0),
4081 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004082 "build_fcn": (
4083 build_gather,
4084 TosaTensorGen.tgBasic,
4085 TosaTensorValuesGen.tvgDefault,
4086 None,
4087 ),
James Ward24dbc422022-10-19 12:20:31 +01004088 "types": (
4089 DType.INT8,
4090 DType.INT16,
4091 DType.INT32,
4092 DType.FP16,
4093 DType.BF16,
4094 DType.FP32,
4095 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004096 "error_if_validators": (
4097 TosaErrorValidator.evWrongInputType,
4098 TosaErrorValidator.evWrongOutputType,
4099 TosaErrorValidator.evWrongInputList,
4100 TosaErrorValidator.evWrongOutputList,
4101 TosaErrorValidator.evWrongRank,
4102 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004103 },
4104 "scatter": {
4105 "op": Op.SCATTER,
4106 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004107 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004108 "operands": (2, 0),
4109 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004110 "build_fcn": (
4111 build_scatter,
4112 TosaTensorGen.tgScatter,
4113 TosaTensorValuesGen.tvgDefault,
4114 None,
4115 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004116 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004117 "error_if_validators": (
4118 TosaErrorValidator.evWrongInputType,
4119 TosaErrorValidator.evWrongOutputType,
4120 TosaErrorValidator.evWrongInputList,
4121 TosaErrorValidator.evWrongOutputList,
4122 TosaErrorValidator.evWrongRank,
4123 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004124 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004125 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004126 "resize": {
4127 "op": Op.RESIZE,
4128 "operands": (1, 0),
4129 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004130 "build_fcn": (
4131 build_resize,
4132 TosaTensorGen.tgNHWC,
4133 TosaTensorValuesGen.tvgDefault,
4134 TosaArgGen.agResize,
4135 ),
James Ward24dbc422022-10-19 12:20:31 +01004136 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004137 "invalid_test_validators": (
4138 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004139 ),
4140 "error_if_validators": (
4141 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004142 TosaErrorValidator.evScaleSmallerEqualZero,
4143 TosaErrorValidator.evScaleNLargerMax,
4144 TosaErrorValidator.evScaleDLargerMax,
4145 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004146 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004147 TosaErrorValidator.evBorderSmallerMin,
4148 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004149 TosaErrorValidator.evWrongInputType,
4150 TosaErrorValidator.evWrongOutputType,
4151 TosaErrorValidator.evWrongRank,
4152 TosaErrorValidator.evWrongInputList,
4153 TosaErrorValidator.evWrongOutputList,
4154 TosaErrorValidator.evBatchMismatch,
4155 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004156 TosaErrorValidator.evResizeOutputShapeMismatch,
4157 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004158 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004159 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004160 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004161 "cast": {
4162 "op": Op.CAST,
4163 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004164 "build_fcn": (
4165 build_cast,
4166 TosaTensorGen.tgBasic,
4167 TosaTensorValuesGen.tvgDefault,
4168 TosaArgGen.agCast,
4169 ),
James Ward8b390432022-08-12 20:48:56 +01004170 "types": (
4171 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004172 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004173 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004174 DType.INT8,
4175 DType.INT16,
4176 DType.INT32,
4177 DType.BOOL,
4178 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004179 "error_if_validators": (
4180 TosaErrorValidator.evWrongInputType,
4181 TosaErrorValidator.evWrongOutputType,
4182 TosaErrorValidator.evWrongInputList,
4183 TosaErrorValidator.evWrongOutputList,
4184 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004185 },
4186 "rescale": {
4187 "op": Op.RESCALE,
4188 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004189 "build_fcn": (
4190 build_rescale,
4191 TosaTensorGen.tgBasic,
4192 TosaTensorValuesGen.tvgDefault,
4193 TosaArgGen.agRescale,
4194 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004195 "types": [
4196 DType.UINT8,
4197 DType.INT8,
4198 DType.INT16,
4199 DType.INT32,
4200 DType.INT48,
4201 DType.UINT16,
4202 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004203 "error_if_validators": (
4204 TosaErrorValidator.evInputZeroPointNotZero,
4205 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004206 TosaErrorValidator.evU16InputZeroPointNotValid,
4207 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004208 TosaErrorValidator.evScaleTrue,
4209 TosaErrorValidator.evScaleNotTrue,
4210 TosaErrorValidator.evWrongInputType,
4211 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004212 TosaErrorValidator.evWrongInputList,
4213 TosaErrorValidator.evWrongOutputList,
4214 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004215 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004216 # Custom
4217 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004218 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004219 # Two varients of cond_if, one that generates one of two constant tensors (no
4220 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4221 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004222 "cond_if_const": {
4223 "op": Op.COND_IF,
4224 "operands": (0, 2),
4225 "build_fcn": (
4226 build_cond_if_const,
4227 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004228 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004229 TosaArgGen.agCondIf,
4230 ),
4231 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004232 "error_if_validators": (
4233 TosaErrorValidator.evOutputListThenGraphMismatch,
4234 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004235 TosaErrorValidator.evCondIfCondNotMatchingBool,
4236 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004237 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004238 },
4239 "cond_if_binary": {
4240 "op": Op.COND_IF,
4241 "operands": (2, 0),
4242 "build_fcn": (
4243 build_cond_if_binary,
4244 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004245 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004246 TosaArgGen.agCondIf,
4247 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004248 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004249 "error_if_validators": (
4250 TosaErrorValidator.evInputListThenGraphMismatch,
4251 TosaErrorValidator.evInputListElseGraphMismatch,
4252 TosaErrorValidator.evOutputListThenGraphMismatch,
4253 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004254 TosaErrorValidator.evCondIfCondNotMatchingBool,
4255 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004256 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004257 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004258 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004259 "while_loop": {
4260 "op": Op.WHILE_LOOP,
4261 "operands": (0, 1),
4262 "build_fcn": (
4263 build_while_loop,
4264 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004265 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004266 TosaArgGen.agWhileLoop,
4267 ),
4268 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004269 "error_if_validators": (
4270 TosaErrorValidator.evInputListOutputListMismatch,
4271 TosaErrorValidator.evInputListCondGraphMismatch,
4272 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4273 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4274 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004275 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004276 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004277 },
Luke Hutton57287132023-02-06 14:54:18 +00004278 "fft2d": {
4279 "op": Op.FFT2D,
4280 "operands": (2, 0),
4281 "rank": (3, 3),
4282 "build_fcn": (
4283 build_fft2d,
4284 TosaTensorGen.tgFFT2d,
4285 TosaTensorValuesGen.tvgDefault,
4286 TosaArgGen.agFFT2d,
4287 ),
4288 "types": [DType.FP32],
4289 "error_if_validators": (
4290 TosaErrorValidator.evWrongInputType,
4291 TosaErrorValidator.evWrongOutputType,
4292 TosaErrorValidator.evWrongInputList,
4293 TosaErrorValidator.evWrongOutputList,
4294 TosaErrorValidator.evWrongRank,
4295 TosaErrorValidator.evBatchMismatch,
4296 TosaErrorValidator.evKernelNotPowerOfTwo,
4297 TosaErrorValidator.evFFTInputShapeMismatch,
4298 TosaErrorValidator.evFFTOutputShapeMismatch,
4299 ),
4300 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004301 "rfft2d": {
4302 "op": Op.RFFT2D,
4303 "operands": (1, 0),
4304 "rank": (3, 3),
4305 "build_fcn": (
4306 build_rfft2d,
4307 TosaTensorGen.tgRFFT2d,
4308 TosaTensorValuesGen.tvgDefault,
4309 TosaArgGen.agNone,
4310 ),
4311 "types": [DType.FP32],
4312 "error_if_validators": (
4313 TosaErrorValidator.evWrongInputType,
4314 TosaErrorValidator.evWrongOutputType,
4315 TosaErrorValidator.evWrongInputList,
4316 TosaErrorValidator.evWrongOutputList,
4317 TosaErrorValidator.evWrongRank,
4318 TosaErrorValidator.evBatchMismatch,
4319 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004320 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004321 ),
4322 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004323 }
4324
Kevin Cheng550ccc52021-03-03 11:21:43 -08004325
Eric Kunzee5e26762020-10-13 16:11:07 -07004326class OutputShaper:
4327 # Methods in this class compute the expected output shape and datatype
4328 # for common classes of operations
4329 def __init__(self):
4330 pass
4331
4332 # These methods return arguments that can be used for
4333 # creating a new output tensor
4334 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004335 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4336 if error_name != ErrorIf.RankMismatch:
4337 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004338 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004339
4340 shape = []
4341 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004342 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004343 shape.append(b.shape[i])
4344 else:
4345 shape.append(a.shape[i])
4346
Jerry Ge135c9552023-05-23 20:59:32 +00004347 fuzz_idx = rng.integers(0, len(a.shape))
4348 if error_name == ErrorIf.DimensionMismatch:
4349 shape[fuzz_idx] += 1
4350
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004351 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004352 all_dtypes = [
4353 DType.INT8,
4354 DType.INT16,
4355 DType.INT32,
4356 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004357 DType.FP16,
4358 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004359 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004360 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004361 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4362 outputDType = rng.choice(wrong_dtypes)
4363 else:
4364 outputDType = a.dtype
4365
4366 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004367
4368 @staticmethod
4369 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004370 assert len(a.shape) == len(b.shape)
4371 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004372
4373 shape = []
4374 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004375 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004376 shape.append(a.shape[i])
4377
Kevin Cheng550ccc52021-03-03 11:21:43 -08004378 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004379
4380 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004381 def unaryOp(ser, rng, a, error_name=None):
4382 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004383 all_dtypes = [
4384 DType.INT8,
4385 DType.INT16,
4386 DType.INT32,
4387 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004388 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004389 DType.FP16,
4390 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004391 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004392 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4393 outputDType = rng.choice(wrong_dtypes)
4394 else:
4395 outputDType = a.dtype
4396
4397 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004398
4399 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004400 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004401 if error_name != ErrorIf.RankMismatch:
4402 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004403 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004404
4405 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004406 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004407 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004408 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4409 else:
4410 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004411
Jerry Ge135c9552023-05-23 20:59:32 +00004412 fuzz_idx = rng.integers(0, len(a.shape))
4413 if error_name == ErrorIf.DimensionMismatch:
4414 shape[fuzz_idx] += 1
4415
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004416 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004417 all_dtypes = [
4418 DType.INT8,
4419 DType.INT16,
4420 DType.INT32,
4421 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004422 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004423 DType.FP16,
4424 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004425 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004426 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4427 outputDType = rng.choice(wrong_dtypes)
4428 else:
4429 outputDType = a.dtype
4430
4431 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004432
4433 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004434 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004435 if error_name != ErrorIf.RankMismatch:
4436 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004437 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004438
4439 # Do broadcast
4440 shape = []
4441 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004442 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004443 shape.append(b.shape[i])
4444 else:
4445 shape.append(a.shape[i])
4446
Jerry Ge135c9552023-05-23 20:59:32 +00004447 fuzz_idx = rng.integers(0, len(a.shape))
4448 if error_name == ErrorIf.DimensionMismatch:
4449 shape[fuzz_idx] += 1
4450
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004451 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004452 wrong_dtypes = [
4453 DType.INT8,
4454 DType.INT16,
4455 DType.INT32,
4456 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004457 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004458 DType.FP16,
4459 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004460 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004461 outputDType = rng.choice(wrong_dtypes)
4462 else:
4463 outputDType = DType.BOOL
4464
4465 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004466
4467 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004468 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004469 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004470 if error_name not in [
4471 ErrorIf.AxisSmallerZero,
4472 ErrorIf.AxisLargerRank,
4473 ErrorIf.ShapeOfAxisNotOne,
4474 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004475 shape[axis] = 1
4476 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4477 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004478
Matthew Haddond6ce7252021-09-29 15:35:44 +01004479 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004480 all_dtypes = [
4481 DType.INT8,
4482 DType.INT16,
4483 DType.INT32,
4484 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004485 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004486 DType.FP16,
4487 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004488 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004489 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4490 outputDType = rng.choice(wrong_dtypes)
4491 else:
4492 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004493
Matthew Haddond6ce7252021-09-29 15:35:44 +01004494 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004495
4496 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004497 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004498 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004499
4500 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4501 del shape[axis]
4502
4503 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4504 remove = rng.choice([True, False])
4505 if remove and len(shape) > 1:
4506 del shape[0]
4507 else:
4508 shape.append(1)
4509 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4510 for i in range(len(shape)):
4511 shape[i] = shape[i] + rng.integers(1, 10)
4512
4513 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004514 all_dtypes = [
4515 DType.INT8,
4516 DType.INT16,
4517 DType.INT32,
4518 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004519 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004520 DType.FP16,
4521 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004522 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004523 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4524 outputDType = rng.choice(wrong_dtypes)
4525 else:
4526 outputDType = DType.INT32
4527
4528 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004529
4530 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004531 def conv2dOp(
4532 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4533 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004534
4535 # IFM: NHWC
4536 # Filter: OHWI
4537 # OFM: NHWC
4538
Kevin Cheng550ccc52021-03-03 11:21:43 -08004539 h = (
4540 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004541 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004542 + padding[0]
4543 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004544 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004545 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004546
Kevin Cheng550ccc52021-03-03 11:21:43 -08004547 w = (
4548 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004549 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004550 + padding[2]
4551 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004552 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004553 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004554
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004555 if error_name == ErrorIf.ConvOutputShapeMismatch:
4556 choices = [1, 2, 3]
4557 change = rng.choice(choices)
4558 # increment in multiples of stride to not hit non-integer error case
4559 if change in [1, 3]:
4560 h = h + (rng.choice(choices) * strides[0])
4561 if change in [2, 3]:
4562 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004563
Eric Kunzee5e26762020-10-13 16:11:07 -07004564 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4565
James Ward8b390432022-08-12 20:48:56 +01004566 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004567 # Pick some potentially correct output dtype if input type is incorrect
4568 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004569 else:
James Ward8b390432022-08-12 20:48:56 +01004570 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004571
4572 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004573 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004574 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004575 else:
4576 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004577 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004578 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004579
Kevin Cheng550ccc52021-03-03 11:21:43 -08004580 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004581
4582 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004583 def conv3dOp(
4584 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4585 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004586
4587 # IFM: NDHWC
4588 # Filter: ODHWI
4589 # OFM: NDHWC
4590
4591 d = (
4592 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004593 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004594 + padding[0]
4595 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004596 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004597 ) // strides[0] + 1
4598
4599 h = (
4600 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004601 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004602 + padding[2]
4603 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004604 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004605 ) // strides[1] + 1
4606
4607 w = (
4608 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004609 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004610 + padding[4]
4611 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004612 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004613 ) // strides[2] + 1
4614
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004615 if error_name == ErrorIf.ConvOutputShapeMismatch:
4616 choices = [1, 2, 3, 4]
4617 change = rng.choice(choices)
4618 # increment in multiples of stride to not hit non-integer error case
4619 if change in [1, 4]:
4620 d = d + (rng.choice(choices) * strides[0])
4621 if change in [2, 4]:
4622 h = h + (rng.choice(choices) * strides[1])
4623 if change in [3, 4]:
4624 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004625
Kevin Cheng1533b852021-09-01 12:51:58 -07004626 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4627
James Ward8b390432022-08-12 20:48:56 +01004628 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004629 # Pick some potentially correct output dtype if input type is incorrect
4630 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004631 else:
James Ward8b390432022-08-12 20:48:56 +01004632 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004633
4634 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004635 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004636 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004637 else:
4638 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004639 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004640 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004641
4642 return ser.addOutput(ofm_shape, out_dtype)
4643
4644 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004645 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004646 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004647 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004648 # IFM: NHWC
4649 # Filter: HWCM
4650 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004651
Kevin Cheng550ccc52021-03-03 11:21:43 -08004652 h = (
4653 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004654 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004655 + padding[0]
4656 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004657 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004658 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004659
Kevin Cheng550ccc52021-03-03 11:21:43 -08004660 w = (
4661 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004662 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004663 + padding[2]
4664 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004665 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004666 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004667
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004668 if error_name == ErrorIf.ConvOutputShapeMismatch:
4669 choices = [1, 2, 3]
4670 change = rng.choice(choices)
4671 # increment in multiples of stride to not hit non-integer error case
4672 if change in [1, 3]:
4673 h = h + (rng.choice(choices) * strides[0])
4674 if change in [2, 3]:
4675 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004676
Eric Kunzee5e26762020-10-13 16:11:07 -07004677 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4678
James Ward8b390432022-08-12 20:48:56 +01004679 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004680 # Pick some potentially correct output dtype if input type is incorrect
4681 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004682 else:
James Ward8b390432022-08-12 20:48:56 +01004683 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004684
4685 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004686 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004687 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004688 else:
4689 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004690 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004691 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004692
Kevin Cheng550ccc52021-03-03 11:21:43 -08004693 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004694
4695 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004696 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004697 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004698 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004699 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004700 h = 1
4701 w = 1
4702 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004703 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4704 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004705
4706 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004707 choices = [1, 2, 3]
4708 change = rng.choice(choices)
4709 # increment in multiples of stride to not hit non-integer error case
4710 if change in [1, 3]:
4711 h = h + (rng.choice(choices) * stride[0])
4712 if change in [2, 3]:
4713 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004714 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004715
4716 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004717 all_dtypes = [
4718 DType.INT8,
4719 DType.INT16,
4720 DType.INT32,
4721 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004722 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004723 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004724 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004725 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004726 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4727 outputDType = rng.choice(wrong_dtypes)
4728 else:
4729 outputDType = ifm.dtype
4730
4731 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004732
4733 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004734 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004735 # input: N, IC
4736 # filter: OC, IC
4737 # output: N, OC
4738
4739 output_shape = [input.shape[0], filter.shape[0]]
4740
James Ward8b390432022-08-12 20:48:56 +01004741 # Validated in arg_gen (also invalidated for ErrorIf)
4742 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004743
Kevin Cheng550ccc52021-03-03 11:21:43 -08004744 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004745
4746 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004747 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004748 # a: N, H, C
4749 # b: N, C, W
4750 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004751
Kevin Cheng2d60f002021-06-09 14:18:32 -07004752 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004753
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004754 if error_name == ErrorIf.WrongOutputType:
4755 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004756 incorrect_types = (
4757 DType.INT4,
4758 DType.INT8,
4759 DType.INT16,
4760 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004761 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004762 DType.FP16,
4763 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004764 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004765 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004766 incorrect_types = (
4767 DType.INT4,
4768 DType.INT8,
4769 DType.INT16,
4770 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004771 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004772 DType.FP16,
4773 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004774 )
James Ward24dbc422022-10-19 12:20:31 +01004775 elif (
4776 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4777 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004778 incorrect_types = (
4779 DType.INT4,
4780 DType.INT8,
4781 DType.INT16,
4782 DType.INT32,
4783 DType.INT48,
4784 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004785 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004786 elif error_name == ErrorIf.WrongInputType:
4787 # Pick some potentially correct output dtype if input type is incorrect
4788 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004789 else:
James Ward8b390432022-08-12 20:48:56 +01004790 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004791
Kevin Cheng550ccc52021-03-03 11:21:43 -08004792 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004793
4794 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004795 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004796 input1 = a[0]
4797 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004798
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004799 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004800 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004801 if not (
4802 # unable to concat tensors of different ranks
4803 error_name == ErrorIf.ConcatInputRankMismatch
4804 # unable to concat tensors along an invalid axis
4805 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004806 ):
4807 for tensor in remaining_inputs:
4808 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004809
Matthew Haddon01c359d2021-10-15 16:30:48 +01004810 if error_name == ErrorIf.ConcatShapeSumMismatch:
4811 output_shape[axis] += rng.integers(5, 10)
4812
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004813 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004814 all_dtypes = {
4815 DType.INT8,
4816 DType.INT16,
4817 DType.INT32,
4818 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004819 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004820 DType.FP16,
4821 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004822 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004823 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4824 outputDType = rng.choice(wrong_dtypes)
4825 else:
4826 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004827
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004828 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004829
4830 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004831 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004832
4833 output_shape = a.shape.copy()
4834
4835 for i in range(len(output_shape)):
4836 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4837
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004838 if error_name == ErrorIf.PadOutputShapeMismatch:
4839 bad_dim = rng.choice(range(len(output_shape)))
4840 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004841 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004842 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004843
Matthew Haddone807aae2021-10-11 18:12:58 +01004844 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004845 all_dtypes = [
4846 DType.INT8,
4847 DType.INT16,
4848 DType.INT32,
4849 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004850 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004851 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004852 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004853 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004854 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4855 outputDType = rng.choice(wrong_dtypes)
4856 else:
4857 outputDType = a.dtype
4858
4859 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004860
4861 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00004862 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00004863 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00004864
4865 if error_name == ErrorIf.WrongOutputType:
4866 all_dtypes = [
4867 DType.INT8,
4868 DType.INT16,
4869 DType.INT32,
4870 DType.INT48,
4871 DType.FP32,
4872 DType.FP16,
4873 DType.BF16,
4874 ]
4875 wrong_dtypes = list(set(all_dtypes))
4876 outputDType = rng.choice(wrong_dtypes)
4877 else:
4878 outputDType = DType.SHAPE
4879
4880 return ser.addOutput(output_shape, outputDType)
4881
4882 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004883 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004884 output_shape = shape.copy()
4885
Matthew Haddone807aae2021-10-11 18:12:58 +01004886 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4887 for i in range(len(output_shape)):
4888 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4889
4890 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004891 all_dtypes = [
4892 DType.INT8,
4893 DType.INT16,
4894 DType.INT32,
4895 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004896 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004897 DType.FP16,
4898 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004899 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004900 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4901 outputDType = rng.choice(wrong_dtypes)
4902 else:
4903 outputDType = a.dtype
4904
4905 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004906
4907 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00004908 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004909
Matthew Haddone807aae2021-10-11 18:12:58 +01004910 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004911 all_dtypes = [
4912 DType.INT8,
4913 DType.INT16,
4914 DType.INT32,
4915 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004916 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004917 DType.FP16,
4918 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004919 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00004920 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01004921 outputDType = rng.choice(wrong_dtypes)
4922 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00004923 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01004924
Luke Huttona4e48ca2023-02-22 11:53:48 +00004925 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004926 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01004927 for index in range(len(output_shape)):
4928 if output_shape[index] <= 2:
4929 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4930 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004931 output_shape[index] = output_shape[index] + rng.choice(
4932 [-2, -1, 1, 2]
4933 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00004934 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
4935 output_shape = input.shape.copy()
4936 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004937 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01004938
4939 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004940
4941 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004942 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004943
4944 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004945 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004946
4947 for i in range(len(output_shape)):
4948 output_shape[i] = a.shape[i] * multiples[i]
4949
Luke Huttona4e48ca2023-02-22 11:53:48 +00004950 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004951 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00004952
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004953 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004954 all_dtypes = [
4955 DType.INT8,
4956 DType.INT16,
4957 DType.INT32,
4958 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004959 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004960 DType.FP16,
4961 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004962 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004963 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4964 outputDType = rng.choice(wrong_dtypes)
4965 else:
4966 outputDType = a.dtype
4967
4968 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004969
4970 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004971 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004972 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004973
Kevin Cheng550ccc52021-03-03 11:21:43 -08004974 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004975
Luke Huttona4e48ca2023-02-22 11:53:48 +00004976 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01004977 for i in range(len(output_shape)):
4978 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004979
Luke Huttona4e48ca2023-02-22 11:53:48 +00004980 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4981 for i in range(len(output_shape)):
4982 output_shape[i] += rng.integers(1, 10)
4983 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004984 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00004985
Matthew Haddone807aae2021-10-11 18:12:58 +01004986 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004987 all_dtypes = [
4988 DType.INT8,
4989 DType.INT16,
4990 DType.INT32,
4991 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004992 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004993 DType.FP16,
4994 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004995 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004996 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4997 outputDType = rng.choice(wrong_dtypes)
4998 else:
4999 outputDType = a.dtype
5000
5001 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005002
5003 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005004 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005005 if error_name != ErrorIf.WrongRank:
5006 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005007 assert len(indices.shape) == 2
5008 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005009
Kevin Cheng77d0f762020-11-24 10:26:32 -08005010 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5011
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005012 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005013 all_dtypes = [
5014 DType.INT8,
5015 DType.INT16,
5016 DType.INT32,
5017 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005018 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005019 DType.FP16,
5020 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005021 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005022 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5023 outputDType = rng.choice(wrong_dtypes)
5024 else:
5025 outputDType = values.dtype
5026
5027 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005028
5029 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005030 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005031 if error_name != ErrorIf.WrongRank:
5032 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005033 assert len(indices.shape) == 2
5034 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005035 assert values_in.shape[0] == indices.shape[0] # N
5036 assert input.shape[1] == indices.shape[1] # W
5037 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005038
5039 output_shape = values_in.shape
5040
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005041 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005042 all_dtypes = [
5043 DType.INT8,
5044 DType.INT16,
5045 DType.INT32,
5046 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005047 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005048 DType.FP16,
5049 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005050 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005051 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5052 outputDType = rng.choice(wrong_dtypes)
5053 else:
5054 outputDType = values_in.dtype
5055
5056 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005057
5058 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005059 def tableOp(ser, rng, input, error_name=None):
5060 # Same shape as the input, dtype dependent on input dtype
5061 if error_name != ErrorIf.WrongInputType:
5062 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005063 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005064 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005065 wrong_dtypes = [
5066 DType.INT8,
5067 DType.INT16,
5068 DType.INT32,
5069 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005070 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005071 DType.FP16,
5072 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005073 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005074 wrong_dtypes.remove(output_dtype)
5075 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005076 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005077
5078 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005079 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005080 serializer,
5081 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005082 input,
5083 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005084 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005085 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005086 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005087 input_dtype,
5088 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005089 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005090 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005091 # Calculate OH, OW
5092 scale_y_n = scale[0]
5093 scale_y_d = scale[1]
5094 scale_x_n = scale[2]
5095 scale_x_d = scale[3]
5096 if error_name == ErrorIf.ScaleSmallerEqualZero:
5097 scale_y_n = max(scale_y_n, 1)
5098 scale_y_d = max(scale_y_d, 1)
5099 scale_x_n = max(scale_x_n, 1)
5100 scale_x_d = max(scale_x_d, 1)
5101
5102 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5103 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5104
5105 if error_name is not None:
5106 # Make sure the output tensor is valid, which can occur when
5107 # scale, offset or border have been changed for ERROR_IFs
5108 oh = max(oh, 1)
5109 ow = max(ow, 1)
5110 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005111 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5112 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005113
5114 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5115 choices = [1, 2, 3]
5116 change = rng.choice(choices)
5117 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5118 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005119 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005120 oh -= scale_y_d
5121 assert oh > 0 # Should have been caught in agResize
5122 else:
5123 oh += scale_y_d
5124 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005125 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005126 ow -= scale_x_d
5127 assert ow > 0 # Should have been caught in agResize
5128 else:
5129 ow += scale_x_d
5130
Matthew Haddon848efb42021-09-09 12:30:53 +01005131 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005132 output_dims = [
5133 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005134 oh,
5135 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005136 input.shape[0],
5137 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005138 elif error_name == ErrorIf.BatchMismatch:
5139 output_dims = [
5140 input.shape[0] + rng.integers(1, 10),
5141 oh,
5142 ow,
5143 input.shape[3],
5144 ]
5145 elif error_name == ErrorIf.ChannelMismatch:
5146 output_dims = [
5147 input.shape[0],
5148 oh,
5149 ow,
5150 input.shape[3] + rng.integers(1, 10),
5151 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005152 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005153 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005154
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005155 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005156
5157 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005158 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005159 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005160
5161 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005162 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005163 if error_name == ErrorIf.ConvOutputShapeMismatch:
5164 choices = [1, 2, 3]
5165 change = rng.choice(choices)
5166 if change in [1, 3]:
5167 output_shape[1] = output_shape[1] + rng.choice(choices)
5168 if change in [2, 3]:
5169 output_shape[2] = output_shape[2] + rng.choice(choices)
5170
James Ward8b390432022-08-12 20:48:56 +01005171 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005172 # Pick some potentially correct output dtype if input type is incorrect
5173 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005174 else:
James Ward8b390432022-08-12 20:48:56 +01005175 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005176
5177 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005178 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005179 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005180 else:
5181 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005182 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005183 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005184
Kevin Cheng550ccc52021-03-03 11:21:43 -08005185 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005186
5187 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005188 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5189 outputs = []
5190
5191 assert ifm1.dtype == ifm2.dtype
5192 input_dtype = ifm1.dtype
5193
5194 if error_name != ErrorIf.FFTInputShapeMismatch:
5195 assert ifm1.shape == ifm2.shape
5196
5197 input_shape = ifm1.shape
5198 if error_name != ErrorIf.WrongRank:
5199 assert len(input_shape) == 3
5200
5201 output_shape = input_shape.copy()
5202 output_dtype = input_dtype
5203
5204 if error_name == ErrorIf.WrongOutputType:
5205 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005206 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005207 output_dtype = rng.choice(wrong_dtypes)
5208 elif error_name == ErrorIf.BatchMismatch:
5209 output_shape[0] += rng.integers(1, 10)
5210 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5211 modify_dim = rng.choice([1, 2])
5212 output_shape[modify_dim] += rng.integers(1, 10)
5213
5214 outputs.append(serializer.addOutput(output_shape, output_dtype))
5215 outputs.append(serializer.addOutput(output_shape, output_dtype))
5216 return outputs
5217
5218 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005219 def rfft2dOp(serializer, rng, value, error_name=None):
5220 outputs = []
5221
5222 input_shape = value.shape
5223 if error_name != ErrorIf.WrongRank:
5224 assert len(input_shape) == 3
5225
5226 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5227
5228 output_dtype = value.dtype
5229 if error_name == ErrorIf.WrongOutputType:
5230 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005231 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005232 output_dtype = rng.choice(wrong_dtypes)
5233 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005234 output_shape[0] += rng.integers(1, 10)
5235 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5236 modify_dim = rng.choice([1, 2])
5237 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005238
5239 outputs.append(serializer.addOutput(output_shape, output_dtype))
5240 outputs.append(serializer.addOutput(output_shape, output_dtype))
5241 return outputs