blob: 8beb2ae8b0d2715af3d127cdcb9f40803ffb553b [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010012from generator.tosa_arg_gen import TosaArgGen
13from generator.tosa_arg_gen import TosaQuantGen
14from generator.tosa_arg_gen import TosaTensorGen
15from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000016from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010017from generator.tosa_error_if import TosaErrorIfArgGen
18from generator.tosa_error_if import TosaErrorValidator
19from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010020from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000021from tosa.DType import DType
22from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010023
Jeremy Johnson1271c442023-09-05 11:39:26 +010024TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
25// SPDX-License-Identifier: Apache-2.0
26// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
27"""
28
Matthew Haddonb724efc2021-08-25 16:40:29 +010029
Eric Kunzee5e26762020-10-13 16:11:07 -070030class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010031 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000032 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010033 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010034 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010035 TOSA_8K_LEVEL_MAX_KERNEL = 8192
36 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010037
Jeremy Johnson1271c442023-09-05 11:39:26 +010038 # Main compliance dot product statistical test range
39 TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6)
40 TOSA_MI_DOT_PRODUCT_MIN = 1000
41
Eric Kunzee5e26762020-10-13 16:11:07 -070042 def __init__(self, args):
43 self.args = args
44 self.basePath = args.output_dir
45 self.random_seed = args.random_seed
46 self.ser = None
47 self.rng = np.random.default_rng(self.random_seed)
48 self.createDynamicOpLists()
49 self.initOpListDefaults()
50 self.quantGen = TosaQuantGen()
51 # Force makeShape to do a specific starting shape
52 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010053 # Work out floating point range
54 self.random_fp_low = min(args.tensor_fp_value_range)
55 self.random_fp_high = max(args.tensor_fp_value_range)
Jeremy Johnson1271c442023-09-05 11:39:26 +010056 # JSON schema validation
57 self.descSchemaValidator = TestDescSchemaValidator()
Eric Kunzee5e26762020-10-13 16:11:07 -070058
59 def createSerializer(self, opName, testPath):
60 self.testPath = os.path.join(opName, testPath)
61
62 fullPath = os.path.join(self.basePath, self.testPath)
63 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010064 # Embed const data in the flatbuffer
65 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010066 if self.args.lazy_data_gen:
67 # Lazy data generation - so make constants files
68 constMode = ts.ConstMode.INPUTS
69 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010070 constMode = ts.ConstMode.EMBED_DUMP
71 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070072
73 def getSerializer(self):
74 return self.ser
75
Jeremy Johnson1271c442023-09-05 11:39:26 +010076 def serialize(self, testName, metaData=None):
77 path = Path(self.basePath) / self.testPath
78
79 # Write out TOSA flatbuffer binary
80 path_fb = path / f"{testName}.tosa"
81 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070082 fd.write(self.ser.serialize())
83
Jeremy Johnson1271c442023-09-05 11:39:26 +010084 # Get JSON descriptor from serializer
85 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
86
87 if metaData:
88 # Add extra meta data to desc.json
89 desc["meta"] = metaData
90
91 # Validate desc.json before we output it
92 self.descSchemaValidator.validate_config(desc)
93
94 if metaData:
95 if self.args.lazy_data_gen and "data_gen" in metaData:
96 # Output datagen meta data as CPP data
97 path_md = path / f"{testName}_meta_data_gen.cpp"
98 with path_md.open("w") as fd:
99 fd.write(TOSA_AUTOGENERATED_HEADER)
100 fd.write("// Test meta data for data generation setup\n\n")
101 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
102 json.dump(metaData["data_gen"], fd)
103 fd.write(')";\n\n')
104 if "compliance" in metaData:
105 # Output datagen meta data as CPP data
106 path_md = path / f"{testName}_meta_compliance.cpp"
107 with path_md.open("w") as fd:
108 fd.write(TOSA_AUTOGENERATED_HEADER)
109 fd.write("// Test meta data for compliance validation\n\n")
110 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
111 json.dump(metaData["compliance"], fd)
112 fd.write(')";\n\n')
113
114 # Write desc.json
115 path_desc = path / "desc.json"
116 with path_desc.open("w") as fd:
117 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700118
Matthew Haddon74567092021-07-16 15:38:20 +0100119 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000120 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100121 seed = self.random_seed + 1
122 self.rng = np.random.default_rng(seed)
123
Jeremy Johnson1271c442023-09-05 11:39:26 +0100124 def getDTypeRange(self, dtype, high_inclusive=False):
125 # Returns dtype value range boundaries (low, high)
126 # The high boundary is excluded in the range
127 # unless high_inclusive is True
128
129 if dtype in (DType.FP32, DType.FP16, DType.BF16):
130 return (self.random_fp_low, self.random_fp_high)
131 elif dtype == DType.BOOL:
132 rng = (0, 2)
133 elif dtype == DType.UINT8:
134 rng = (0, 256)
135 elif dtype == DType.UINT16:
136 rng = (0, 65536)
137 elif dtype == DType.INT4:
138 # TOSA specific INT4 weight range from -7 to 7
139 rng = (-7, 8)
140 elif dtype == DType.INT8:
141 rng = (-128, 128)
142 elif dtype == DType.INT16:
143 rng = (-32768, 32768)
144 elif dtype in (DType.INT32, DType.SHAPE):
145 # restricting too large value for SHAPE
146 rng = (-(1 << 31), (1 << 31))
147 elif dtype == DType.INT48:
148 rng = (-(1 << 47), (1 << 47))
149 else:
150 raise Exception("Unknown dtype: {}".format(dtype))
151
152 if not high_inclusive:
153 # Exclusive high: low <= range < high
154 return rng
155 else:
156 # Inclusive range: low <= range <= high
157 return (rng[0], rng[1] - 1)
158
Eric Kunzee5e26762020-10-13 16:11:07 -0700159 def getRandTensor(self, shape, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100160 low, high = self.getDTypeRange(dtype)
161
Eric Kunzee5e26762020-10-13 16:11:07 -0700162 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700163 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700164 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100165 return np.int64(self.rng.integers(low=low, high=high, size=shape))
166 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
167 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
168
169 if dtype == DType.FP16:
170 return np.float16(f_tensor)
171 else:
172 f32_tensor = np.float32(f_tensor)
173 if dtype == DType.BF16:
174 # Floor the last 16 bits of each f32 value
175 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
176 else:
177 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100179 # All other integer types
180 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700181
Kevin Cheng989cb052021-04-28 16:29:44 -0700182 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700183 placeholders = []
184
Kevin Cheng989cb052021-04-28 16:29:44 -0700185 assert len(shape_list) == len(dtype_list)
186
Jeremy Johnson1271c442023-09-05 11:39:26 +0100187 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700188 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100189 if not self.args.lazy_data_gen:
190 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700191 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700192
193 return placeholders
194
Kevin Cheng989cb052021-04-28 16:29:44 -0700195 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700196 consts = []
197
Kevin Cheng989cb052021-04-28 16:29:44 -0700198 assert len(shape_list) == len(dtype_list)
199
Jeremy Johnson1271c442023-09-05 11:39:26 +0100200 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700201 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100202 if not self.args.lazy_data_gen:
203 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700204 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700205
206 return consts
207
208 def makeShape(self, rank):
209 if self.targetted_shape:
210 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800211 return np.int32(
212 self.rng.integers(
213 low=self.args.tensor_shape_range[0],
214 high=self.args.tensor_shape_range[1],
215 size=rank,
216 )
217 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700218
219 def setTargetShape(self, shape):
220 self.targetted_shape = shape
221
222 def randInt(self, low=0, high=256):
223 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
224
225 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100226 low, high = self.getDTypeRange(dtype)
227
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100228 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100229 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100230 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100231 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100232 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100233 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
234 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700235 elif dtype == DType.BOOL:
236 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700237 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700238 # Special size
239 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700240
241 return np.int32(self.rng.integers(low, high, size=1))[0]
242
243 def shapeStr(self, shape):
244
245 sStr = []
246 # Convert to strings
247 for i in shape:
248 sStr.append(str(i))
249
Kevin Cheng550ccc52021-03-03 11:21:43 -0800250 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700251
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100252 def typeStr(self, dtype):
253 if isinstance(dtype, list) or isinstance(dtype, tuple):
254 assert len(dtype) >= 2
255 strs = [self.typeStr(t) for t in dtype]
256 # Limit types to the first 2 as the 3rd is the accumulator
257 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100259 if dtype in gtu.DTYPE_ATTRIBUTES:
260 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700261 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100262 raise Exception(
263 "Unknown dtype, cannot convert to string: {}".format(dtype)
264 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700265
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100266 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100267 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100268 if dtype in gtu.DTYPE_ATTRIBUTES:
269 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700270 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100271 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700272
Luke Hutton57287132023-02-06 14:54:18 +0000273 def constrictBatchSize(self, shape):
274 # Limit the batch size unless an explicit target shape set
275 if self.args.max_batch_size and not self.args.target_shapes:
276 shape[0] = min(shape[0], self.args.max_batch_size)
277 return shape
278
James Ward30124a82023-02-02 14:56:33 +0000279 def makeDimension(self):
280 return self.randInt(
281 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
282 )
283
Jeremy Johnson1271c442023-09-05 11:39:26 +0100284 def tensorComplianceMetaData(self, op, argsDict, outputTensor, errorName):
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100285 if errorName or not gtu.dtypeIsFloat(outputTensor.dtype):
286 # No compliance for error tests or integer tests currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100287 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100288
Jeremy Johnson1271c442023-09-05 11:39:26 +0100289 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100290 compliance_tens = {
291 "mode": None,
292 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
293 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
294 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100295 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
296 mode = gtu.ComplianceMode.DOT_PRODUCT
297 compliance_tens["dot_product_info"] = {
298 "s": argsDict["s"],
299 "ks": argsDict["ks"],
Jeremy Johnson1271c442023-09-05 11:39:26 +0100300 }
301 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
302 mode = gtu.ComplianceMode.FP_SPECIAL
303 elif "compliance" in op and "ulp" in op["compliance"]:
304 mode = gtu.ComplianceMode.ULP
305 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
306 elif op["op"] == Op.REDUCE_PRODUCT:
307 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100308 elif op["op"] in (Op.ADD, Op.MUL, Op.SUB, Op.CEIL, Op.FLOOR, Op.CAST):
309 mode = gtu.ComplianceMode.ROUND
Jeremy Johnson1271c442023-09-05 11:39:26 +0100310 else:
311 mode = gtu.ComplianceMode.EXACT
312 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
313
314 return compliance_tens
315
316 # Build Op functions
317 # Create the output tensor (calling OutputShaper as needed)
318 # Do final tweaks to attributes (if necessary for errorIf)
319 # Add Op into graph
320 # Return resulting tensor information or BuildInfo
321
322 class BuildInfo:
323 """Enhanced build information containing result tensor and associated compliance dict."""
324
325 def __init__(self, resultTensor, complianceDict):
326 self.resultTensor = resultTensor
327 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700328
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100329 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
330 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
331
Matthew Haddon848efb42021-09-09 12:30:53 +0100332 # build_placeholder returns an int, ABS/other ops does not
333 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000334 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100335 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000336 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000337 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100338 return result_tens
339
340 # Ensure new output type has correct qinfo
341 if error_name == ErrorIf.WrongOutputType:
342 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000343 qinfo = [
344 TosaQuantGen.getZeroPoint(self, a.dtype),
345 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
346 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100347
348 # Invalidate Input/Output list for error if checks.
349 input_list = [a.name]
350 output_list = [result_tens.name]
351 pCount, cCount = op["operands"]
352 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000353 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
354 self, error_name, input_list, output_list
355 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100356
Les Bell729b0352021-11-24 10:28:21 +0000357 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100358 self.ser,
359 validator_fcns,
360 error_name,
361 op=op,
362 input_dtype=a.dtype,
363 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000364 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000365 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100366 input_list=input_list,
367 output_list=output_list,
368 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000369 ):
370 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100371
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000372 attr = None
373 if op["op"] == Op.NEGATE:
374 attr = ts.TosaSerializerAttribute()
375 attr.NegateAttribute(qinfo[0], qinfo[1])
376
377 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700378 return result_tens
379
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100380 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000381 result_tens = OutputShaper.binaryBroadcastOp(
382 self.ser, self.rng, a, b, error_name
383 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100384
385 # Invalidate Input/Output list for error if checks.
386 input_list = [a.name, b.name]
387 output_list = [result_tens.name]
388 pCount, cCount = op["operands"]
389 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000390 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
391 self, error_name, input_list, output_list
392 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100393
Les Bell729b0352021-11-24 10:28:21 +0000394 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100395 self.ser,
396 validator_fcns,
397 error_name,
398 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000399 input1=a,
400 input2=b,
401 input_dtype=a.dtype,
402 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000403 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100404 input_list=input_list,
405 output_list=output_list,
406 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000407 ):
408 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100409
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000410 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700411 return result_tens
412
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100413 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700414 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000415 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700416 return result_tens
417
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000418 def build_arithmetic_right_shift(
419 self, op, a, b, round, validator_fcns=None, error_name=None
420 ):
421 result_tens = OutputShaper.binaryBroadcastOp(
422 self.ser, self.rng, a, b, error_name
423 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100424
425 # Invalidate Input/Output list for error if checks.
426 input_list = [a.name, b.name]
427 output_list = [result_tens.name]
428 pCount, cCount = op["operands"]
429 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000430 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
431 self, error_name, input_list, output_list
432 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100433
Les Bell729b0352021-11-24 10:28:21 +0000434 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100435 self.ser,
436 validator_fcns,
437 error_name,
438 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000439 input1=a,
440 input2=b,
441 input_dtype=a.dtype,
442 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000443 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100444 input_list=input_list,
445 output_list=output_list,
446 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000447 ):
448 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800449
450 attr = ts.TosaSerializerAttribute()
451 attr.ArithmeticRightShiftAttribute(round)
452
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000453 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800454 return result_tens
455
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100456 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000457 result_tens = OutputShaper.binaryBroadcastOp(
458 self.ser, self.rng, a, b, error_name
459 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700460
461 # Special for multiply:
462 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100463 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700464 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100465 if error_name == ErrorIf.WrongOutputType:
466 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
467 outputDType = self.rng.choice(all_dtypes)
468 result_tens.setDtype(outputDType)
469
470 # Invalidate Input/Output list for error if checks.
471 input_list = [a.name, b.name]
472 output_list = [result_tens.name]
473 pCount, cCount = op["operands"]
474 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000475 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
476 self, error_name, input_list, output_list
477 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100478
Les Bell729b0352021-11-24 10:28:21 +0000479 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100480 self.ser,
481 validator_fcns,
482 error_name,
483 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000484 input1=a,
485 input2=b,
486 input_dtype=a.dtype,
487 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000488 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100489 input_list=input_list,
490 output_list=output_list,
491 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000492 ):
493 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700494
Kevin Chengaee1fac2020-11-11 13:54:06 -0800495 attr = ts.TosaSerializerAttribute()
496 attr.MulAttribute(shift)
497
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000498 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700499 return result_tens
500
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100501 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
502 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700503
Kevin Chengfe392ce2021-10-18 21:51:55 +0000504 attr = ts.TosaSerializerAttribute()
505 attr.TableAttribute(table)
506
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100507 # Invalidate Input/Output list for error if checks.
508 input_list = [a.name]
509 output_list = [result_tens.name]
510 pCount, cCount = op["operands"]
511 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000512 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
513 self, error_name, input_list, output_list
514 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100515
Les Bell729b0352021-11-24 10:28:21 +0000516 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100517 self.ser,
518 validator_fcns,
519 error_name,
520 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000521 input_shape=a.shape,
522 input_dtype=a.dtype,
523 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000524 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100525 input_list=input_list,
526 output_list=output_list,
527 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000528 ):
529 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100530
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000531 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700532
533 return result_tens
534
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100535 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
536 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
537
538 # Invalidate Input/Output list for error if checks.
539 input_list = [cond.name, a.name, b.name]
540 output_list = [result_tens.name]
541 pCount, cCount = op["operands"]
542 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000543 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
544 self, error_name, input_list, output_list
545 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100546
Les Bell729b0352021-11-24 10:28:21 +0000547 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100548 self.ser,
549 validator_fcns,
550 error_name,
551 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000552 input1=cond,
553 input2=a,
554 input3=b,
555 input_shape=a.shape,
556 input_dtype=a.dtype,
557 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000558 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100559 input_list=input_list,
560 output_list=output_list,
561 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000562 ):
563 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100564
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000565 self.ser.addOperator(
566 op["op"],
567 input_list,
568 output_list,
569 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700570 return result_tens
571
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100572 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000573 result_tens = OutputShaper.binaryComparisonOp(
574 self.ser, self.rng, a, b, error_name
575 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100576
577 # Invalidate Input/Output list for error if checks.
578 input_list = [a.name, b.name]
579 output_list = [result_tens.name]
580 pCount, cCount = op["operands"]
581 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000582 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
583 self, error_name, input_list, output_list
584 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100585
Les Bell729b0352021-11-24 10:28:21 +0000586 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100587 self.ser,
588 validator_fcns,
589 error_name,
590 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000591 input1=a,
592 input2=b,
593 input_shape=a.shape,
594 input_dtype=a.dtype,
595 output_shape=result_tens.shape,
596 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000597 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100598 input_list=input_list,
599 output_list=output_list,
600 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000601 ):
602 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100603
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000604 self.ser.addOperator(
605 op["op"],
606 input_list,
607 output_list,
608 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700609 return result_tens
610
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100611 def build_argmax(self, op, a, axis, validator_fcns, error_name):
612 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
613
614 # Invalidate Input/Output list for error if checks.
615 input_list = [a.name]
616 output_list = [result_tens.name]
617 pCount, cCount = op["operands"]
618 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000619 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
620 self, error_name, input_list, output_list
621 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100622
Les Bell729b0352021-11-24 10:28:21 +0000623 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100624 self.ser,
625 validator_fcns,
626 error_name,
627 op=op,
628 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000629 input_shape=a.shape,
630 input_dtype=a.dtype,
631 output_shape=result_tens.shape,
632 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000633 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100634 input_list=input_list,
635 output_list=output_list,
636 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000637 ):
638 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700639
640 attr = ts.TosaSerializerAttribute()
641 attr.AxisAttribute(axis)
642
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000643 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700644 return result_tens
645
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000646 def build_pool2d(
647 self,
648 op,
649 input,
James Ward8b390432022-08-12 20:48:56 +0100650 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000651 stride,
652 pad,
653 kernel,
654 validator_fcns=None,
655 error_name=None,
656 qinfo=None,
657 ):
658 result_tens = OutputShaper.pool2dOp(
659 self.ser, self.rng, input, kernel, stride, pad, error_name
660 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100661
662 # Ensure new output type has correct qinfo
663 if error_name == ErrorIf.WrongInputType:
664 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000665 qinfo = [
666 TosaQuantGen.getZeroPoint(self, input.dtype),
667 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
668 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100669
670 # Invalidate Input/Output list for error if checks.
671 input_list = [input.name]
672 output_list = [result_tens.name]
673 pCount, cCount = op["operands"]
674 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000675 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
676 self, error_name, input_list, output_list
677 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100678
Les Bell729b0352021-11-24 10:28:21 +0000679 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100680 self.ser,
681 validator_fcns,
682 error_name,
683 op=op,
684 input_shape=input.shape,
685 input_dtype=input.dtype,
686 output_shape=result_tens.shape,
687 output_dtype=result_tens.dtype,
688 kernel=kernel,
689 stride=stride,
690 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000691 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000692 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100693 input_list=input_list,
694 output_list=output_list,
695 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000696 ):
697 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700698
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000699 if qinfo is None:
700 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700701
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000702 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100703 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000704
705 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700706 return result_tens
707
James Ward8b390432022-08-12 20:48:56 +0100708 def build_maxpool2d(
709 self,
710 op,
711 input,
712 stride,
713 pad,
714 kernel,
715 validator_fcns=None,
716 error_name=None,
717 qinfo=None,
718 ):
719 # Same as build_pool2d but manually sets accum_dtype value
720 # (maxpool has no accum_dtype)
721 return self.build_pool2d(
722 op,
723 input,
724 DType.UNKNOWN,
725 stride,
726 pad,
727 kernel,
728 validator_fcns,
729 error_name,
730 qinfo,
731 )
732
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000733 def build_conv2d(
734 self,
735 op,
736 ifm,
737 filter,
738 bias,
James Ward8b390432022-08-12 20:48:56 +0100739 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000740 strides,
741 padding,
742 dilations,
743 validator_fcns=None,
744 error_name=None,
745 qinfo=None,
746 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800747 assert len(padding) == 4
748 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100749 self.ser,
750 self.rng,
751 ifm,
752 filter,
753 accum_dtype,
754 strides,
755 padding,
756 dilations,
757 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000758 )
759
760 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000761 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
762 DType.INT8,
763 DType.UINT8,
764 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000765 qinfo = [
766 TosaQuantGen.getZeroPoint(self, ifm.dtype),
767 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
768 ]
Les Bell0e027d42021-11-09 14:42:14 +0000769
770 # Invalidate Input/Output list for error_if checks.
771 input_list = [ifm.name, filter.name, bias.name]
772 output_list = [result_tens.name]
773 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000774 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
775 self, error_name, input_list, output_list
776 )
Les Bell0e027d42021-11-09 14:42:14 +0000777
Les Bell729b0352021-11-24 10:28:21 +0000778 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000779 self.ser,
780 validator_fcns,
781 error_name,
782 op=op,
783 input_dtype=ifm.dtype,
784 weight_dtype=filter.dtype,
785 output_dtype=result_tens.dtype,
786 qinfo=qinfo,
787 input_list=input_list,
788 num_operands=num_operands,
789 output_list=output_list,
790 pad=padding,
791 stride=strides,
792 dilation=dilations,
793 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100794 weight_shape=filter.shape,
795 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000796 ):
797 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700798
799 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000800 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700801
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000802 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700803 return result_tens
804
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000805 def build_conv3d(
806 self,
807 op,
808 ifm,
809 filter,
810 bias,
James Ward8b390432022-08-12 20:48:56 +0100811 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000812 strides,
813 padding,
814 dilations,
815 validator_fcns=None,
816 error_name=None,
817 qinfo=None,
818 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700819 assert len(padding) == 6
820 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100821 self.ser,
822 self.rng,
823 ifm,
824 filter,
825 accum_dtype,
826 strides,
827 padding,
828 dilations,
829 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000830 )
831
832 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000833 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
834 DType.INT8,
835 DType.UINT8,
836 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000837 qinfo = [
838 TosaQuantGen.getZeroPoint(self, ifm.dtype),
839 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
840 ]
Les Bell0e027d42021-11-09 14:42:14 +0000841
842 # Invalidate Input/Output list for error_if checks.
843 input_list = [ifm.name, filter.name, bias.name]
844 output_list = [result_tens.name]
845 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000846 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
847 self, error_name, input_list, output_list
848 )
Les Bell0e027d42021-11-09 14:42:14 +0000849
Les Bell729b0352021-11-24 10:28:21 +0000850 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000851 self.ser,
852 validator_fcns,
853 error_name,
854 op=op,
855 input_dtype=ifm.dtype,
856 weight_dtype=filter.dtype,
857 output_dtype=result_tens.dtype,
858 qinfo=qinfo,
859 input_list=input_list,
860 num_operands=num_operands,
861 output_list=output_list,
862 pad=padding,
863 stride=strides,
864 dilation=dilations,
865 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100866 weight_shape=filter.shape,
867 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000868 ):
869 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700870
871 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000872 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700873
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000874 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700875 return result_tens
876
Kevin Cheng550ccc52021-03-03 11:21:43 -0800877 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000878 self,
879 op,
880 ifm,
881 filter,
882 bias,
James Ward8b390432022-08-12 20:48:56 +0100883 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000884 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700885 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000886 output_shape,
887 validator_fcns=None,
888 error_name=None,
889 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800890 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700891 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000892 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100893 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000894 )
Les Bell0e027d42021-11-09 14:42:14 +0000895
896 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000897 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
898 DType.INT8,
899 DType.UINT8,
900 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000901 qinfo = [
902 TosaQuantGen.getZeroPoint(self, ifm.dtype),
903 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
904 ]
Les Bell0e027d42021-11-09 14:42:14 +0000905
906 # Invalidate Input/Output list for error_if checks.
907 input_list = [ifm.name, filter.name, bias.name]
908 output_list = [result_tens.name]
909 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000910 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
911 self, error_name, input_list, output_list
912 )
Les Bell0e027d42021-11-09 14:42:14 +0000913
Les Bell729b0352021-11-24 10:28:21 +0000914 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000915 self.ser,
916 validator_fcns,
917 error_name,
918 op=op,
919 input_dtype=ifm.dtype,
920 weight_dtype=filter.dtype,
921 output_dtype=result_tens.dtype,
922 qinfo=qinfo,
923 input_list=input_list,
924 num_operands=num_operands,
925 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700926 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000927 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000928 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100929 weight_shape=filter.shape,
930 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000931 ):
932 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700933
934 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000935 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700936
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000937 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700938 return result_tens
939
Kevin Cheng550ccc52021-03-03 11:21:43 -0800940 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000941 self,
942 op,
943 ifm,
944 filter,
945 bias,
James Ward8b390432022-08-12 20:48:56 +0100946 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000947 strides,
948 padding,
949 dilations,
950 validator_fcns=None,
951 error_name=None,
952 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800953 ):
954 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100955 self.ser,
956 self.rng,
957 ifm,
958 filter,
959 accum_dtype,
960 strides,
961 padding,
962 dilations,
963 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000964 )
965
966 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000967 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
968 DType.INT8,
969 DType.UINT8,
970 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000971 qinfo = [
972 TosaQuantGen.getZeroPoint(self, ifm.dtype),
973 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
974 ]
Les Bell0e027d42021-11-09 14:42:14 +0000975
976 # Invalidate Input/Output list for error_if checks.
977 input_list = [ifm.name, filter.name, bias.name]
978 output_list = [result_tens.name]
979 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000980 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
981 self, error_name, input_list, output_list
982 )
Les Bell0e027d42021-11-09 14:42:14 +0000983
Les Bell729b0352021-11-24 10:28:21 +0000984 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000985 self.ser,
986 validator_fcns,
987 error_name,
988 op=op,
989 input_dtype=ifm.dtype,
990 weight_dtype=filter.dtype,
991 output_dtype=result_tens.dtype,
992 qinfo=qinfo,
993 input_list=input_list,
994 num_operands=num_operands,
995 output_list=output_list,
996 pad=padding,
997 stride=strides,
998 dilation=dilations,
999 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001000 weight_shape=filter.shape,
1001 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +00001002 ):
1003 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001004
1005 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001006 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07001007
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001008 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001009 return result_tens
1010
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001011 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001012 self,
1013 op,
1014 ifm,
1015 filter,
1016 bias,
1017 accum_dtype,
1018 validator_fcns=None,
1019 error_name=None,
1020 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001021 ):
1022 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001023 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001024 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001025
1026 # Invalidate Input/Output list for error if checks.
1027 input_list = [ifm.name, filter.name, bias.name]
1028 output_list = [result_tens.name]
1029 pCount, cCount = op["operands"]
1030 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001031 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1032 self, error_name, input_list, output_list
1033 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001034
Les Bell729b0352021-11-24 10:28:21 +00001035 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001036 self.ser,
1037 validator_fcns,
1038 error_name,
1039 op=op,
1040 input_shape=ifm.shape,
1041 input_dtype=ifm.dtype,
1042 weight_dtype=filter.dtype,
1043 output_shape=result_tens.shape,
1044 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001045 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001046 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001047 input_list=input_list,
1048 output_list=output_list,
1049 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001050 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001051 ):
1052 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001053
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001054 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001055 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001056
1057 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001058 return result_tens
1059
James Ward8b390432022-08-12 20:48:56 +01001060 def build_matmul(
Jeremy Johnson1271c442023-09-05 11:39:26 +01001061 self, op, a, b, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001062 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001063 accum_dtype = args_dict["acc_type"]
1064 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001065 self.ser, self.rng, a, b, accum_dtype, error_name
1066 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001067
1068 # Invalidate Input/Output list for error if checks.
1069 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001070 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001071 pCount, cCount = op["operands"]
1072 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001073 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1074 self, error_name, input_list, output_list
1075 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001076
Les Bell729b0352021-11-24 10:28:21 +00001077 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001078 self.ser,
1079 validator_fcns,
1080 error_name,
1081 op=op,
1082 input_shape=a.shape,
1083 input_dtype=a.dtype,
1084 input2_shape=b.shape,
1085 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001086 output_shape=result_tensor.shape,
1087 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001088 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001089 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001090 input_list=input_list,
1091 output_list=output_list,
1092 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001093 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001094 ):
1095 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001096
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001097 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001098 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001099
1100 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001101
1102 compliance = self.tensorComplianceMetaData(
1103 op, args_dict, result_tensor, error_name
1104 )
1105
1106 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001107
Matthew Haddond6ce7252021-09-29 15:35:44 +01001108 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1109 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1110
1111 # Invalidate Input/Output list for error if checks.
1112 input_list = [a.name]
1113 output_list = [result_tens.name]
1114 pCount, cCount = op["operands"]
1115 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001116 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1117 self, error_name, input_list, output_list
1118 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001119
Les Bell729b0352021-11-24 10:28:21 +00001120 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001121 self.ser,
1122 validator_fcns,
1123 error_name,
1124 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001125 axis=axis,
1126 input_shape=a.shape,
1127 output_shape=result_tens.shape,
1128 input_dtype=a.dtype,
1129 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001130 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001131 input_list=input_list,
1132 output_list=output_list,
1133 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001134 ):
1135 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001136
1137 attr = ts.TosaSerializerAttribute()
1138 attr.AxisAttribute(axis)
1139
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001140 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001141 return result_tens
1142
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001143 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1144 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001145
Jeremy Johnson18e26662021-07-22 16:15:29 +01001146 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001147
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001148 if error_name == ErrorIf.MaxSmallerMin:
1149 # Make sure the numbers are different to invoke this error
1150 while v[0] == v[1]:
1151 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1152 max_val = min(v)
1153 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001154 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001155 max_val = max(v)
1156 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001157
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001158 # Invalidate Input/Output list for error if checks.
1159 input_list = [a.name]
1160 output_list = [result_tens.name]
1161 pCount, cCount = op["operands"]
1162 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001163 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1164 self, error_name, input_list, output_list
1165 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001166
Les Bell729b0352021-11-24 10:28:21 +00001167 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001168 self.ser,
1169 validator_fcns,
1170 error_name,
1171 op=op,
1172 max_val=max_val,
1173 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001174 input_shape=a.shape,
1175 output_shape=result_tens.shape,
1176 input_dtype=a.dtype,
1177 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001178 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001179 input_list=input_list,
1180 output_list=output_list,
1181 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001182 ):
1183 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001184
1185 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001186 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1187 if a.dtype == DType.FP16:
1188 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1189 min_val = min_val.astype(np.float32)
1190 max_val = max_val.astype(np.float32)
1191
1192 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001193 else:
James Ward34071252022-12-07 15:48:47 +00001194 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001195
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001196 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001197 return result_tens
1198
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001199 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1200 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001201 attr = ts.TosaSerializerAttribute()
1202
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001203 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001204
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001205 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001206 return result_tens
1207
1208 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001209 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1210 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001211
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001212 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001213 return result_tens
1214
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001215 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1216 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1217
1218 # Invalidate Input/Output list for error if checks.
1219 input_list = [a.name]
1220 output_list = [result_tens.name]
1221 pCount, cCount = op["operands"]
1222 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001223 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1224 self, error_name, input_list, output_list
1225 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001226
Les Bell729b0352021-11-24 10:28:21 +00001227 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001228 self.ser,
1229 validator_fcns,
1230 error_name,
1231 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001232 input_shape=a.shape,
1233 output_shape=result_tens.shape,
1234 input_dtype=a.dtype,
1235 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001236 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001237 input_list=input_list,
1238 output_list=output_list,
1239 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001240 ):
1241 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001242
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001243 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001244 return result_tens
1245
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001246 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1247 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1248
1249 # Invalidate Input/Output list for error if checks.
1250 input_list = [a.name]
1251 output_list = [result_tens.name]
1252 pCount, cCount = op["operands"]
1253 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001254 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1255 self, error_name, input_list, output_list
1256 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001257
Les Bell729b0352021-11-24 10:28:21 +00001258 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001259 self.ser,
1260 validator_fcns,
1261 error_name,
1262 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001263 input_shape=a.shape,
1264 output_shape=result_tens.shape,
1265 input_dtype=a.dtype,
1266 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001267 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001268 input_list=input_list,
1269 output_list=output_list,
1270 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001271 ):
1272 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001273
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001274 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001275 return result_tens
1276
Won Jeon78155c62023-06-10 00:20:04 +00001277 def build_erf(self, op, a, validator_fcns=None, error_name=None):
1278 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1279
1280 # Invalidate Input/Output list for error if checks.
1281 input_list = [a.name]
1282 output_list = [result_tens.name]
1283 pCount, cCount = op["operands"]
1284 num_operands = pCount + cCount
1285 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1286 self, error_name, input_list, output_list
1287 )
1288
1289 if not TosaErrorValidator.evValidateErrorIfs(
1290 self.ser,
1291 validator_fcns,
1292 error_name,
1293 op=op,
1294 input_shape=a.shape,
1295 output_shape=result_tens.shape,
1296 input_dtype=a.dtype,
1297 output_dtype=result_tens.dtype,
1298 result_tensors=[result_tens],
1299 input_list=input_list,
1300 output_list=output_list,
1301 num_operands=num_operands,
1302 ):
1303 return None
1304
1305 self.ser.addOperator(op["op"], input_list, output_list)
1306 return result_tens
1307
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001308 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1309 if error_name != ErrorIf.WrongInputType:
1310 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001311
1312 # To store variable length list of input tensors we need to store axis along with it
1313 axis = a[-1]
1314 a = a[:-1]
1315
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001316 result_tens = OutputShaper.concatOp(
1317 self.ser, self.rng, axis, *a, error_name=error_name
1318 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001319
Matthew Haddon818ab902021-07-27 09:12:49 +01001320 input_tensor_names = []
1321 for tensor in a:
1322 input_tensor_names.append(tensor.name)
1323
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001324 # Invalidate Input/Output list for error if checks.
1325 input_list = input_tensor_names
1326 output_list = [result_tens.name]
1327 pCount, cCount = op["operands"]
1328 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001329 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1330 self, error_name, input_list, output_list
1331 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001332
Les Bell729b0352021-11-24 10:28:21 +00001333 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001334 self.ser,
1335 validator_fcns,
1336 error_name,
1337 op=op,
1338 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001339 input_shape=a[0].shape,
1340 output_shape=result_tens.shape,
1341 input_dtype=a[0].dtype,
1342 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001343 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001344 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001345 input_list=input_list,
1346 output_list=output_list,
1347 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001348 ):
1349 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001350
1351 attr = ts.TosaSerializerAttribute()
1352 attr.AxisAttribute(axis)
1353
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001354 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001355 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001356
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001357 def build_pad(
1358 self,
1359 op,
1360 a,
1361 padding,
1362 pad_const_int,
1363 pad_const_float,
1364 validator_fcns=None,
1365 error_name=None,
1366 qinfo=None,
1367 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001368 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001369
Kevin Chengfe392ce2021-10-18 21:51:55 +00001370 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001371 attr.PadAttribute(
1372 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1373 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001374
Matthew Haddone807aae2021-10-11 18:12:58 +01001375 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001376 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001377 output_list = [result_tens.name]
1378 pCount, cCount = op["operands"]
1379 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001380 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1381 self, error_name, input_list, output_list
1382 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001383
Les Bell729b0352021-11-24 10:28:21 +00001384 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001385 self.ser,
1386 validator_fcns,
1387 error_name,
1388 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001389 input_shape=a.shape,
1390 output_shape=result_tens.shape,
1391 input_dtype=a.dtype,
1392 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001393 pad=padding,
1394 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001395 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001396 input_list=input_list,
1397 output_list=output_list,
1398 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001399 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001400 ):
1401 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001402
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001403 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001404 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001405
Won Jeona21b2e82023-08-10 10:33:01 +00001406 def build_dim(
1407 self,
1408 op,
1409 a,
1410 axis,
1411 validator_fcns=None,
1412 error_name=None,
1413 qinfo=None,
1414 ):
1415 result_tens = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
1416
1417 # Invalidate Input/Output list for error if checks.
1418 input_list = [a.name]
1419 output_list = [result_tens.name]
1420 pCount, cCount = op["operands"]
1421 num_operands = pCount + cCount
1422 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1423 self, error_name, input_list, output_list
1424 )
1425
1426 if not TosaErrorValidator.evValidateErrorIfs(
1427 self.ser,
1428 validator_fcns,
1429 error_name,
1430 op=op,
1431 axis=axis,
1432 input_shape=a.shape,
1433 input_dtype=a.dtype,
1434 output_shape=result_tens.shape,
1435 output_dtype=result_tens.dtype,
1436 result_tensors=[result_tens],
1437 input_list=input_list,
1438 output_list=output_list,
1439 num_operands=num_operands,
1440 ):
1441 return None
1442
1443 attr = ts.TosaSerializerAttribute()
1444 attr.AxisAttribute(axis)
1445
1446 self.ser.addOperator(op["op"], input_list, output_list, attr)
1447 return result_tens
1448
Matthew Haddone807aae2021-10-11 18:12:58 +01001449 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001450 result_tens = OutputShaper.reshapeOp(
1451 self.ser, self.rng, a, newShape, error_name
1452 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001453
1454 # Invalidate Input/Output list for error if checks.
1455 input_list = [a.name]
1456 output_list = [result_tens.name]
1457 pCount, cCount = op["operands"]
1458 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001459 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1460 self, error_name, input_list, output_list
1461 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001462
Les Bell729b0352021-11-24 10:28:21 +00001463 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001464 self.ser,
1465 validator_fcns,
1466 error_name,
1467 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001468 input_shape=a.shape,
1469 output_shape=result_tens.shape,
1470 input_dtype=a.dtype,
1471 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001472 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001473 input_list=input_list,
1474 output_list=output_list,
1475 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001476 ):
1477 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001478
1479 attr = ts.TosaSerializerAttribute()
1480 attr.ReshapeAttribute(newShape)
1481
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001482 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001483 return result_tens
1484
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001485 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1486 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1487
1488 # Invalidate Input/Output list for error if checks.
1489 input_list = [a.name]
1490 output_list = [result_tens.name]
1491 pCount, cCount = op["operands"]
1492 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001493 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1494 self, error_name, input_list, output_list
1495 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001496
Les Bell729b0352021-11-24 10:28:21 +00001497 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001498 self.ser,
1499 validator_fcns,
1500 error_name,
1501 op=op,
1502 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001503 input_shape=a.shape,
1504 output_shape=result_tens.shape,
1505 input_dtype=a.dtype,
1506 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001507 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001508 input_list=input_list,
1509 output_list=output_list,
1510 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001511 ):
1512 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001513
1514 attr = ts.TosaSerializerAttribute()
1515 attr.AxisAttribute(axis)
1516
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001517 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001518 return result_tens
1519
Matthew Haddone807aae2021-10-11 18:12:58 +01001520 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1521 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001522
Kevin Chengfe392ce2021-10-18 21:51:55 +00001523 attr = ts.TosaSerializerAttribute()
1524 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001525
Matthew Haddone807aae2021-10-11 18:12:58 +01001526 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001527 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001528 output_list = [result_tens.name]
1529 pCount, cCount = op["operands"]
1530 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001531 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1532 self, error_name, input_list, output_list
1533 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001534
Les Bell729b0352021-11-24 10:28:21 +00001535 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001536 self.ser,
1537 validator_fcns,
1538 error_name,
1539 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001540 input_shape=a.shape,
1541 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001542 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001543 input_dtype=a.dtype,
1544 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001545 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001546 input_list=input_list,
1547 output_list=output_list,
1548 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001549 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001550 ):
1551 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001552
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001553 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001554 return result_tens
1555
Matthew Haddone807aae2021-10-11 18:12:58 +01001556 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001557 result_tens = OutputShaper.sliceOp(
1558 self.ser, self.rng, a, start, size, error_name
1559 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001560
1561 # Invalidate Input/Output list for error if checks.
1562 input_list = [a.name]
1563 output_list = [result_tens.name]
1564 pCount, cCount = op["operands"]
1565 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001566 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1567 self, error_name, input_list, output_list
1568 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001569
Les Bell729b0352021-11-24 10:28:21 +00001570 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001571 self.ser,
1572 validator_fcns,
1573 error_name,
1574 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001575 input_shape=a.shape,
1576 output_shape=result_tens.shape,
1577 input_dtype=a.dtype,
1578 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001579 start=start,
1580 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001581 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001582 input_list=input_list,
1583 output_list=output_list,
1584 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001585 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001586 ):
1587 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001588
1589 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001590 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001591
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001592 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001593 return result_tens
1594
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001595 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1596 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1597
1598 # Invalidate Input/Output list for error if checks.
1599 input_list = [a.name]
1600 output_list = [result_tens.name]
1601 pCount, cCount = op["operands"]
1602 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001603 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1604 self, error_name, input_list, output_list
1605 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001606
Les Bell729b0352021-11-24 10:28:21 +00001607 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001608 self.ser,
1609 validator_fcns,
1610 error_name,
1611 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001612 input_shape=a.shape,
1613 output_shape=result_tens.shape,
1614 input_dtype=a.dtype,
1615 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001616 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001617 input_list=input_list,
1618 output_list=output_list,
1619 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001620 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001621 ):
1622 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001623
1624 attr = ts.TosaSerializerAttribute()
1625 attr.TileAttribute(multiples)
1626
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001627 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001628 return result_tens
1629
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001630 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001631
1632 # Create a new indicies tensor
1633 # here with data that doesn't exceed the dimensions of the values tensor
1634
Kevin Cheng550ccc52021-03-03 11:21:43 -08001635 K = values.shape[1] # K
1636 W = self.randInt(
1637 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1638 ) # W
1639 indicies_arr = np.int32(
1640 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1641 ) # (N, W)
1642 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001643
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001644 result_tens = OutputShaper.gatherOp(
1645 self.ser, self.rng, values, indicies, error_name
1646 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001647
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001648 # Invalidate Input/Output list for error if checks.
1649 input_list = [values.name, indicies.name]
1650 output_list = [result_tens.name]
1651 pCount, cCount = op["operands"]
1652 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001653 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1654 self, error_name, input_list, output_list
1655 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001656
Les Bell729b0352021-11-24 10:28:21 +00001657 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001658 self.ser,
1659 validator_fcns,
1660 error_name,
1661 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001662 input_shape=values.shape,
1663 output_shape=result_tens.shape,
1664 input_dtype=values.dtype,
1665 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001666 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001667 input_list=input_list,
1668 output_list=output_list,
1669 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001670 ):
1671 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001672
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001673 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001674
1675 return result_tens
1676
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001677 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001678
1679 # Create a new indicies tensor
1680 # here with data that doesn't exceed the dimensions of the values_in tensor
1681
Kevin Cheng550ccc52021-03-03 11:21:43 -08001682 K = values_in.shape[1] # K
1683 W = input.shape[1] # W
1684 indicies_arr = np.int32(
1685 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1686 ) # (N, W)
1687 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001688
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001689 result_tens = OutputShaper.scatterOp(
1690 self.ser, self.rng, values_in, indicies, input, error_name
1691 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001692
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001693 # Invalidate Input/Output list for error if checks.
1694 input_list = [values_in.name, indicies.name, input.name]
1695 output_list = [result_tens.name]
1696 pCount, cCount = op["operands"]
1697 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001698 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1699 self, error_name, input_list, output_list
1700 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001701
Les Bell729b0352021-11-24 10:28:21 +00001702 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001703 self.ser,
1704 validator_fcns,
1705 error_name,
1706 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001707 input_shape=values_in.shape,
1708 output_shape=result_tens.shape,
1709 input_dtype=values_in.dtype,
1710 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001711 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001712 input_list=input_list,
1713 output_list=output_list,
1714 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001715 ):
1716 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001717
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001718 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001719
Kevin Cheng77d0f762020-11-24 10:26:32 -08001720 return result_tens
1721
Kevin Cheng550ccc52021-03-03 11:21:43 -08001722 def build_resize(
1723 self,
1724 op,
1725 input,
1726 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001727 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001728 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001729 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001730 input_dtype,
1731 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001732 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001733 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001734 ):
1735 result_tens = OutputShaper.resizeOp(
1736 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001737 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001738 input,
1739 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001740 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001741 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001742 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001743 input_dtype,
1744 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001745 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001746 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001747
Matthew Haddon848efb42021-09-09 12:30:53 +01001748 # Invalidate Input/Output list for error if checks.
1749 input_list = [input.name]
1750 output_list = [result_tens.name]
1751 pCount, cCount = op["operands"]
1752 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001753 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1754 self, error_name, input_list, output_list
1755 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001756
Les Bell729b0352021-11-24 10:28:21 +00001757 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001758 self.ser,
1759 validator_fcns,
1760 error_name,
1761 op=op,
1762 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001763 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001764 input_dtype=input_dtype,
1765 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001766 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001767 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001768 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001769 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001770 input_list=input_list,
1771 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001772 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001773 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001774 ):
1775 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001776
Eric Kunzee5e26762020-10-13 16:11:07 -07001777 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001778
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001779 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001780
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001781 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001782 return result_tens
1783
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001784 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1785 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1786 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001787 self.ser.addOperator(
1788 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1789 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001790 return result_tens
1791
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001792 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001793 self.ser.addOutputTensor(val)
1794 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001795
1796 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001797 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001798 result_tens = OutputShaper.typeConversionOp(
1799 self.ser, self.rng, val, out_dtype, error_name
1800 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001801
1802 # Invalidate Input/Output list for error if checks.
1803 input_list = [val.name]
1804 output_list = [result_tens.name]
1805 pCount, cCount = op["operands"]
1806 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001807 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1808 self, error_name, input_list, output_list
1809 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001810
Les Bell729b0352021-11-24 10:28:21 +00001811 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001812 self.ser,
1813 validator_fcns,
1814 error_name,
1815 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001816 input_shape=val.shape,
1817 output_shape=result_tens.shape,
1818 input_dtype=val.dtype,
1819 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001820 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001821 input_list=input_list,
1822 output_list=output_list,
1823 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001824 ):
1825 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001826
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001827 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001828 return result_tens
1829
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001830 def build_rescale(
1831 self,
1832 op,
1833 val,
1834 out_dtype,
1835 scale32,
1836 double_round,
1837 per_channel,
1838 validator_fcns,
1839 error_name,
1840 ):
1841 result_tens = OutputShaper.typeConversionOp(
1842 self.ser, self.rng, val, out_dtype, error_name
1843 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001844
1845 if per_channel:
1846 nc = val.shape[-1]
1847 else:
1848 nc = 1
1849
1850 in_type_width = self.typeWidth(val.dtype)
1851 out_type_width = self.typeWidth(out_dtype)
1852
Kevin Cheng3a478572021-01-22 17:21:02 -08001853 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001854 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001855 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001856 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001857 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001858 in_type_width += 1
1859 elif error_name in [
1860 ErrorIf.InputZeroPointNotZero,
1861 ErrorIf.U16InputZeroPointNotValid,
1862 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001863 input_zp = self.randInt(-128, 128)
1864 if input_zp == 0:
1865 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001866 in_type_width += 1
1867 elif val.dtype == DType.UINT16:
1868 # Must come after ErrorIf.U16InputZeroPointNotValid check
1869 input_zp = self.rng.choice([0, 32768])
1870 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001871 else:
1872 input_zp = 0
1873
Kevin Cheng3a478572021-01-22 17:21:02 -08001874 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001875 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001876 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001877 elif out_dtype == DType.UINT8:
1878 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001879 out_type_width += 1
1880 elif error_name in [
1881 ErrorIf.OutputZeroPointNotZero,
1882 ErrorIf.U16OutputZeroPointNotValid,
1883 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001884 output_zp = self.randInt(-128, 128)
1885 if output_zp == 0:
1886 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001887 out_type_width += 1
1888 elif out_dtype == DType.UINT16:
1889 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1890 output_zp = self.rng.choice([0, 32768])
1891 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001892 else:
1893 output_zp = 0
1894
1895 # Calculate scale based on:
1896 # scale = a *(2^output_width)/(2^input_width))
1897
1898 a = np.float32(self.rng.random(size=[nc]))
1899 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1900
1901 if scale32:
1902 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001903 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001904 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1905 else:
1906 # Cap the scaling at 2^15 - 1 for scale16
1907 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1908
Kevin Cheng550ccc52021-03-03 11:21:43 -08001909 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001910
1911 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1912 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001913 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1914 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001915
1916 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001917 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1918 scale_arr[i], scale32
1919 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001920 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1921 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001922
Kevin Cheng550ccc52021-03-03 11:21:43 -08001923 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001924 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001925 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001926 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001927 assert val.placeholderFilename
1928 values = np.load(
1929 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1930 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001931 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1932 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1933 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1934 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001935 if not np.all(np.array_equal(values, val_adj)):
1936 # Values changed so overwrite file with new values
1937 np.save(
1938 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1939 val_adj,
1940 False,
1941 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001942
Matthew Haddonc2025212021-10-08 21:21:05 +01001943 # Invalidate Input/Output list for error if checks.
1944 input_list = [val.name]
1945 output_list = [result_tens.name]
1946 pCount, cCount = op["operands"]
1947 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001948 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1949 self, error_name, input_list, output_list
1950 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001951
1952 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001953 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001954 self.ser,
1955 validator_fcns,
1956 error_name,
1957 op=op,
1958 input_dtype=val.dtype,
1959 output_dtype=out_dtype,
1960 input_shape=val.shape,
1961 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001962 scale32=scale32,
1963 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001964 input_list=input_list,
1965 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001966 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01001967 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001968 ):
1969 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001970
Eric Kunzee5e26762020-10-13 16:11:07 -07001971 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001972 attr.RescaleAttribute(
1973 input_zp,
1974 output_zp,
1975 multiplier_arr,
1976 shift_arr,
1977 scale32,
1978 double_round,
1979 per_channel,
1980 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001981
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001982 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001983 return result_tens
1984
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001985 def _get_condition_tensor(self, op, cond, error_name):
1986 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001987 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001988 else:
1989 cond_type = DType.BOOL
1990 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1991 choice = self.rng.choice([1, 2])
1992 if choice == 1:
1993 cond_shape = [2]
1994 else:
1995 cond_shape = [1, 2]
1996 else:
1997 # Must be of size 1 (rank 0)
1998 cond_shape = []
1999 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2000 return cond_tens
2001
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002002 def build_cond_if_const(
2003 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
2004 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002005 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002006 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002007 # and fill them with const nodes for the body.
2008
2009 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002010 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002011
2012 # Make then/else tensors
2013 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002014
2015 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002016 if error_name in [
2017 ErrorIf.CondIfOutputListThenGraphMismatch,
2018 ErrorIf.CondIfOutputListElseGraphMismatch,
2019 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002020 incorrect_shape = deepcopy(then_tens.shape)
2021 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002022 incorrect_shape[i] += (
2023 self.rng.choice([-3, -2, 2, 3])
2024 if incorrect_shape[i] > 3
2025 else self.rng.choice([1, 2, 4])
2026 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002027 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2028
Jeremy Johnson18e26662021-07-22 16:15:29 +01002029 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2030 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002031
2032 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002033 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002034
2035 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002036 then_block = "THEN_BLOCK"
2037 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002038 attr = ts.TosaSerializerAttribute()
2039 attr.CondIfAttribute(then_block, else_block)
2040
2041 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002042 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002043
Jerry Ge9e94af82022-10-27 09:57:00 -07002044 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002045 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002046 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2047 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2048 else:
2049 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002050 self.ser.addOutputTensor(then_tens)
2051
Jerry Ge9e94af82022-10-27 09:57:00 -07002052 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002053 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2054 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2055 else:
2056 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002057 self.ser.addOutputTensor(else_tens)
2058
Les Bell729b0352021-11-24 10:28:21 +00002059 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002060 self.ser,
2061 validator_fcns,
2062 error_name,
2063 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002064 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002065 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002066 ):
2067 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002068
Eric Kunzee5e26762020-10-13 16:11:07 -07002069 return result_tens
2070
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002071 def build_cond_if_binary(
2072 self, op, a, b, cond, validator_fcns=None, error_name=None
2073 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002074 # For cond_if with a binary op in the then/else blocks, take a and b and
2075 # alternately add or subtract them based on the condition
2076
2077 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002078 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002079
Kevin Cheng550ccc52021-03-03 11:21:43 -08002080 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002081
2082 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002083 then_block = "THEN_BLOCK"
2084 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002085 attr = ts.TosaSerializerAttribute()
2086 attr.CondIfAttribute(then_block, else_block)
2087
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002088 if error_name in [
2089 ErrorIf.CondIfInputListThenGraphMismatch,
2090 ErrorIf.CondIfInputListElseGraphMismatch,
2091 ErrorIf.CondIfOutputListElseGraphMismatch,
2092 ErrorIf.CondIfOutputListThenGraphMismatch,
2093 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002094 incorrect_shape = a.shape.copy()
2095 for i in range(len(incorrect_shape)):
2096 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2097 incorrect_block_input = deepcopy(a)
2098 incorrect_block_input.shape = incorrect_shape
2099
Eric Kunzee5e26762020-10-13 16:11:07 -07002100 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002101 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002102 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002103 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002104
James Ward24dbc422022-10-19 12:20:31 +01002105 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002106 then_op, else_op = Op.ADD, Op.SUB
2107 elif a.dtype in (DType.INT8, DType.INT16):
2108 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2109 else:
2110 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002111
Les Bell6040b4d2021-10-11 12:50:31 +01002112 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002113 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002114 if (
2115 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2116 and block == then_block
2117 ) or (
2118 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2119 and block == else_block
2120 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002121 self.ser.addInputTensor(incorrect_block_input)
2122 self.ser.addInputTensor(b)
2123 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002124 elif (
2125 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2126 and block == then_block
2127 ) or (
2128 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2129 and block == else_block
2130 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002131 self.ser.addInputTensor(a)
2132 self.ser.addInputTensor(b)
2133 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2134 else:
2135 self.ser.addInputTensor(a)
2136 self.ser.addInputTensor(b)
2137 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002138 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002139
Les Bell729b0352021-11-24 10:28:21 +00002140 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002141 self.ser,
2142 validator_fcns,
2143 error_name,
2144 op=op,
2145 a=a,
2146 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002147 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002148 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002149 ):
2150 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002151
Eric Kunzee5e26762020-10-13 16:11:07 -07002152 return result_tens
2153
Matthew Haddon630c17c2021-10-14 15:05:41 +01002154 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002155 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002156
Kevin Cheng550ccc52021-03-03 11:21:43 -08002157 cond_block = "COND_BLOCK"
2158 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002159
2160 attr = ts.TosaSerializerAttribute()
2161 attr.WhileLoopAttribute(cond_block, body_block)
2162
2163 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002164 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002165 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002166 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002167
2168 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002169 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2170 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002171 if error_name == ErrorIf.InputListOutputListMismatch:
2172 incorrect_acc = deepcopy(acc)
2173 for i in range(len(incorrect_acc.shape)):
2174 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2175 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2176 else:
2177 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002178
2179 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002180 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002181 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002182 [iter.name, a.name, acc.name],
2183 [iter_out.name, a_out.name, acc_out.name],
2184 attr,
2185 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002186 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002187
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002188 if error_name in [
2189 ErrorIf.InputListCondGraphMismatch,
2190 ErrorIf.InputListBodyGraphInputMismatch,
2191 ErrorIf.InputListBodyGraphOutputMismatch,
2192 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002193 incorrect_iter = deepcopy(iter)
2194 for i in range(len(incorrect_iter.shape)):
2195 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2196 if len(incorrect_iter.shape) == 0:
2197 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2198
2199 incorrect_acc = deepcopy(acc)
2200 for i in range(len(incorrect_acc.shape)):
2201 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2202
Eric Kunzee5e26762020-10-13 16:11:07 -07002203 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002204 self.ser.addBasicBlock(cond_block)
2205
Matthew Haddon630c17c2021-10-14 15:05:41 +01002206 if error_name == ErrorIf.InputListCondGraphMismatch:
2207 self.ser.addInputTensor(incorrect_iter)
2208 self.ser.addInputTensor(a)
2209 self.ser.addInputTensor(incorrect_acc)
2210 else:
2211 self.ser.addInputTensor(iter)
2212 self.ser.addInputTensor(a)
2213 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002214 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002215
2216 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002217 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002218 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002219 cond_type = DType.BOOL
2220 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2221 choice = self.rng.choice([1, 2])
2222 if choice == 1:
2223 cond_shape = [3]
2224 else:
2225 cond_shape = [1, 2]
2226 else:
2227 cond_shape = []
2228 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002229
Kevin Cheng550ccc52021-03-03 11:21:43 -08002230 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002231
2232 # BODY block (input: a, acc, iter, output: a, acc, iter)
2233 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002234 self.ser.addBasicBlock(body_block)
2235
Matthew Haddon630c17c2021-10-14 15:05:41 +01002236 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2237 self.ser.addInputTensor(incorrect_iter)
2238 self.ser.addInputTensor(a)
2239 self.ser.addInputTensor(incorrect_acc)
2240 else:
2241 self.ser.addInputTensor(iter)
2242 self.ser.addInputTensor(a)
2243 self.ser.addInputTensor(acc)
2244
Kevin Cheng550ccc52021-03-03 11:21:43 -08002245 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002246
2247 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002248 iter_body_out = self.ser.addIntermediate(
2249 incorrect_iter.shape, incorrect_iter.dtype
2250 )
2251 acc_body_out = self.ser.addIntermediate(
2252 incorrect_acc.shape, incorrect_acc.dtype
2253 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002254 else:
2255 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2256 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2257
Eric Kunzee5e26762020-10-13 16:11:07 -07002258 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2259 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2260 self.ser.addOutputTensor(iter_body_out)
2261 self.ser.addOutputTensor(a)
2262 self.ser.addOutputTensor(acc_body_out)
2263
Les Bell729b0352021-11-24 10:28:21 +00002264 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002265 self.ser,
2266 validator_fcns,
2267 error_name,
2268 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002269 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002270 ):
2271 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002272
Eric Kunzee5e26762020-10-13 16:11:07 -07002273 return acc_out
2274
Luke Hutton57287132023-02-06 14:54:18 +00002275 def build_fft2d(
2276 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2277 ):
2278 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2279
2280 input_names = [val1.name, val2.name]
2281 pCount, cCount = op["operands"]
2282 num_operands = pCount + cCount
2283
2284 output_names = [res.name for res in results]
2285 output_shapes = [res.shape for res in results]
2286 output_dtypes = [res.dtype for res in results]
2287
2288 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2289 self, error_name, input_names, output_names
2290 )
2291
2292 if not TosaErrorValidator.evValidateErrorIfs(
2293 self.ser,
2294 validator_fcns,
2295 error_name,
2296 op=op,
2297 inverse=inverse,
2298 input1=val1,
2299 input2=val2,
2300 input_shape=val1.shape,
2301 input_dtype=val1.dtype,
2302 output_shape=output_shapes,
2303 output_dtype=output_dtypes,
2304 result_tensors=results,
2305 input_list=input_names,
2306 output_list=output_names,
2307 num_operands=num_operands,
2308 ):
2309 return None
2310
2311 attr = ts.TosaSerializerAttribute()
2312 attr.FFTAttribute(inverse)
2313
2314 self.ser.addOperator(op["op"], input_names, output_names, attr)
2315 return results
2316
Luke Hutton261b7b62023-01-10 14:50:31 +00002317 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2318 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2319
2320 input_names = [val.name]
2321 pCount, cCount = op["operands"]
2322 num_operands = pCount + cCount
2323
2324 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002325 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002326 output_dtypes = [res.dtype for res in results]
2327
2328 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2329 self, error_name, input_names, output_names
2330 )
2331
2332 if not TosaErrorValidator.evValidateErrorIfs(
2333 self.ser,
2334 validator_fcns,
2335 error_name,
2336 op=op,
2337 input_shape=val.shape,
2338 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002339 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002340 output_dtype=output_dtypes,
2341 result_tensors=results,
2342 input_list=input_names,
2343 output_list=output_names,
2344 num_operands=num_operands,
2345 ):
2346 return None
2347
2348 self.ser.addOperator(op["op"], input_names, output_names)
2349 return results
2350
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002351 def create_filter_lists(
2352 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2353 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002354 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2355 default_test_rank_range = range(1, 5)
2356 if not shapeFilter:
2357 shapeFilter = [None]
2358
2359 # Calculate the filters based on what is requested and what the operator allows
2360 rmin, rmax = op["rank"]
2361 if rankFilter is not None:
2362 cleanRankFilter = []
2363 # Ensure rankFilter values are allowed by operator
2364 for rank in rankFilter:
2365 if rank >= rmin and rank <= rmax:
2366 cleanRankFilter.append(rank)
2367 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002368 # Ensure default behaviour is bounded by default range or by operator,
2369 # whichever is the smaller range of ranks.
2370 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002371 cleanRankFilter = (
2372 opRankRange
2373 if len(opRankRange) <= len(default_test_rank_range)
2374 else default_test_rank_range
2375 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002376 else:
2377 cleanRankFilter = range(rmin, rmax + 1)
2378
2379 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002380
Matthew Haddon1c00b712021-10-01 15:51:03 +01002381 if dtypeFilter is not None:
2382 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002383 # Create list of operator dtypes filtered by requested dtypes
2384 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002385 if dtype in dtypeFilter or (
2386 isinstance(dtype, list) and dtype[0] in dtypeFilter
2387 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002388 cleanDtypeFilter.append(dtype)
2389 else:
2390 cleanDtypeFilter = dtypes
2391
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002392 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002393 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002394 "shapeFilter": shapeFilter,
2395 "rankFilter": cleanRankFilter,
2396 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002397 }
2398 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002399 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002400 if validator is not None:
2401 validator_info = validator(check=False, op=op)
2402 else:
2403 return None
2404
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002405 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002406
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002407 # Set parameters as required
2408 if error_arguments["rank"] is not None:
2409 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002410 else:
2411 rankFilter = cleanRankFilter
2412
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002413 if error_arguments["dtype"] is not None:
2414 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002415 else:
2416 dtypeFilter = cleanDtypeFilter
2417
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002418 if error_arguments["shape"] is not None:
2419 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002420 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002421 shapeFilter = shapeFilter[
2422 :2
2423 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002424
2425 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002426 "shapeFilter": shapeFilter,
2427 "rankFilter": rankFilter,
2428 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002429 }
2430 return filterDict
2431
Kevin Cheng550ccc52021-03-03 11:21:43 -08002432 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002433 self,
2434 opName,
2435 shapeFilter=[None],
2436 rankFilter=None,
2437 dtypeFilter=None,
2438 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002439 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002440
2441 try:
2442 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002443 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002444 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002445
2446 # Initialize a new random number generator
2447 self.rng = np.random.default_rng(self.random_seed)
2448
Jeremy Johnson1271c442023-09-05 11:39:26 +01002449 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002450
Eric Kunzee5e26762020-10-13 16:11:07 -07002451 # Test list consists of a tuple of:
2452 # (opName, testNameStr, dtype, shapeList, argumentsList)
2453 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002454 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002455 error_if_validators = op["error_if_validators"]
2456 else:
2457 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002458
Matthew Haddon1c00b712021-10-01 15:51:03 +01002459 for validator in error_if_validators:
2460 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002461 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002462 else:
2463 error_name = None
2464
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002465 filterDict = self.create_filter_lists(
2466 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2467 )
2468 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002469 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002470 cleanRankFilter = filterDict["rankFilter"]
2471 cleanDtypeFilter = filterDict["dtypeFilter"]
2472 cleanShapeFilter = filterDict["shapeFilter"]
2473 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002474
2475 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002476 for t in cleanDtypeFilter:
2477 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002478 # Filter out by rank
2479 if shape is not None and len(shape) != r:
2480 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002481 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002482 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002483
Matthew Haddon74567092021-07-16 15:38:20 +01002484 shapeStr = self.shapeStr(shapeList[0])
2485 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002486
Matthew Haddon74567092021-07-16 15:38:20 +01002487 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2488 argList = []
2489 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002490 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002491 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002492 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002493
Matthew Haddon74567092021-07-16 15:38:20 +01002494 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002495 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002496 if argStr:
2497 testStr = "{}_{}_{}_{}".format(
2498 opName, shapeStr, typeStr, argStr
2499 )
2500 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002501 testStr = "{}_{}_{}".format(
2502 opName, shapeStr, typeStr
2503 )
2504 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002505 if argStr:
2506 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2507 opName, error_name, shapeStr, typeStr, argStr
2508 )
2509 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002510 testStr = "{}_ERRORIF_{}_{}_{}".format(
2511 opName, error_name, shapeStr, typeStr
2512 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002513
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002514 testList.append(
2515 (opName, testStr, t, error_name, shapeList, args)
2516 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002517
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002518 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002519 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2520 if "invalid_test_validators" in op:
2521 invalid_test_validators = op["invalid_test_validators"]
2522 clean_testList = []
2523 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002524 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002525 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002526 if validator_fcn(
2527 opName=test[0],
2528 input_dtype=test[2],
2529 shapeList=test[4],
2530 args=test[5],
2531 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002532 remove_test = True
2533 if not remove_test:
2534 clean_testList.append(test)
2535 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002536
2537 return testList
2538
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002539 def serializeTest(
2540 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2541 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002542 try:
2543 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002544 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002545 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002546
Jeremy Johnson0c716862023-04-13 17:18:19 +01002547 if self.args.verbose:
2548 print(f"Creating {testStr}")
2549
Eric Kunzee5e26762020-10-13 16:11:07 -07002550 # Create a serializer
2551 self.createSerializer(opName, testStr)
2552
Jeremy Johnson1271c442023-09-05 11:39:26 +01002553 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002554 if "error_if_validators" in op:
2555 error_if_validators = op["error_if_validators"]
2556 else:
2557 error_if_validators = None
2558
Kevin Cheng550ccc52021-03-03 11:21:43 -08002559 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002560 num_operands = pCount + cCount
2561
2562 if isinstance(dtype_or_dtypeList, list):
2563 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002564 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002565 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002566 else:
2567 dtypeList = [dtype_or_dtypeList] * (num_operands)
2568
Kevin Cheng93a16282021-08-31 16:14:03 -07002569 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002570 assert (
2571 len(shapeList) == num_operands
2572 ), "shapeList length {} must match number of operands {}".format(
2573 len(shapeList), num_operands
2574 )
2575 assert (
2576 len(dtypeList) == num_operands
2577 ), "dtypeList length {} must match number of operands {}".format(
2578 len(dtypeList), num_operands
2579 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002580
2581 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002582 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002583 except KeyError:
2584 qgen = None
2585
2586 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002587
Matthew Haddon1c00b712021-10-01 15:51:03 +01002588 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002589 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002590 else:
2591 qinfo = None
2592
Jeremy Johnson1271c442023-09-05 11:39:26 +01002593 # Extra meta data for the desc.json
2594 tensMeta = {}
2595
2596 # Check we are using the new testArgs interface with an argsDict dictionary
2597 if len(testArgs) == 1 and isinstance(testArgs[0], dict):
2598 argsDict = testArgs[0]
2599 assert "dg_type" in argsDict
2600 tvgInfo = tvgen_fcn(
2601 self, opName, dtypeList, shapeList, argsDict, error_name
2602 )
2603 if tvgInfo.dataGenDict:
2604 tensMeta["data_gen"] = tvgInfo.dataGenDict
2605 tens = tvgInfo.tensorList
2606 else:
2607 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002608
Matthew Haddon1c00b712021-10-01 15:51:03 +01002609 try:
2610 if error_if_validators is None:
2611 if qinfo is not None:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002612 result = build_fcn(self, op, *tens, *testArgs, qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002613 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002614 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002615 else:
2616 if qinfo is not None:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002617 result = build_fcn(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002618 self,
2619 op,
2620 *tens,
2621 *testArgs,
2622 validator_fcns=error_if_validators,
2623 error_name=error_name,
2624 qinfo=qinfo,
2625 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002626 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002627 result = build_fcn(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002628 self,
2629 op,
2630 *tens,
2631 *testArgs,
2632 validator_fcns=error_if_validators,
2633 error_name=error_name,
2634 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002635 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002636 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002637 raise e
2638
Jeremy Johnson1271c442023-09-05 11:39:26 +01002639 if result:
Les Bell729b0352021-11-24 10:28:21 +00002640 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002641 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2642 # Add the compliance meta data
2643 # NOTE: This currently expects only one result output
2644 tensMeta["compliance"] = {
2645 "version": "0.1",
2646 "tensors": {result.resultTensor.name: result.complianceDict},
2647 }
2648 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002649 else:
2650 # The test is not valid
2651 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002652
Eric Kunzee5e26762020-10-13 16:11:07 -07002653 def createDynamicOpLists(self):
2654
Jeremy Johnson00423432022-09-12 17:27:37 +01002655 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2656 # Already created these lists (can occur when class is initialized more than once)
2657 return
2658
Eric Kunzee5e26762020-10-13 16:11:07 -07002659 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002660 if not self.args.level8k:
2661 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2662 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2663 else:
2664 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2665 KERNELS_2D = [[1, bigK], [bigK, 2]]
2666 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002667
Kevin Cheng1533b852021-09-01 12:51:58 -07002668 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002669 testName = "conv2d_{}x{}".format(k[0], k[1])
2670 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2671 self.TOSA_OP_LIST[testName]["filter"] = k
2672 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002673
Kevin Cheng550ccc52021-03-03 11:21:43 -08002674 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2675 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2676 "depthwise_conv2d_TEMPLATE"
2677 ].copy()
2678 self.TOSA_OP_LIST[testName]["filter"] = k
2679 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002680
Kevin Cheng550ccc52021-03-03 11:21:43 -08002681 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2682 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2683 "transpose_conv2d_TEMPLATE"
2684 ].copy()
2685 self.TOSA_OP_LIST[testName]["filter"] = k
2686 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002687
Kevin Cheng1533b852021-09-01 12:51:58 -07002688 for k in KERNELS_3D:
2689 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2690 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2691 self.TOSA_OP_LIST[testName]["filter"] = k
2692 self.TOSA_OP_LIST[testName]["template"] = False
2693
Eric Kunzee5e26762020-10-13 16:11:07 -07002694 # Delete any templates after having created any dynamic ops
2695 # This is a two-pass operation because it's bad practice to delete
2696 # keys from dictionaries while iterating
2697 keyList = []
2698 for k in self.TOSA_OP_LIST:
2699 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002700 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002701 keyList.append(k)
2702 continue
2703 except KeyError:
2704 pass
2705
2706 for k in keyList:
2707 del self.TOSA_OP_LIST[k]
2708
2709 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002710 """Fill in default fields for ops if they aren't already specified.
2711 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002712 for op in self.TOSA_OP_LIST:
2713
2714 # Required fields
2715 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002716 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002717 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002718 raise Exception(
2719 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2720 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002721
2722 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002723 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002724 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002725 raise Exception(
2726 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2727 op
2728 )
2729 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002730
2731 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002732 _ = self.TOSA_OP_LIST[op]["types"]
2733 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002734 raise Exception(
2735 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2736 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002737
2738 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002739 _ = self.TOSA_OP_LIST[op]["op"]
2740 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002741 raise Exception(
2742 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2743 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002744
2745 # Put in default rank range, if missing
2746 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002747 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002748 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002749 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002750
2751 # Tensor operator list
2752 # 'op': op name
2753 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002754 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2755 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002756 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2757 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002758 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002759
Kevin Cheng550ccc52021-03-03 11:21:43 -08002760 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002761 TYPE_INT_FP = [
2762 DType.INT8,
2763 DType.INT16,
2764 DType.INT32,
2765 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002766 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002767 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002768 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002769
Kevin Cheng550ccc52021-03-03 11:21:43 -08002770 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002771 TYPE_FI32 = [
2772 DType.FP32,
2773 DType.FP16,
2774 DType.BF16,
2775 DType.INT32,
2776 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002777 TYPE_FIB = [
2778 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002779 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002780 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002781 DType.INT8,
2782 DType.INT16,
2783 DType.INT32,
2784 DType.BOOL,
2785 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002786 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002787
James Ward24dbc422022-10-19 12:20:31 +01002788 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002789
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002790 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002791 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002792 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002793 [DType.INT8, DType.INT8, DType.INT32],
2794 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002795 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002796 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002797 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002798 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002799 ]
2800
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002801 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002802
2803 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002804 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002805 "argmax": {
2806 "op": Op.ARGMAX,
2807 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002808 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002809 "build_fcn": (
2810 build_argmax,
2811 TosaTensorGen.tgBasic,
2812 TosaTensorValuesGen.tvgDefault,
2813 TosaArgGen.agAxis,
2814 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002815 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002816 "error_if_validators": (
2817 TosaErrorValidator.evAxisSmallerZero,
2818 TosaErrorValidator.evAxisLargerRank,
2819 TosaErrorValidator.evArgmaxOutputRankMismatch,
2820 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2821 TosaErrorValidator.evWrongRank,
2822 TosaErrorValidator.evWrongInputType,
2823 TosaErrorValidator.evWrongOutputType,
2824 TosaErrorValidator.evWrongInputList,
2825 TosaErrorValidator.evWrongOutputList,
2826 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002827 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002828 "avg_pool2d": {
2829 "op": Op.AVG_POOL2D,
2830 "operands": (1, 0),
2831 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002832 "build_fcn": (
2833 build_pool2d,
2834 TosaTensorGen.tgNHWC,
2835 TosaTensorValuesGen.tvgDefault,
2836 TosaArgGen.agPooling,
2837 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002838 "qgen": TosaQuantGen.qgUnary,
2839 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002840 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002841 "error_if_validators": (
2842 TosaErrorValidator.evKernelSmallerOne,
2843 TosaErrorValidator.evStrideSmallerOne,
2844 TosaErrorValidator.evPadSmallerZero,
2845 TosaErrorValidator.evWrongRank,
2846 TosaErrorValidator.evWrongInputType,
2847 TosaErrorValidator.evWrongOutputType,
2848 TosaErrorValidator.evWrongInputList,
2849 TosaErrorValidator.evWrongOutputList,
2850 TosaErrorValidator.evInputZeroPointNotZero,
2851 TosaErrorValidator.evOutputZeroPointNotZero,
2852 TosaErrorValidator.evPadLargerEqualKernel,
2853 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002854 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002855 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002856 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002857 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002858 "conv2d_TEMPLATE": {
2859 "op": Op.CONV2D,
2860 "operands": (1, 2),
2861 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002862 "build_fcn": (
2863 build_conv2d,
2864 TosaTensorGen.tgConv2D,
2865 TosaTensorValuesGen.tvgDefault,
2866 TosaArgGen.agConv,
2867 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002868 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002869 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002870 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2871 "error_if_validators": (
2872 TosaErrorValidator.evWrongInputType,
2873 TosaErrorValidator.evWrongOutputType,
2874 TosaErrorValidator.evWrongInputList,
2875 TosaErrorValidator.evWrongOutputList,
2876 TosaErrorValidator.evInputZeroPointNotZero,
2877 TosaErrorValidator.evWeightZeroPointNotZero,
2878 TosaErrorValidator.evPadSmallerZero,
2879 TosaErrorValidator.evStrideSmallerOne,
2880 TosaErrorValidator.evDilationSmallerOne,
2881 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002882 TosaErrorValidator.evConvOutputShapeMismatch,
2883 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002884 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002885 "template": True,
2886 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002887 # Templated operator. Filled in by createDynamicOpLists
2888 "conv3d_TEMPLATE": {
2889 "op": Op.CONV3D,
2890 "operands": (1, 2),
2891 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002892 "build_fcn": (
2893 build_conv3d,
2894 TosaTensorGen.tgConv3D,
2895 TosaTensorValuesGen.tvgDefault,
2896 TosaArgGen.agConv,
2897 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002898 "qgen": TosaQuantGen.qgConv,
2899 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002900 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2901 "error_if_validators": (
2902 TosaErrorValidator.evWrongInputType,
2903 TosaErrorValidator.evWrongOutputType,
2904 TosaErrorValidator.evWrongInputList,
2905 TosaErrorValidator.evWrongOutputList,
2906 TosaErrorValidator.evInputZeroPointNotZero,
2907 TosaErrorValidator.evWeightZeroPointNotZero,
2908 TosaErrorValidator.evPadSmallerZero,
2909 TosaErrorValidator.evStrideSmallerOne,
2910 TosaErrorValidator.evDilationSmallerOne,
2911 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002912 TosaErrorValidator.evConvOutputShapeMismatch,
2913 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002914 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002915 "template": True,
2916 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002917 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002918 "depthwise_conv2d_TEMPLATE": {
2919 "op": Op.DEPTHWISE_CONV2D,
2920 "operands": (1, 2),
2921 "filter": [1, 1],
2922 "rank": (4, 4),
2923 "build_fcn": (
2924 build_depthwise_conv2d,
2925 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002926 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002927 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002928 ),
2929 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002930 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002931 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2932 "error_if_validators": (
2933 TosaErrorValidator.evWrongInputType,
2934 TosaErrorValidator.evWrongOutputType,
2935 TosaErrorValidator.evWrongInputList,
2936 TosaErrorValidator.evWrongOutputList,
2937 TosaErrorValidator.evInputZeroPointNotZero,
2938 TosaErrorValidator.evWeightZeroPointNotZero,
2939 TosaErrorValidator.evPadSmallerZero,
2940 TosaErrorValidator.evStrideSmallerOne,
2941 TosaErrorValidator.evDilationSmallerOne,
2942 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002943 TosaErrorValidator.evConvOutputShapeMismatch,
2944 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002945 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002946 "template": True,
2947 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002948 "fully_connected": {
2949 "op": Op.FULLY_CONNECTED,
2950 "operands": (1, 2),
2951 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002952 "build_fcn": (
2953 build_fully_connected,
2954 TosaTensorGen.tgFullyConnected,
2955 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002956 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002957 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002958 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002959 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002960 "error_if_validators": (
2961 TosaErrorValidator.evInputZeroPointNotZero,
2962 TosaErrorValidator.evWeightZeroPointNotZero,
2963 TosaErrorValidator.evWrongRank,
2964 TosaErrorValidator.evWrongInputType,
2965 TosaErrorValidator.evWrongOutputType,
2966 TosaErrorValidator.evWrongInputList,
2967 TosaErrorValidator.evWrongOutputList,
2968 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002969 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002970 "matmul": {
2971 "op": Op.MATMUL,
2972 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002973 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002974 "build_fcn": (
2975 build_matmul,
2976 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002977 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01002978 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002979 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002980 "qgen": TosaQuantGen.qgMatmul,
2981 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002982 "error_if_validators": (
2983 TosaErrorValidator.evInputZeroPointNotZero,
2984 TosaErrorValidator.evWrongRank,
2985 TosaErrorValidator.evWrongInputType,
2986 TosaErrorValidator.evWrongOutputType,
2987 TosaErrorValidator.evWrongInputList,
2988 TosaErrorValidator.evWrongOutputList,
2989 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01002990 "data_gen": {
2991 "fp": (gtu.DataGenType.DOT_PRODUCT,),
2992 "int": (gtu.DataGenType.PSEUDO_RANDOM,),
2993 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002994 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002995 "max_pool2d": {
2996 "op": Op.MAX_POOL2D,
2997 "operands": (1, 0),
2998 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002999 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01003000 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003001 TosaTensorGen.tgNHWC,
3002 TosaTensorValuesGen.tvgDefault,
3003 TosaArgGen.agPooling,
3004 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003005 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003006 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003007 "error_if_validators": (
3008 TosaErrorValidator.evKernelSmallerOne,
3009 TosaErrorValidator.evStrideSmallerOne,
3010 TosaErrorValidator.evPadSmallerZero,
3011 TosaErrorValidator.evWrongRank,
3012 TosaErrorValidator.evWrongInputType,
3013 TosaErrorValidator.evWrongOutputType,
3014 TosaErrorValidator.evWrongInputList,
3015 TosaErrorValidator.evWrongOutputList,
3016 TosaErrorValidator.evPadLargerEqualKernel,
3017 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003018 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003019 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003020 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003021 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003022 "transpose_conv2d_TEMPLATE": {
3023 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003024 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003025 "rank": (4, 4),
3026 "build_fcn": (
3027 build_transpose_conv2d,
3028 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003029 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003030 TosaArgGen.agTransposeConv2D,
3031 ),
3032 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003033 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003034 "invalid_test_validators": (
3035 TosaInvalidValidator.ivHeightWidthInvalid,
3036 TosaInvalidValidator.ivNonPositiveOutputShape,
3037 ),
3038 "error_if_validators": (
3039 TosaErrorValidator.evWrongInputType,
3040 TosaErrorValidator.evWrongOutputType,
3041 TosaErrorValidator.evWrongInputList,
3042 TosaErrorValidator.evWrongOutputList,
3043 TosaErrorValidator.evInputZeroPointNotZero,
3044 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003045 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003046 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003047 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003048 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003049 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003050 "template": True,
3051 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003052 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003053 "clamp": {
3054 "op": Op.CLAMP,
3055 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003056 "build_fcn": (
3057 build_clamp,
3058 TosaTensorGen.tgBasic,
3059 TosaTensorValuesGen.tvgDefault,
3060 None,
3061 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003062 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003063 "error_if_validators": (
3064 TosaErrorValidator.evMaxSmallerMin,
3065 TosaErrorValidator.evWrongInputType,
3066 TosaErrorValidator.evWrongOutputType,
3067 TosaErrorValidator.evWrongInputList,
3068 TosaErrorValidator.evWrongOutputList,
3069 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003070 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003071 "sigmoid": {
3072 "op": Op.SIGMOID,
3073 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003074 "build_fcn": (
3075 build_sigmoid,
3076 TosaTensorGen.tgBasic,
3077 TosaTensorValuesGen.tvgDefault,
3078 None,
3079 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003080 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003081 "error_if_validators": (
3082 TosaErrorValidator.evWrongInputType,
3083 TosaErrorValidator.evWrongOutputType,
3084 TosaErrorValidator.evWrongInputList,
3085 TosaErrorValidator.evWrongOutputList,
3086 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003087 },
3088 "tanh": {
3089 "op": Op.TANH,
3090 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003091 "build_fcn": (
3092 build_tanh,
3093 TosaTensorGen.tgBasic,
3094 TosaTensorValuesGen.tvgDefault,
3095 None,
3096 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003097 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003098 "error_if_validators": (
3099 TosaErrorValidator.evWrongInputType,
3100 TosaErrorValidator.evWrongOutputType,
3101 TosaErrorValidator.evWrongInputList,
3102 TosaErrorValidator.evWrongOutputList,
3103 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003104 },
Won Jeon78155c62023-06-10 00:20:04 +00003105 "erf": {
3106 "op": Op.ERF,
3107 "operands": (1, 0),
3108 "build_fcn": (
3109 build_erf,
3110 TosaTensorGen.tgBasic,
3111 TosaTensorValuesGen.tvgDefault,
3112 None,
3113 ),
3114 "types": TYPE_FP,
3115 "error_if_validators": (
3116 TosaErrorValidator.evWrongInputType,
3117 TosaErrorValidator.evWrongOutputType,
3118 TosaErrorValidator.evWrongInputList,
3119 TosaErrorValidator.evWrongOutputList,
3120 ),
3121 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003122 # Elementwise Binary Operators
3123 "add": {
3124 "op": Op.ADD,
3125 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003126 "build_fcn": (
3127 build_binary_broadcast,
3128 TosaTensorGen.tgBroadcastFuzz,
3129 TosaTensorValuesGen.tvgAddSub,
3130 None,
3131 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003132 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003133 "error_if_validators": (
3134 TosaErrorValidator.evRankMismatch,
3135 TosaErrorValidator.evWrongInputType,
3136 TosaErrorValidator.evWrongOutputType,
3137 TosaErrorValidator.evWrongInputList,
3138 TosaErrorValidator.evWrongOutputList,
3139 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003140 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003141 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003142 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003143 "arithmetic_right_shift": {
3144 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3145 "operands": (2, 0),
3146 "build_fcn": (
3147 build_arithmetic_right_shift,
3148 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003149 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003150 TosaArgGen.agArithmeticRightShift,
3151 ),
3152 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003153 "error_if_validators": (
3154 TosaErrorValidator.evRankMismatch,
3155 TosaErrorValidator.evWrongInputType,
3156 TosaErrorValidator.evWrongOutputType,
3157 TosaErrorValidator.evWrongInputList,
3158 TosaErrorValidator.evWrongOutputList,
3159 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003160 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003161 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003162 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003163 "bitwise_and": {
3164 "op": Op.BITWISE_AND,
3165 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003166 "build_fcn": (
3167 build_binary_broadcast,
3168 TosaTensorGen.tgBroadcastFuzz,
3169 TosaTensorValuesGen.tvgDefault,
3170 None,
3171 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003172 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003173 "error_if_validators": (
3174 TosaErrorValidator.evRankMismatch,
3175 TosaErrorValidator.evWrongInputType,
3176 TosaErrorValidator.evWrongOutputType,
3177 TosaErrorValidator.evWrongInputList,
3178 TosaErrorValidator.evWrongOutputList,
3179 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003180 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003181 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003182 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003183 "bitwise_or": {
3184 "op": Op.BITWISE_OR,
3185 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003186 "build_fcn": (
3187 build_binary_broadcast,
3188 TosaTensorGen.tgBroadcastFuzz,
3189 TosaTensorValuesGen.tvgDefault,
3190 None,
3191 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003192 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003193 "error_if_validators": (
3194 TosaErrorValidator.evRankMismatch,
3195 TosaErrorValidator.evWrongInputType,
3196 TosaErrorValidator.evWrongOutputType,
3197 TosaErrorValidator.evWrongInputList,
3198 TosaErrorValidator.evWrongOutputList,
3199 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003200 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003201 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003202 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003203 "bitwise_xor": {
3204 "op": Op.BITWISE_XOR,
3205 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003206 "build_fcn": (
3207 build_binary_broadcast,
3208 TosaTensorGen.tgBroadcastFuzz,
3209 TosaTensorValuesGen.tvgDefault,
3210 None,
3211 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003212 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003213 "error_if_validators": (
3214 TosaErrorValidator.evRankMismatch,
3215 TosaErrorValidator.evWrongInputType,
3216 TosaErrorValidator.evWrongOutputType,
3217 TosaErrorValidator.evWrongInputList,
3218 TosaErrorValidator.evWrongOutputList,
3219 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003220 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003221 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003222 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003223 "intdiv": {
3224 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003225 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003226 "build_fcn": (
3227 build_binary_broadcast,
3228 TosaTensorGen.tgBroadcastFuzz,
3229 TosaTensorValuesGen.tvgIntDiv,
3230 None,
3231 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003232 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003233 "error_if_validators": (
3234 TosaErrorValidator.evRankMismatch,
3235 TosaErrorValidator.evWrongInputType,
3236 TosaErrorValidator.evWrongOutputType,
3237 TosaErrorValidator.evWrongInputList,
3238 TosaErrorValidator.evWrongOutputList,
3239 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003240 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003241 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003242 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003243 "logical_and": {
3244 "op": Op.LOGICAL_AND,
3245 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003246 "build_fcn": (
3247 build_binary_broadcast,
3248 TosaTensorGen.tgBroadcastFuzz,
3249 TosaTensorValuesGen.tvgDefault,
3250 None,
3251 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003252 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003253 "error_if_validators": (
3254 TosaErrorValidator.evRankMismatch,
3255 TosaErrorValidator.evWrongInputType,
3256 TosaErrorValidator.evWrongOutputType,
3257 TosaErrorValidator.evWrongInputList,
3258 TosaErrorValidator.evWrongOutputList,
3259 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003260 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003261 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003262 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003263 "logical_left_shift": {
3264 "op": Op.LOGICAL_LEFT_SHIFT,
3265 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003266 "build_fcn": (
3267 build_binary_broadcast,
3268 TosaTensorGen.tgBroadcastFuzz,
3269 TosaTensorValuesGen.tvgLogicalShift,
3270 None,
3271 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003272 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003273 "error_if_validators": (
3274 TosaErrorValidator.evRankMismatch,
3275 TosaErrorValidator.evWrongInputType,
3276 TosaErrorValidator.evWrongOutputType,
3277 TosaErrorValidator.evWrongInputList,
3278 TosaErrorValidator.evWrongOutputList,
3279 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003280 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003281 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003282 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003283 "logical_right_shift": {
3284 "op": Op.LOGICAL_RIGHT_SHIFT,
3285 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003286 "build_fcn": (
3287 build_binary_broadcast,
3288 TosaTensorGen.tgBroadcastFuzz,
3289 TosaTensorValuesGen.tvgLogicalShift,
3290 None,
3291 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003292 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003293 "error_if_validators": (
3294 TosaErrorValidator.evRankMismatch,
3295 TosaErrorValidator.evWrongInputType,
3296 TosaErrorValidator.evWrongOutputType,
3297 TosaErrorValidator.evWrongInputList,
3298 TosaErrorValidator.evWrongOutputList,
3299 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003300 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003301 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003302 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003303 "logical_or": {
3304 "op": Op.LOGICAL_OR,
3305 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003306 "build_fcn": (
3307 build_binary_broadcast,
3308 TosaTensorGen.tgBroadcastFuzz,
3309 TosaTensorValuesGen.tvgDefault,
3310 None,
3311 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003312 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003313 "error_if_validators": (
3314 TosaErrorValidator.evRankMismatch,
3315 TosaErrorValidator.evWrongInputType,
3316 TosaErrorValidator.evWrongOutputType,
3317 TosaErrorValidator.evWrongInputList,
3318 TosaErrorValidator.evWrongOutputList,
3319 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003320 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003321 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003322 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003323 "logical_xor": {
3324 "op": Op.LOGICAL_XOR,
3325 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003326 "build_fcn": (
3327 build_binary_broadcast,
3328 TosaTensorGen.tgBroadcastFuzz,
3329 TosaTensorValuesGen.tvgDefault,
3330 None,
3331 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003332 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003333 "error_if_validators": (
3334 TosaErrorValidator.evRankMismatch,
3335 TosaErrorValidator.evWrongInputType,
3336 TosaErrorValidator.evWrongOutputType,
3337 TosaErrorValidator.evWrongInputList,
3338 TosaErrorValidator.evWrongOutputList,
3339 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003340 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003341 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003342 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003343 "maximum": {
3344 "op": Op.MAXIMUM,
3345 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003346 "build_fcn": (
3347 build_binary_broadcast,
3348 TosaTensorGen.tgBroadcastFuzz,
3349 TosaTensorValuesGen.tvgDefault,
3350 None,
3351 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003352 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003353 "error_if_validators": (
3354 TosaErrorValidator.evRankMismatch,
3355 TosaErrorValidator.evWrongInputType,
3356 TosaErrorValidator.evWrongOutputType,
3357 TosaErrorValidator.evWrongInputList,
3358 TosaErrorValidator.evWrongOutputList,
3359 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003360 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003361 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003362 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003363 "minimum": {
3364 "op": Op.MINIMUM,
3365 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003366 "build_fcn": (
3367 build_binary_broadcast,
3368 TosaTensorGen.tgBroadcastFuzz,
3369 TosaTensorValuesGen.tvgDefault,
3370 None,
3371 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003372 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003373 "error_if_validators": (
3374 TosaErrorValidator.evRankMismatch,
3375 TosaErrorValidator.evWrongInputType,
3376 TosaErrorValidator.evWrongOutputType,
3377 TosaErrorValidator.evWrongInputList,
3378 TosaErrorValidator.evWrongOutputList,
3379 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003380 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003381 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003382 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003383 "mul": {
3384 "op": Op.MUL,
3385 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003386 "build_fcn": (
3387 build_mul,
3388 TosaTensorGen.tgBroadcastFuzz,
3389 TosaTensorValuesGen.tvgMul,
3390 TosaArgGen.agMul,
3391 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003392 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003393 "error_if_validators": (
3394 TosaErrorValidator.evWrongInputType,
3395 TosaErrorValidator.evWrongOutputType,
3396 TosaErrorValidator.evWrongInputList,
3397 TosaErrorValidator.evWrongOutputList,
3398 TosaErrorValidator.evRankMismatch,
3399 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003400 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003401 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003402 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003403 "pow": {
3404 "op": Op.POW,
3405 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003406 "build_fcn": (
3407 build_binary_broadcast,
3408 TosaTensorGen.tgBroadcastFuzz,
3409 TosaTensorValuesGen.tvgDefault,
3410 None,
3411 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003412 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003413 "error_if_validators": (
3414 TosaErrorValidator.evRankMismatch,
3415 TosaErrorValidator.evWrongInputType,
3416 TosaErrorValidator.evWrongOutputType,
3417 TosaErrorValidator.evWrongInputList,
3418 TosaErrorValidator.evWrongOutputList,
3419 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003420 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003421 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003422 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003423 "sub": {
3424 "op": Op.SUB,
3425 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003426 "build_fcn": (
3427 build_binary_broadcast,
3428 TosaTensorGen.tgBroadcastFuzz,
3429 TosaTensorValuesGen.tvgAddSub,
3430 None,
3431 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003432 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003433 "error_if_validators": (
3434 TosaErrorValidator.evRankMismatch,
3435 TosaErrorValidator.evWrongInputType,
3436 TosaErrorValidator.evWrongOutputType,
3437 TosaErrorValidator.evWrongInputList,
3438 TosaErrorValidator.evWrongOutputList,
3439 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003440 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003441 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003442 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003443 "table": {
3444 "op": Op.TABLE,
3445 # Use the automatic generation functions to create the input array
3446 # but create the table tensor in the build function, as it may be
3447 # a different type from the input
3448 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003449 "build_fcn": (
3450 build_table,
3451 TosaTensorGen.tgBasic,
3452 TosaTensorValuesGen.tvgDefault,
3453 TosaArgGen.agTable,
3454 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003455 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003456 "error_if_validators": (
3457 TosaErrorValidator.evWrongInputType,
3458 TosaErrorValidator.evWrongOutputType,
3459 TosaErrorValidator.evWrongInputList,
3460 TosaErrorValidator.evWrongOutputList,
3461 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003462 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003463 # Elementwise Unary operators
3464 "abs": {
3465 "op": Op.ABS,
3466 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003467 "build_fcn": (
3468 build_unary,
3469 TosaTensorGen.tgBasic,
3470 TosaTensorValuesGen.tvgDefault,
3471 None,
3472 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003474 "error_if_validators": (
3475 TosaErrorValidator.evWrongInputType,
3476 TosaErrorValidator.evWrongOutputType,
3477 TosaErrorValidator.evWrongInputList,
3478 TosaErrorValidator.evWrongOutputList,
3479 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003480 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003481 "bitwise_not": {
3482 "op": Op.BITWISE_NOT,
3483 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003484 "build_fcn": (
3485 build_unary,
3486 TosaTensorGen.tgBasic,
3487 TosaTensorValuesGen.tvgDefault,
3488 None,
3489 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003490 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003491 "error_if_validators": (
3492 TosaErrorValidator.evWrongInputType,
3493 TosaErrorValidator.evWrongOutputType,
3494 TosaErrorValidator.evWrongInputList,
3495 TosaErrorValidator.evWrongOutputList,
3496 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003497 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003498 "ceil": {
3499 "op": Op.CEIL,
3500 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003501 "build_fcn": (
3502 build_unary,
3503 TosaTensorGen.tgBasic,
3504 TosaTensorValuesGen.tvgDefault,
3505 None,
3506 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003507 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003508 "error_if_validators": (
3509 TosaErrorValidator.evWrongInputType,
3510 TosaErrorValidator.evWrongOutputType,
3511 TosaErrorValidator.evWrongInputList,
3512 TosaErrorValidator.evWrongOutputList,
3513 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003514 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003515 "clz": {
3516 "op": Op.CLZ,
3517 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003518 "build_fcn": (
3519 build_unary,
3520 TosaTensorGen.tgBasic,
3521 TosaTensorValuesGen.tvgDefault,
3522 None,
3523 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003524 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003525 "error_if_validators": (
3526 TosaErrorValidator.evWrongInputType,
3527 TosaErrorValidator.evWrongOutputType,
3528 TosaErrorValidator.evWrongInputList,
3529 TosaErrorValidator.evWrongOutputList,
3530 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003531 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003532 "exp": {
3533 "op": Op.EXP,
3534 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003535 "build_fcn": (
3536 build_unary,
3537 TosaTensorGen.tgBasic,
3538 TosaTensorValuesGen.tvgDefault,
3539 None,
3540 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003541 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003542 "error_if_validators": (
3543 TosaErrorValidator.evWrongInputType,
3544 TosaErrorValidator.evWrongOutputType,
3545 TosaErrorValidator.evWrongInputList,
3546 TosaErrorValidator.evWrongOutputList,
3547 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003548 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003549 "floor": {
3550 "op": Op.FLOOR,
3551 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003552 "build_fcn": (
3553 build_unary,
3554 TosaTensorGen.tgBasic,
3555 TosaTensorValuesGen.tvgDefault,
3556 None,
3557 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003558 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003559 "error_if_validators": (
3560 TosaErrorValidator.evWrongInputType,
3561 TosaErrorValidator.evWrongOutputType,
3562 TosaErrorValidator.evWrongInputList,
3563 TosaErrorValidator.evWrongOutputList,
3564 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003565 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003566 "log": {
3567 "op": Op.LOG,
3568 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003569 "build_fcn": (
3570 build_unary,
3571 TosaTensorGen.tgBasic,
3572 TosaTensorValuesGen.tvgDefault,
3573 None,
3574 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003575 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003576 "error_if_validators": (
3577 TosaErrorValidator.evWrongInputType,
3578 TosaErrorValidator.evWrongOutputType,
3579 TosaErrorValidator.evWrongInputList,
3580 TosaErrorValidator.evWrongOutputList,
3581 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003582 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003583 "logical_not": {
3584 "op": Op.LOGICAL_NOT,
3585 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003586 "build_fcn": (
3587 build_unary,
3588 TosaTensorGen.tgBasic,
3589 TosaTensorValuesGen.tvgDefault,
3590 None,
3591 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003592 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003593 "error_if_validators": (
3594 TosaErrorValidator.evWrongInputType,
3595 TosaErrorValidator.evWrongOutputType,
3596 TosaErrorValidator.evWrongInputList,
3597 TosaErrorValidator.evWrongOutputList,
3598 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003599 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003600 "negate": {
3601 "op": Op.NEGATE,
3602 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003603 "build_fcn": (
3604 build_unary,
3605 TosaTensorGen.tgBasic,
3606 TosaTensorValuesGen.tvgNegate,
3607 None,
3608 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003609 "qgen": TosaQuantGen.qgUnary,
3610 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003611 "error_if_validators": (
3612 TosaErrorValidator.evInputZeroPointNotZero,
3613 TosaErrorValidator.evOutputZeroPointNotZero,
3614 TosaErrorValidator.evWrongInputType,
3615 TosaErrorValidator.evWrongOutputType,
3616 TosaErrorValidator.evWrongInputList,
3617 TosaErrorValidator.evWrongOutputList,
3618 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003619 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003620 "reciprocal": {
3621 "op": Op.RECIPROCAL,
3622 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003623 "build_fcn": (
3624 build_unary,
3625 TosaTensorGen.tgBasic,
3626 TosaTensorValuesGen.tvgDefault,
3627 None,
3628 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003629 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003630 "error_if_validators": (
3631 TosaErrorValidator.evWrongInputType,
3632 TosaErrorValidator.evWrongOutputType,
3633 TosaErrorValidator.evWrongInputList,
3634 TosaErrorValidator.evWrongOutputList,
3635 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003636 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003637 "rsqrt": {
3638 "op": Op.RSQRT,
3639 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003640 "build_fcn": (
3641 build_unary,
3642 TosaTensorGen.tgBasic,
3643 TosaTensorValuesGen.tvgDefault,
3644 None,
3645 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003646 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003647 "error_if_validators": (
3648 TosaErrorValidator.evWrongInputType,
3649 TosaErrorValidator.evWrongOutputType,
3650 TosaErrorValidator.evWrongInputList,
3651 TosaErrorValidator.evWrongOutputList,
3652 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003653 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003654 # Elementwise Ternary operators
3655 "select": {
3656 "op": Op.SELECT,
3657 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003658 "build_fcn": (
3659 build_select,
3660 TosaTensorGen.tgBroadcastFuzz,
3661 TosaTensorValuesGen.tvgSelect,
3662 None,
3663 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003664 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003665 "error_if_validators": (
3666 TosaErrorValidator.evRankMismatch,
3667 TosaErrorValidator.evWrongInputType,
3668 TosaErrorValidator.evWrongOutputType,
3669 TosaErrorValidator.evWrongInputList,
3670 TosaErrorValidator.evWrongOutputList,
3671 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003672 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003673 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003674 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003675 # Comparison operators
3676 "equal": {
3677 "op": Op.EQUAL,
3678 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003679 "build_fcn": (
3680 build_comparison,
3681 TosaTensorGen.tgBroadcastFuzz,
3682 TosaTensorValuesGen.tvgEqual,
3683 None,
3684 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003685 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003686 "error_if_validators": (
3687 TosaErrorValidator.evRankMismatch,
3688 TosaErrorValidator.evWrongInputType,
3689 TosaErrorValidator.evWrongOutputType,
3690 TosaErrorValidator.evWrongInputList,
3691 TosaErrorValidator.evWrongOutputList,
3692 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003693 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003694 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003695 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003696 "greater_equal": {
3697 "op": Op.GREATER_EQUAL,
3698 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003699 "build_fcn": (
3700 build_comparison,
3701 TosaTensorGen.tgBroadcastFuzz,
3702 TosaTensorValuesGen.tvgDefault,
3703 None,
3704 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003705 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003706 "error_if_validators": (
3707 TosaErrorValidator.evRankMismatch,
3708 TosaErrorValidator.evWrongInputType,
3709 TosaErrorValidator.evWrongOutputType,
3710 TosaErrorValidator.evWrongInputList,
3711 TosaErrorValidator.evWrongOutputList,
3712 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003713 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003714 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003715 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003716 "greater": {
3717 "op": Op.GREATER,
3718 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003719 "build_fcn": (
3720 build_comparison,
3721 TosaTensorGen.tgBroadcastFuzz,
3722 TosaTensorValuesGen.tvgDefault,
3723 None,
3724 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003725 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003726 "error_if_validators": (
3727 TosaErrorValidator.evRankMismatch,
3728 TosaErrorValidator.evWrongInputType,
3729 TosaErrorValidator.evWrongOutputType,
3730 TosaErrorValidator.evWrongInputList,
3731 TosaErrorValidator.evWrongOutputList,
3732 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003733 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003734 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003735 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003736 # Reduction operators
3737 "reduce_all": {
3738 "op": Op.REDUCE_ALL,
3739 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003740 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003741 "build_fcn": (
3742 build_reduce,
3743 TosaTensorGen.tgBasic,
3744 TosaTensorValuesGen.tvgDefault,
3745 TosaArgGen.agAxis,
3746 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003747 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003748 "error_if_validators": (
3749 TosaErrorValidator.evAxisLargerRank,
3750 TosaErrorValidator.evAxisSmallerZero,
3751 TosaErrorValidator.evShapeOfAxisNotOne,
3752 TosaErrorValidator.evWrongInputType,
3753 TosaErrorValidator.evWrongOutputType,
3754 TosaErrorValidator.evWrongRank,
3755 TosaErrorValidator.evWrongInputList,
3756 TosaErrorValidator.evWrongOutputList,
3757 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003758 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003759 "reduce_any": {
3760 "op": Op.REDUCE_ANY,
3761 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003762 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003763 "build_fcn": (
3764 build_reduce,
3765 TosaTensorGen.tgBasic,
3766 TosaTensorValuesGen.tvgDefault,
3767 TosaArgGen.agAxis,
3768 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003769 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003770 "error_if_validators": (
3771 TosaErrorValidator.evAxisLargerRank,
3772 TosaErrorValidator.evAxisSmallerZero,
3773 TosaErrorValidator.evShapeOfAxisNotOne,
3774 TosaErrorValidator.evWrongInputType,
3775 TosaErrorValidator.evWrongOutputType,
3776 TosaErrorValidator.evWrongRank,
3777 TosaErrorValidator.evWrongInputList,
3778 TosaErrorValidator.evWrongOutputList,
3779 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003780 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 "reduce_max": {
3782 "op": Op.REDUCE_MAX,
3783 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003784 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003785 "build_fcn": (
3786 build_reduce,
3787 TosaTensorGen.tgBasic,
3788 TosaTensorValuesGen.tvgDefault,
3789 TosaArgGen.agAxis,
3790 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003791 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003792 "error_if_validators": (
3793 TosaErrorValidator.evAxisLargerRank,
3794 TosaErrorValidator.evAxisSmallerZero,
3795 TosaErrorValidator.evShapeOfAxisNotOne,
3796 TosaErrorValidator.evWrongInputType,
3797 TosaErrorValidator.evWrongOutputType,
3798 TosaErrorValidator.evWrongRank,
3799 TosaErrorValidator.evWrongInputList,
3800 TosaErrorValidator.evWrongOutputList,
3801 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003802 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003803 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003804 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003805 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003806 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003807 "build_fcn": (
3808 build_reduce,
3809 TosaTensorGen.tgBasic,
3810 TosaTensorValuesGen.tvgDefault,
3811 TosaArgGen.agAxis,
3812 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003813 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003814 "error_if_validators": (
3815 TosaErrorValidator.evAxisLargerRank,
3816 TosaErrorValidator.evAxisSmallerZero,
3817 TosaErrorValidator.evShapeOfAxisNotOne,
3818 TosaErrorValidator.evWrongInputType,
3819 TosaErrorValidator.evWrongOutputType,
3820 TosaErrorValidator.evWrongRank,
3821 TosaErrorValidator.evWrongInputList,
3822 TosaErrorValidator.evWrongOutputList,
3823 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003824 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003825 "reduce_product": {
3826 "op": Op.REDUCE_PRODUCT,
3827 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003828 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003829 "build_fcn": (
3830 build_reduce,
3831 TosaTensorGen.tgBasic,
3832 TosaTensorValuesGen.tvgDefault,
3833 TosaArgGen.agAxis,
3834 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003835 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003836 "error_if_validators": (
3837 TosaErrorValidator.evAxisLargerRank,
3838 TosaErrorValidator.evAxisSmallerZero,
3839 TosaErrorValidator.evShapeOfAxisNotOne,
3840 TosaErrorValidator.evWrongInputType,
3841 TosaErrorValidator.evWrongOutputType,
3842 TosaErrorValidator.evWrongRank,
3843 TosaErrorValidator.evWrongInputList,
3844 TosaErrorValidator.evWrongOutputList,
3845 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003846 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003847 "reduce_sum": {
3848 "op": Op.REDUCE_SUM,
3849 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003850 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003851 "build_fcn": (
3852 build_reduce,
3853 TosaTensorGen.tgBasic,
3854 TosaTensorValuesGen.tvgReduceSum,
3855 TosaArgGen.agAxis,
3856 ),
James Ward24dbc422022-10-19 12:20:31 +01003857 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003858 "error_if_validators": (
3859 TosaErrorValidator.evAxisLargerRank,
3860 TosaErrorValidator.evAxisSmallerZero,
3861 TosaErrorValidator.evShapeOfAxisNotOne,
3862 TosaErrorValidator.evWrongInputType,
3863 TosaErrorValidator.evWrongOutputType,
3864 TosaErrorValidator.evWrongRank,
3865 TosaErrorValidator.evWrongInputList,
3866 TosaErrorValidator.evWrongOutputList,
3867 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003868 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003869 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003870 "concat": {
3871 "op": Op.CONCAT,
3872 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003873 "build_fcn": (
3874 build_concat,
3875 TosaTensorGen.tgConcat,
3876 TosaTensorValuesGen.tvgConcat,
3877 TosaArgGen.agAxis,
3878 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003879 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003880 "error_if_validators": (
3881 TosaErrorValidator.evAxisLargerRank,
3882 TosaErrorValidator.evAxisSmallerZero,
3883 TosaErrorValidator.evConcatInputRankMismatch,
3884 TosaErrorValidator.evConcatShapeSumMismatch,
3885 TosaErrorValidator.evConcatInputDimMismatch,
3886 TosaErrorValidator.evWrongInputType,
3887 TosaErrorValidator.evWrongOutputType,
3888 TosaErrorValidator.evWrongOutputList,
3889 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003890 },
3891 "pad": {
3892 "op": Op.PAD,
3893 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003894 "build_fcn": (
3895 build_pad,
3896 TosaTensorGen.tgBasic,
3897 TosaTensorValuesGen.tvgDefault,
3898 TosaArgGen.agPad,
3899 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003900 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003901 "error_if_validators": (
3902 TosaErrorValidator.evWrongInputType,
3903 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003904 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003905 TosaErrorValidator.evWrongOutputType,
3906 TosaErrorValidator.evWrongInputList,
3907 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003908 TosaErrorValidator.evRankMismatch,
3909 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003910 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003911 },
Won Jeona21b2e82023-08-10 10:33:01 +00003912 "dim": {
3913 "op": Op.DIM,
3914 "operands": (1, 0),
3915 "build_fcn": (
3916 build_dim,
3917 TosaTensorGen.tgBasic,
3918 TosaTensorValuesGen.tvgDefault,
3919 TosaArgGen.agAxis,
3920 ),
3921 "types": TYPE_FIB,
3922 "error_if_validators": (
3923 TosaErrorValidator.evAxisLargerRank,
3924 TosaErrorValidator.evAxisSmallerZero,
3925 TosaErrorValidator.evWrongInputType,
3926 TosaErrorValidator.evWrongInputList,
3927 TosaErrorValidator.evWrongOutputList,
3928 TosaErrorValidator.evWrongRank,
3929 ),
3930 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003931 "reshape": {
3932 "op": Op.RESHAPE,
3933 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003934 "build_fcn": (
3935 build_reshape,
3936 TosaTensorGen.tgBasic,
3937 TosaTensorValuesGen.tvgDefault,
3938 TosaArgGen.agReshape,
3939 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003940 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003941 "error_if_validators": (
3942 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3943 TosaErrorValidator.evWrongInputType,
3944 TosaErrorValidator.evWrongOutputType,
3945 TosaErrorValidator.evWrongInputList,
3946 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00003947 TosaErrorValidator.evReshapeOutputSizeMultiInference,
3948 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003949 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003950 },
3951 "reverse": {
3952 "op": Op.REVERSE,
3953 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003954 "build_fcn": (
3955 build_reverse,
3956 TosaTensorGen.tgBasic,
3957 TosaTensorValuesGen.tvgDefault,
3958 TosaArgGen.agAxis,
3959 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003960 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003961 "error_if_validators": (
3962 TosaErrorValidator.evAxisSmallerZero,
3963 TosaErrorValidator.evAxisLargerRank,
3964 TosaErrorValidator.evWrongInputType,
3965 TosaErrorValidator.evWrongOutputType,
3966 TosaErrorValidator.evWrongInputList,
3967 TosaErrorValidator.evWrongOutputList,
3968 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003969 },
3970 "slice": {
3971 "op": Op.SLICE,
3972 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003973 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003974 "build_fcn": (
3975 build_slice,
3976 TosaTensorGen.tgBasic,
3977 TosaTensorValuesGen.tvgDefault,
3978 TosaArgGen.agSlice,
3979 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003980 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003981 "error_if_validators": (
3982 TosaErrorValidator.evStartSmallerZero,
3983 TosaErrorValidator.evSizeSmallerEqualZero,
3984 TosaErrorValidator.evStartSizeOutsideBounds,
3985 TosaErrorValidator.evSizeOutputShapeMismatch,
3986 TosaErrorValidator.evInputSizeStartLengthMismatch,
3987 TosaErrorValidator.evWrongRank,
3988 TosaErrorValidator.evWrongInputType,
3989 TosaErrorValidator.evWrongOutputType,
3990 TosaErrorValidator.evWrongInputList,
3991 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003992 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003993 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003994 },
3995 "tile": {
3996 "op": Op.TILE,
3997 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003998 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003999 "build_fcn": (
4000 build_tile,
4001 TosaTensorGen.tgBasic,
4002 TosaTensorValuesGen.tvgDefault,
4003 TosaArgGen.agTile,
4004 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004005 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004006 "error_if_validators": (
4007 TosaErrorValidator.evWrongInputType,
4008 TosaErrorValidator.evWrongOutputType,
4009 TosaErrorValidator.evWrongInputList,
4010 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004011 TosaErrorValidator.evRankMismatch,
4012 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004013 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004014 },
4015 "transpose": {
4016 "op": Op.TRANSPOSE,
4017 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004018 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004019 "build_fcn": (
4020 build_transpose,
4021 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004022 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004023 TosaArgGen.agTranspose,
4024 ),
4025 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004026 "error_if_validators": (
4027 TosaErrorValidator.evIndexOutsideBounds,
4028 TosaErrorValidator.evIndexUsedTwice,
4029 TosaErrorValidator.evWrongInputType,
4030 TosaErrorValidator.evWrongOutputType,
4031 TosaErrorValidator.evWrongInputList,
4032 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004033 TosaErrorValidator.evWrongRank,
4034 TosaErrorValidator.evRankMismatch,
4035 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004036 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004037 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004038 # Data nodes
4039 "const": {
4040 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004041 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004042 "build_fcn": (
4043 build_const,
4044 TosaTensorGen.tgBasic,
4045 TosaTensorValuesGen.tvgDefault,
4046 None,
4047 ),
Luke Hutton65872422023-02-20 10:33:04 +00004048 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004049 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004050 "identity": {
4051 "op": Op.IDENTITY,
4052 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004053 "build_fcn": (
4054 build_unary,
4055 TosaTensorGen.tgBasic,
4056 TosaTensorValuesGen.tvgDefault,
4057 None,
4058 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004059 "types": TYPE_FIB,
4060 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004061 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004062 "gather": {
4063 "op": Op.GATHER,
4064 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4065 "operands": (1, 0),
4066 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004067 "build_fcn": (
4068 build_gather,
4069 TosaTensorGen.tgBasic,
4070 TosaTensorValuesGen.tvgDefault,
4071 None,
4072 ),
James Ward24dbc422022-10-19 12:20:31 +01004073 "types": (
4074 DType.INT8,
4075 DType.INT16,
4076 DType.INT32,
4077 DType.FP16,
4078 DType.BF16,
4079 DType.FP32,
4080 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004081 "error_if_validators": (
4082 TosaErrorValidator.evWrongInputType,
4083 TosaErrorValidator.evWrongOutputType,
4084 TosaErrorValidator.evWrongInputList,
4085 TosaErrorValidator.evWrongOutputList,
4086 TosaErrorValidator.evWrongRank,
4087 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004088 },
4089 "scatter": {
4090 "op": Op.SCATTER,
4091 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004092 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004093 "operands": (2, 0),
4094 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004095 "build_fcn": (
4096 build_scatter,
4097 TosaTensorGen.tgScatter,
4098 TosaTensorValuesGen.tvgDefault,
4099 None,
4100 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004101 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004102 "error_if_validators": (
4103 TosaErrorValidator.evWrongInputType,
4104 TosaErrorValidator.evWrongOutputType,
4105 TosaErrorValidator.evWrongInputList,
4106 TosaErrorValidator.evWrongOutputList,
4107 TosaErrorValidator.evWrongRank,
4108 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004109 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004110 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004111 "resize": {
4112 "op": Op.RESIZE,
4113 "operands": (1, 0),
4114 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004115 "build_fcn": (
4116 build_resize,
4117 TosaTensorGen.tgNHWC,
4118 TosaTensorValuesGen.tvgDefault,
4119 TosaArgGen.agResize,
4120 ),
James Ward24dbc422022-10-19 12:20:31 +01004121 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004122 "invalid_test_validators": (
4123 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004124 ),
4125 "error_if_validators": (
4126 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004127 TosaErrorValidator.evScaleSmallerEqualZero,
4128 TosaErrorValidator.evScaleNLargerMax,
4129 TosaErrorValidator.evScaleDLargerMax,
4130 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004131 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004132 TosaErrorValidator.evBorderSmallerMin,
4133 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004134 TosaErrorValidator.evWrongInputType,
4135 TosaErrorValidator.evWrongOutputType,
4136 TosaErrorValidator.evWrongRank,
4137 TosaErrorValidator.evWrongInputList,
4138 TosaErrorValidator.evWrongOutputList,
4139 TosaErrorValidator.evBatchMismatch,
4140 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004141 TosaErrorValidator.evResizeOutputShapeMismatch,
4142 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004143 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004144 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004145 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004146 "cast": {
4147 "op": Op.CAST,
4148 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004149 "build_fcn": (
4150 build_cast,
4151 TosaTensorGen.tgBasic,
4152 TosaTensorValuesGen.tvgDefault,
4153 TosaArgGen.agCast,
4154 ),
James Ward8b390432022-08-12 20:48:56 +01004155 "types": (
4156 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004157 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004158 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004159 DType.INT8,
4160 DType.INT16,
4161 DType.INT32,
4162 DType.BOOL,
4163 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004164 "error_if_validators": (
4165 TosaErrorValidator.evWrongInputType,
4166 TosaErrorValidator.evWrongOutputType,
4167 TosaErrorValidator.evWrongInputList,
4168 TosaErrorValidator.evWrongOutputList,
4169 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004170 },
4171 "rescale": {
4172 "op": Op.RESCALE,
4173 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004174 "build_fcn": (
4175 build_rescale,
4176 TosaTensorGen.tgBasic,
4177 TosaTensorValuesGen.tvgDefault,
4178 TosaArgGen.agRescale,
4179 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004180 "types": [
4181 DType.UINT8,
4182 DType.INT8,
4183 DType.INT16,
4184 DType.INT32,
4185 DType.INT48,
4186 DType.UINT16,
4187 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004188 "error_if_validators": (
4189 TosaErrorValidator.evInputZeroPointNotZero,
4190 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004191 TosaErrorValidator.evU16InputZeroPointNotValid,
4192 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004193 TosaErrorValidator.evScaleTrue,
4194 TosaErrorValidator.evScaleNotTrue,
4195 TosaErrorValidator.evWrongInputType,
4196 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004197 TosaErrorValidator.evWrongInputList,
4198 TosaErrorValidator.evWrongOutputList,
4199 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004200 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004201 # Custom
4202 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004203 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004204 # Two varients of cond_if, one that generates one of two constant tensors (no
4205 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4206 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004207 "cond_if_const": {
4208 "op": Op.COND_IF,
4209 "operands": (0, 2),
4210 "build_fcn": (
4211 build_cond_if_const,
4212 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004213 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004214 TosaArgGen.agCondIf,
4215 ),
4216 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004217 "error_if_validators": (
4218 TosaErrorValidator.evOutputListThenGraphMismatch,
4219 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004220 TosaErrorValidator.evCondIfCondNotMatchingBool,
4221 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004222 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004223 },
4224 "cond_if_binary": {
4225 "op": Op.COND_IF,
4226 "operands": (2, 0),
4227 "build_fcn": (
4228 build_cond_if_binary,
4229 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004230 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004231 TosaArgGen.agCondIf,
4232 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004233 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004234 "error_if_validators": (
4235 TosaErrorValidator.evInputListThenGraphMismatch,
4236 TosaErrorValidator.evInputListElseGraphMismatch,
4237 TosaErrorValidator.evOutputListThenGraphMismatch,
4238 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004239 TosaErrorValidator.evCondIfCondNotMatchingBool,
4240 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004241 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004242 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004243 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004244 "while_loop": {
4245 "op": Op.WHILE_LOOP,
4246 "operands": (0, 1),
4247 "build_fcn": (
4248 build_while_loop,
4249 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004250 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004251 TosaArgGen.agWhileLoop,
4252 ),
4253 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004254 "error_if_validators": (
4255 TosaErrorValidator.evInputListOutputListMismatch,
4256 TosaErrorValidator.evInputListCondGraphMismatch,
4257 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4258 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4259 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004260 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004261 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004262 },
Luke Hutton57287132023-02-06 14:54:18 +00004263 "fft2d": {
4264 "op": Op.FFT2D,
4265 "operands": (2, 0),
4266 "rank": (3, 3),
4267 "build_fcn": (
4268 build_fft2d,
4269 TosaTensorGen.tgFFT2d,
4270 TosaTensorValuesGen.tvgDefault,
4271 TosaArgGen.agFFT2d,
4272 ),
4273 "types": [DType.FP32],
4274 "error_if_validators": (
4275 TosaErrorValidator.evWrongInputType,
4276 TosaErrorValidator.evWrongOutputType,
4277 TosaErrorValidator.evWrongInputList,
4278 TosaErrorValidator.evWrongOutputList,
4279 TosaErrorValidator.evWrongRank,
4280 TosaErrorValidator.evBatchMismatch,
4281 TosaErrorValidator.evKernelNotPowerOfTwo,
4282 TosaErrorValidator.evFFTInputShapeMismatch,
4283 TosaErrorValidator.evFFTOutputShapeMismatch,
4284 ),
4285 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004286 "rfft2d": {
4287 "op": Op.RFFT2D,
4288 "operands": (1, 0),
4289 "rank": (3, 3),
4290 "build_fcn": (
4291 build_rfft2d,
4292 TosaTensorGen.tgRFFT2d,
4293 TosaTensorValuesGen.tvgDefault,
4294 TosaArgGen.agNone,
4295 ),
4296 "types": [DType.FP32],
4297 "error_if_validators": (
4298 TosaErrorValidator.evWrongInputType,
4299 TosaErrorValidator.evWrongOutputType,
4300 TosaErrorValidator.evWrongInputList,
4301 TosaErrorValidator.evWrongOutputList,
4302 TosaErrorValidator.evWrongRank,
4303 TosaErrorValidator.evBatchMismatch,
4304 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004305 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004306 ),
4307 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004308 }
4309
Kevin Cheng550ccc52021-03-03 11:21:43 -08004310
Eric Kunzee5e26762020-10-13 16:11:07 -07004311class OutputShaper:
4312 # Methods in this class compute the expected output shape and datatype
4313 # for common classes of operations
4314 def __init__(self):
4315 pass
4316
4317 # These methods return arguments that can be used for
4318 # creating a new output tensor
4319 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004320 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4321 if error_name != ErrorIf.RankMismatch:
4322 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004323 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004324
4325 shape = []
4326 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004327 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004328 shape.append(b.shape[i])
4329 else:
4330 shape.append(a.shape[i])
4331
Jerry Ge135c9552023-05-23 20:59:32 +00004332 fuzz_idx = rng.integers(0, len(a.shape))
4333 if error_name == ErrorIf.DimensionMismatch:
4334 shape[fuzz_idx] += 1
4335
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004336 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004337 all_dtypes = [
4338 DType.INT8,
4339 DType.INT16,
4340 DType.INT32,
4341 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004342 DType.FP16,
4343 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004344 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004345 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004346 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4347 outputDType = rng.choice(wrong_dtypes)
4348 else:
4349 outputDType = a.dtype
4350
4351 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004352
4353 @staticmethod
4354 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004355 assert len(a.shape) == len(b.shape)
4356 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004357
4358 shape = []
4359 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004360 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004361 shape.append(a.shape[i])
4362
Kevin Cheng550ccc52021-03-03 11:21:43 -08004363 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004364
4365 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004366 def unaryOp(ser, rng, a, error_name=None):
4367 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004368 all_dtypes = [
4369 DType.INT8,
4370 DType.INT16,
4371 DType.INT32,
4372 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004373 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004374 DType.FP16,
4375 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004376 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004377 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4378 outputDType = rng.choice(wrong_dtypes)
4379 else:
4380 outputDType = a.dtype
4381
4382 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004383
4384 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004385 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004386 if error_name != ErrorIf.RankMismatch:
4387 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004388 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004389
4390 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004391 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004392 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004393 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4394 else:
4395 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004396
Jerry Ge135c9552023-05-23 20:59:32 +00004397 fuzz_idx = rng.integers(0, len(a.shape))
4398 if error_name == ErrorIf.DimensionMismatch:
4399 shape[fuzz_idx] += 1
4400
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004401 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004402 all_dtypes = [
4403 DType.INT8,
4404 DType.INT16,
4405 DType.INT32,
4406 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004407 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004408 DType.FP16,
4409 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004410 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004411 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4412 outputDType = rng.choice(wrong_dtypes)
4413 else:
4414 outputDType = a.dtype
4415
4416 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004417
4418 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004419 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004420 if error_name != ErrorIf.RankMismatch:
4421 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004422 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004423
4424 # Do broadcast
4425 shape = []
4426 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004427 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004428 shape.append(b.shape[i])
4429 else:
4430 shape.append(a.shape[i])
4431
Jerry Ge135c9552023-05-23 20:59:32 +00004432 fuzz_idx = rng.integers(0, len(a.shape))
4433 if error_name == ErrorIf.DimensionMismatch:
4434 shape[fuzz_idx] += 1
4435
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004436 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004437 wrong_dtypes = [
4438 DType.INT8,
4439 DType.INT16,
4440 DType.INT32,
4441 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004442 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004443 DType.FP16,
4444 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004445 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004446 outputDType = rng.choice(wrong_dtypes)
4447 else:
4448 outputDType = DType.BOOL
4449
4450 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004451
4452 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004453 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004454 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004455 if error_name not in [
4456 ErrorIf.AxisSmallerZero,
4457 ErrorIf.AxisLargerRank,
4458 ErrorIf.ShapeOfAxisNotOne,
4459 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004460 shape[axis] = 1
4461 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4462 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004463
Matthew Haddond6ce7252021-09-29 15:35:44 +01004464 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004465 all_dtypes = [
4466 DType.INT8,
4467 DType.INT16,
4468 DType.INT32,
4469 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004470 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004471 DType.FP16,
4472 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004473 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004474 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4475 outputDType = rng.choice(wrong_dtypes)
4476 else:
4477 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004478
Matthew Haddond6ce7252021-09-29 15:35:44 +01004479 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004480
4481 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004482 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004483 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004484
4485 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4486 del shape[axis]
4487
4488 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4489 remove = rng.choice([True, False])
4490 if remove and len(shape) > 1:
4491 del shape[0]
4492 else:
4493 shape.append(1)
4494 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4495 for i in range(len(shape)):
4496 shape[i] = shape[i] + rng.integers(1, 10)
4497
4498 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004499 all_dtypes = [
4500 DType.INT8,
4501 DType.INT16,
4502 DType.INT32,
4503 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004504 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004505 DType.FP16,
4506 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004507 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004508 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4509 outputDType = rng.choice(wrong_dtypes)
4510 else:
4511 outputDType = DType.INT32
4512
4513 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004514
4515 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004516 def conv2dOp(
4517 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4518 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004519
4520 # IFM: NHWC
4521 # Filter: OHWI
4522 # OFM: NHWC
4523
Kevin Cheng550ccc52021-03-03 11:21:43 -08004524 h = (
4525 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004526 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004527 + padding[0]
4528 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004529 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004530 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004531
Kevin Cheng550ccc52021-03-03 11:21:43 -08004532 w = (
4533 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004534 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004535 + padding[2]
4536 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004537 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004538 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004539
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004540 if error_name == ErrorIf.ConvOutputShapeMismatch:
4541 choices = [1, 2, 3]
4542 change = rng.choice(choices)
4543 # increment in multiples of stride to not hit non-integer error case
4544 if change in [1, 3]:
4545 h = h + (rng.choice(choices) * strides[0])
4546 if change in [2, 3]:
4547 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004548
Eric Kunzee5e26762020-10-13 16:11:07 -07004549 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4550
James Ward8b390432022-08-12 20:48:56 +01004551 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004552 # Pick some potentially correct output dtype if input type is incorrect
4553 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004554 else:
James Ward8b390432022-08-12 20:48:56 +01004555 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004556
4557 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004558 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004559 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004560 else:
4561 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004562 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004563 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004564
Kevin Cheng550ccc52021-03-03 11:21:43 -08004565 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004566
4567 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004568 def conv3dOp(
4569 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4570 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004571
4572 # IFM: NDHWC
4573 # Filter: ODHWI
4574 # OFM: NDHWC
4575
4576 d = (
4577 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004578 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004579 + padding[0]
4580 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004581 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004582 ) // strides[0] + 1
4583
4584 h = (
4585 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004586 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004587 + padding[2]
4588 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004589 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004590 ) // strides[1] + 1
4591
4592 w = (
4593 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004594 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004595 + padding[4]
4596 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004597 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004598 ) // strides[2] + 1
4599
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004600 if error_name == ErrorIf.ConvOutputShapeMismatch:
4601 choices = [1, 2, 3, 4]
4602 change = rng.choice(choices)
4603 # increment in multiples of stride to not hit non-integer error case
4604 if change in [1, 4]:
4605 d = d + (rng.choice(choices) * strides[0])
4606 if change in [2, 4]:
4607 h = h + (rng.choice(choices) * strides[1])
4608 if change in [3, 4]:
4609 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004610
Kevin Cheng1533b852021-09-01 12:51:58 -07004611 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4612
James Ward8b390432022-08-12 20:48:56 +01004613 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004614 # Pick some potentially correct output dtype if input type is incorrect
4615 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004616 else:
James Ward8b390432022-08-12 20:48:56 +01004617 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004618
4619 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004620 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004621 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004622 else:
4623 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004624 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004625 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004626
4627 return ser.addOutput(ofm_shape, out_dtype)
4628
4629 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004630 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004631 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004632 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004633 # IFM: NHWC
4634 # Filter: HWCM
4635 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004636
Kevin Cheng550ccc52021-03-03 11:21:43 -08004637 h = (
4638 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004639 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004640 + padding[0]
4641 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004642 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004643 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004644
Kevin Cheng550ccc52021-03-03 11:21:43 -08004645 w = (
4646 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004647 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004648 + padding[2]
4649 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004650 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004651 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004652
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004653 if error_name == ErrorIf.ConvOutputShapeMismatch:
4654 choices = [1, 2, 3]
4655 change = rng.choice(choices)
4656 # increment in multiples of stride to not hit non-integer error case
4657 if change in [1, 3]:
4658 h = h + (rng.choice(choices) * strides[0])
4659 if change in [2, 3]:
4660 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004661
Eric Kunzee5e26762020-10-13 16:11:07 -07004662 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4663
James Ward8b390432022-08-12 20:48:56 +01004664 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004665 # Pick some potentially correct output dtype if input type is incorrect
4666 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004667 else:
James Ward8b390432022-08-12 20:48:56 +01004668 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004669
4670 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004671 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004672 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004673 else:
4674 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004675 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004676 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004677
Kevin Cheng550ccc52021-03-03 11:21:43 -08004678 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004679
4680 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004681 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004682 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004683 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004684 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004685 h = 1
4686 w = 1
4687 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004688 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4689 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004690
4691 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004692 choices = [1, 2, 3]
4693 change = rng.choice(choices)
4694 # increment in multiples of stride to not hit non-integer error case
4695 if change in [1, 3]:
4696 h = h + (rng.choice(choices) * stride[0])
4697 if change in [2, 3]:
4698 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004699 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004700
4701 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004702 all_dtypes = [
4703 DType.INT8,
4704 DType.INT16,
4705 DType.INT32,
4706 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004707 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004708 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004709 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004710 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004711 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4712 outputDType = rng.choice(wrong_dtypes)
4713 else:
4714 outputDType = ifm.dtype
4715
4716 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004717
4718 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004719 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004720 # input: N, IC
4721 # filter: OC, IC
4722 # output: N, OC
4723
4724 output_shape = [input.shape[0], filter.shape[0]]
4725
James Ward8b390432022-08-12 20:48:56 +01004726 # Validated in arg_gen (also invalidated for ErrorIf)
4727 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004728
Kevin Cheng550ccc52021-03-03 11:21:43 -08004729 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004730
4731 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004732 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004733 # a: N, H, C
4734 # b: N, C, W
4735 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004736
Kevin Cheng2d60f002021-06-09 14:18:32 -07004737 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004738
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004739 if error_name == ErrorIf.WrongOutputType:
4740 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004741 incorrect_types = (
4742 DType.INT4,
4743 DType.INT8,
4744 DType.INT16,
4745 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004746 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004747 DType.FP16,
4748 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004749 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004750 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004751 incorrect_types = (
4752 DType.INT4,
4753 DType.INT8,
4754 DType.INT16,
4755 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004756 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004757 DType.FP16,
4758 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004759 )
James Ward24dbc422022-10-19 12:20:31 +01004760 elif (
4761 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4762 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004763 incorrect_types = (
4764 DType.INT4,
4765 DType.INT8,
4766 DType.INT16,
4767 DType.INT32,
4768 DType.INT48,
4769 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004770 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004771 elif error_name == ErrorIf.WrongInputType:
4772 # Pick some potentially correct output dtype if input type is incorrect
4773 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004774 else:
James Ward8b390432022-08-12 20:48:56 +01004775 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004776
Kevin Cheng550ccc52021-03-03 11:21:43 -08004777 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004778
4779 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004780 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004781 input1 = a[0]
4782 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004783
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004784 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004785 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004786 if not (
4787 # unable to concat tensors of different ranks
4788 error_name == ErrorIf.ConcatInputRankMismatch
4789 # unable to concat tensors along an invalid axis
4790 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004791 ):
4792 for tensor in remaining_inputs:
4793 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004794
Matthew Haddon01c359d2021-10-15 16:30:48 +01004795 if error_name == ErrorIf.ConcatShapeSumMismatch:
4796 output_shape[axis] += rng.integers(5, 10)
4797
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004798 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004799 all_dtypes = {
4800 DType.INT8,
4801 DType.INT16,
4802 DType.INT32,
4803 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004804 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004805 DType.FP16,
4806 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004807 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004808 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4809 outputDType = rng.choice(wrong_dtypes)
4810 else:
4811 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004812
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004813 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004814
4815 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004816 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004817
4818 output_shape = a.shape.copy()
4819
4820 for i in range(len(output_shape)):
4821 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4822
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004823 if error_name == ErrorIf.PadOutputShapeMismatch:
4824 bad_dim = rng.choice(range(len(output_shape)))
4825 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004826 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004827 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004828
Matthew Haddone807aae2021-10-11 18:12:58 +01004829 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004830 all_dtypes = [
4831 DType.INT8,
4832 DType.INT16,
4833 DType.INT32,
4834 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004835 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004836 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004837 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004838 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004839 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4840 outputDType = rng.choice(wrong_dtypes)
4841 else:
4842 outputDType = a.dtype
4843
4844 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004845
4846 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00004847 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00004848 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00004849
4850 if error_name == ErrorIf.WrongOutputType:
4851 all_dtypes = [
4852 DType.INT8,
4853 DType.INT16,
4854 DType.INT32,
4855 DType.INT48,
4856 DType.FP32,
4857 DType.FP16,
4858 DType.BF16,
4859 ]
4860 wrong_dtypes = list(set(all_dtypes))
4861 outputDType = rng.choice(wrong_dtypes)
4862 else:
4863 outputDType = DType.SHAPE
4864
4865 return ser.addOutput(output_shape, outputDType)
4866
4867 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004868 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004869 output_shape = shape.copy()
4870
Matthew Haddone807aae2021-10-11 18:12:58 +01004871 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4872 for i in range(len(output_shape)):
4873 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4874
4875 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004876 all_dtypes = [
4877 DType.INT8,
4878 DType.INT16,
4879 DType.INT32,
4880 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004881 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004882 DType.FP16,
4883 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004884 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004885 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4886 outputDType = rng.choice(wrong_dtypes)
4887 else:
4888 outputDType = a.dtype
4889
4890 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004891
4892 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00004893 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004894
Matthew Haddone807aae2021-10-11 18:12:58 +01004895 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004896 all_dtypes = [
4897 DType.INT8,
4898 DType.INT16,
4899 DType.INT32,
4900 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004901 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004902 DType.FP16,
4903 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004904 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00004905 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01004906 outputDType = rng.choice(wrong_dtypes)
4907 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00004908 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01004909
Luke Huttona4e48ca2023-02-22 11:53:48 +00004910 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004911 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01004912 for index in range(len(output_shape)):
4913 if output_shape[index] <= 2:
4914 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4915 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004916 output_shape[index] = output_shape[index] + rng.choice(
4917 [-2, -1, 1, 2]
4918 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00004919 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
4920 output_shape = input.shape.copy()
4921 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004922 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01004923
4924 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004925
4926 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004927 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004928
4929 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004930 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004931
4932 for i in range(len(output_shape)):
4933 output_shape[i] = a.shape[i] * multiples[i]
4934
Luke Huttona4e48ca2023-02-22 11:53:48 +00004935 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004936 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00004937
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004938 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004939 all_dtypes = [
4940 DType.INT8,
4941 DType.INT16,
4942 DType.INT32,
4943 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004944 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004945 DType.FP16,
4946 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004947 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004948 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4949 outputDType = rng.choice(wrong_dtypes)
4950 else:
4951 outputDType = a.dtype
4952
4953 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004954
4955 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004956 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004957 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004958
Kevin Cheng550ccc52021-03-03 11:21:43 -08004959 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004960
Luke Huttona4e48ca2023-02-22 11:53:48 +00004961 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01004962 for i in range(len(output_shape)):
4963 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004964
Luke Huttona4e48ca2023-02-22 11:53:48 +00004965 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4966 for i in range(len(output_shape)):
4967 output_shape[i] += rng.integers(1, 10)
4968 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004969 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00004970
Matthew Haddone807aae2021-10-11 18:12:58 +01004971 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004972 all_dtypes = [
4973 DType.INT8,
4974 DType.INT16,
4975 DType.INT32,
4976 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004977 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004978 DType.FP16,
4979 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004980 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004981 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4982 outputDType = rng.choice(wrong_dtypes)
4983 else:
4984 outputDType = a.dtype
4985
4986 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004987
4988 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004989 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004990 if error_name != ErrorIf.WrongRank:
4991 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004992 assert len(indices.shape) == 2
4993 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004994
Kevin Cheng77d0f762020-11-24 10:26:32 -08004995 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4996
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004997 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004998 all_dtypes = [
4999 DType.INT8,
5000 DType.INT16,
5001 DType.INT32,
5002 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005003 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005004 DType.FP16,
5005 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005006 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005007 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5008 outputDType = rng.choice(wrong_dtypes)
5009 else:
5010 outputDType = values.dtype
5011
5012 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005013
5014 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005015 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005016 if error_name != ErrorIf.WrongRank:
5017 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005018 assert len(indices.shape) == 2
5019 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005020 assert values_in.shape[0] == indices.shape[0] # N
5021 assert input.shape[1] == indices.shape[1] # W
5022 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005023
5024 output_shape = values_in.shape
5025
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005026 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005027 all_dtypes = [
5028 DType.INT8,
5029 DType.INT16,
5030 DType.INT32,
5031 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005032 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005033 DType.FP16,
5034 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005035 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005036 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5037 outputDType = rng.choice(wrong_dtypes)
5038 else:
5039 outputDType = values_in.dtype
5040
5041 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005042
5043 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005044 def tableOp(ser, rng, input, error_name=None):
5045 # Same shape as the input, dtype dependent on input dtype
5046 if error_name != ErrorIf.WrongInputType:
5047 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005048 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005049 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005050 wrong_dtypes = [
5051 DType.INT8,
5052 DType.INT16,
5053 DType.INT32,
5054 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005055 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005056 DType.FP16,
5057 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005058 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005059 wrong_dtypes.remove(output_dtype)
5060 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005061 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005062
5063 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005064 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005065 serializer,
5066 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005067 input,
5068 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005069 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005070 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005071 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005072 input_dtype,
5073 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005074 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005075 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005076 # Calculate OH, OW
5077 scale_y_n = scale[0]
5078 scale_y_d = scale[1]
5079 scale_x_n = scale[2]
5080 scale_x_d = scale[3]
5081 if error_name == ErrorIf.ScaleSmallerEqualZero:
5082 scale_y_n = max(scale_y_n, 1)
5083 scale_y_d = max(scale_y_d, 1)
5084 scale_x_n = max(scale_x_n, 1)
5085 scale_x_d = max(scale_x_d, 1)
5086
5087 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5088 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5089
5090 if error_name is not None:
5091 # Make sure the output tensor is valid, which can occur when
5092 # scale, offset or border have been changed for ERROR_IFs
5093 oh = max(oh, 1)
5094 ow = max(ow, 1)
5095 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005096 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5097 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005098
5099 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5100 choices = [1, 2, 3]
5101 change = rng.choice(choices)
5102 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5103 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005104 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005105 oh -= scale_y_d
5106 assert oh > 0 # Should have been caught in agResize
5107 else:
5108 oh += scale_y_d
5109 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005110 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005111 ow -= scale_x_d
5112 assert ow > 0 # Should have been caught in agResize
5113 else:
5114 ow += scale_x_d
5115
Matthew Haddon848efb42021-09-09 12:30:53 +01005116 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005117 output_dims = [
5118 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005119 oh,
5120 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005121 input.shape[0],
5122 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005123 elif error_name == ErrorIf.BatchMismatch:
5124 output_dims = [
5125 input.shape[0] + rng.integers(1, 10),
5126 oh,
5127 ow,
5128 input.shape[3],
5129 ]
5130 elif error_name == ErrorIf.ChannelMismatch:
5131 output_dims = [
5132 input.shape[0],
5133 oh,
5134 ow,
5135 input.shape[3] + rng.integers(1, 10),
5136 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005137 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005138 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005139
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005140 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005141
5142 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005143 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005144 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005145
5146 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005147 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005148 if error_name == ErrorIf.ConvOutputShapeMismatch:
5149 choices = [1, 2, 3]
5150 change = rng.choice(choices)
5151 if change in [1, 3]:
5152 output_shape[1] = output_shape[1] + rng.choice(choices)
5153 if change in [2, 3]:
5154 output_shape[2] = output_shape[2] + rng.choice(choices)
5155
James Ward8b390432022-08-12 20:48:56 +01005156 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005157 # Pick some potentially correct output dtype if input type is incorrect
5158 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005159 else:
James Ward8b390432022-08-12 20:48:56 +01005160 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005161
5162 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005163 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005164 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005165 else:
5166 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005167 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005168 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005169
Kevin Cheng550ccc52021-03-03 11:21:43 -08005170 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005171
5172 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005173 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5174 outputs = []
5175
5176 assert ifm1.dtype == ifm2.dtype
5177 input_dtype = ifm1.dtype
5178
5179 if error_name != ErrorIf.FFTInputShapeMismatch:
5180 assert ifm1.shape == ifm2.shape
5181
5182 input_shape = ifm1.shape
5183 if error_name != ErrorIf.WrongRank:
5184 assert len(input_shape) == 3
5185
5186 output_shape = input_shape.copy()
5187 output_dtype = input_dtype
5188
5189 if error_name == ErrorIf.WrongOutputType:
5190 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005191 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005192 output_dtype = rng.choice(wrong_dtypes)
5193 elif error_name == ErrorIf.BatchMismatch:
5194 output_shape[0] += rng.integers(1, 10)
5195 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5196 modify_dim = rng.choice([1, 2])
5197 output_shape[modify_dim] += rng.integers(1, 10)
5198
5199 outputs.append(serializer.addOutput(output_shape, output_dtype))
5200 outputs.append(serializer.addOutput(output_shape, output_dtype))
5201 return outputs
5202
5203 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005204 def rfft2dOp(serializer, rng, value, error_name=None):
5205 outputs = []
5206
5207 input_shape = value.shape
5208 if error_name != ErrorIf.WrongRank:
5209 assert len(input_shape) == 3
5210
5211 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5212
5213 output_dtype = value.dtype
5214 if error_name == ErrorIf.WrongOutputType:
5215 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005216 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005217 output_dtype = rng.choice(wrong_dtypes)
5218 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005219 output_shape[0] += rng.integers(1, 10)
5220 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5221 modify_dim = rng.choice([1, 2])
5222 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005223
5224 outputs.append(serializer.addOutput(output_shape, output_dtype))
5225 outputs.append(serializer.addOutput(output_shape, output_dtype))
5226 return outputs