blob: d15f785b7a1e188aa0154c6b5d9ae18154a8bd39 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010012from generator.tosa_arg_gen import TosaArgGen
13from generator.tosa_arg_gen import TosaQuantGen
14from generator.tosa_arg_gen import TosaTensorGen
15from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000016from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010017from generator.tosa_error_if import TosaErrorIfArgGen
18from generator.tosa_error_if import TosaErrorValidator
19from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010020from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000021from tosa.DType import DType
22from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010023
Jeremy Johnson1271c442023-09-05 11:39:26 +010024TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
25// SPDX-License-Identifier: Apache-2.0
26// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
27"""
28
Matthew Haddonb724efc2021-08-25 16:40:29 +010029
Eric Kunzee5e26762020-10-13 16:11:07 -070030class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010031 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000032 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010033 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010034 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010035 TOSA_8K_LEVEL_MAX_KERNEL = 8192
36 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010037
Jeremy Johnson1271c442023-09-05 11:39:26 +010038 # Main compliance dot product statistical test range
39 TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6)
40 TOSA_MI_DOT_PRODUCT_MIN = 1000
41
Eric Kunzee5e26762020-10-13 16:11:07 -070042 def __init__(self, args):
43 self.args = args
44 self.basePath = args.output_dir
45 self.random_seed = args.random_seed
46 self.ser = None
47 self.rng = np.random.default_rng(self.random_seed)
48 self.createDynamicOpLists()
49 self.initOpListDefaults()
50 self.quantGen = TosaQuantGen()
51 # Force makeShape to do a specific starting shape
52 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010053 # Work out floating point range
54 self.random_fp_low = min(args.tensor_fp_value_range)
55 self.random_fp_high = max(args.tensor_fp_value_range)
Jeremy Johnson1271c442023-09-05 11:39:26 +010056 # JSON schema validation
57 self.descSchemaValidator = TestDescSchemaValidator()
Eric Kunzee5e26762020-10-13 16:11:07 -070058
59 def createSerializer(self, opName, testPath):
60 self.testPath = os.path.join(opName, testPath)
61
62 fullPath = os.path.join(self.basePath, self.testPath)
63 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010064 # Embed const data in the flatbuffer
65 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010066 if self.args.lazy_data_gen:
67 # Lazy data generation - so make constants files
68 constMode = ts.ConstMode.INPUTS
69 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010070 constMode = ts.ConstMode.EMBED_DUMP
71 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070072
73 def getSerializer(self):
74 return self.ser
75
Jeremy Johnson1271c442023-09-05 11:39:26 +010076 def serialize(self, testName, metaData=None):
77 path = Path(self.basePath) / self.testPath
78
79 # Write out TOSA flatbuffer binary
80 path_fb = path / f"{testName}.tosa"
81 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070082 fd.write(self.ser.serialize())
83
Jeremy Johnson1271c442023-09-05 11:39:26 +010084 # Get JSON descriptor from serializer
85 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
86
87 if metaData:
88 # Add extra meta data to desc.json
89 desc["meta"] = metaData
90
91 # Validate desc.json before we output it
92 self.descSchemaValidator.validate_config(desc)
93
94 if metaData:
95 if self.args.lazy_data_gen and "data_gen" in metaData:
96 # Output datagen meta data as CPP data
97 path_md = path / f"{testName}_meta_data_gen.cpp"
98 with path_md.open("w") as fd:
99 fd.write(TOSA_AUTOGENERATED_HEADER)
100 fd.write("// Test meta data for data generation setup\n\n")
101 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
102 json.dump(metaData["data_gen"], fd)
103 fd.write(')";\n\n')
104 if "compliance" in metaData:
105 # Output datagen meta data as CPP data
106 path_md = path / f"{testName}_meta_compliance.cpp"
107 with path_md.open("w") as fd:
108 fd.write(TOSA_AUTOGENERATED_HEADER)
109 fd.write("// Test meta data for compliance validation\n\n")
110 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
111 json.dump(metaData["compliance"], fd)
112 fd.write(')";\n\n')
113
114 # Write desc.json
115 path_desc = path / "desc.json"
116 with path_desc.open("w") as fd:
117 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700118
Matthew Haddon74567092021-07-16 15:38:20 +0100119 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000120 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100121 seed = self.random_seed + 1
122 self.rng = np.random.default_rng(seed)
123
Jeremy Johnson1271c442023-09-05 11:39:26 +0100124 def getDTypeRange(self, dtype, high_inclusive=False):
125 # Returns dtype value range boundaries (low, high)
126 # The high boundary is excluded in the range
127 # unless high_inclusive is True
128
129 if dtype in (DType.FP32, DType.FP16, DType.BF16):
130 return (self.random_fp_low, self.random_fp_high)
131 elif dtype == DType.BOOL:
132 rng = (0, 2)
133 elif dtype == DType.UINT8:
134 rng = (0, 256)
135 elif dtype == DType.UINT16:
136 rng = (0, 65536)
137 elif dtype == DType.INT4:
138 # TOSA specific INT4 weight range from -7 to 7
139 rng = (-7, 8)
140 elif dtype == DType.INT8:
141 rng = (-128, 128)
142 elif dtype == DType.INT16:
143 rng = (-32768, 32768)
144 elif dtype in (DType.INT32, DType.SHAPE):
145 # restricting too large value for SHAPE
146 rng = (-(1 << 31), (1 << 31))
147 elif dtype == DType.INT48:
148 rng = (-(1 << 47), (1 << 47))
149 else:
150 raise Exception("Unknown dtype: {}".format(dtype))
151
152 if not high_inclusive:
153 # Exclusive high: low <= range < high
154 return rng
155 else:
156 # Inclusive range: low <= range <= high
157 return (rng[0], rng[1] - 1)
158
Eric Kunzee5e26762020-10-13 16:11:07 -0700159 def getRandTensor(self, shape, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100160 low, high = self.getDTypeRange(dtype)
161
Eric Kunzee5e26762020-10-13 16:11:07 -0700162 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700163 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700164 elif dtype == DType.INT48:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100165 return np.int64(self.rng.integers(low=low, high=high, size=shape))
166 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
167 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
168
169 if dtype == DType.FP16:
170 return np.float16(f_tensor)
171 else:
172 f32_tensor = np.float32(f_tensor)
173 if dtype == DType.BF16:
174 # Floor the last 16 bits of each f32 value
175 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
176 else:
177 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100179 # All other integer types
180 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700181
Kevin Cheng989cb052021-04-28 16:29:44 -0700182 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700183 placeholders = []
184
Kevin Cheng989cb052021-04-28 16:29:44 -0700185 assert len(shape_list) == len(dtype_list)
186
Jeremy Johnson1271c442023-09-05 11:39:26 +0100187 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700188 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100189 if not self.args.lazy_data_gen:
190 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700191 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700192
193 return placeholders
194
Kevin Cheng989cb052021-04-28 16:29:44 -0700195 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700196 consts = []
197
Kevin Cheng989cb052021-04-28 16:29:44 -0700198 assert len(shape_list) == len(dtype_list)
199
Jeremy Johnson1271c442023-09-05 11:39:26 +0100200 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700201 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100202 if not self.args.lazy_data_gen:
203 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700204 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700205
206 return consts
207
208 def makeShape(self, rank):
209 if self.targetted_shape:
210 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800211 return np.int32(
212 self.rng.integers(
213 low=self.args.tensor_shape_range[0],
214 high=self.args.tensor_shape_range[1],
215 size=rank,
216 )
217 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700218
219 def setTargetShape(self, shape):
220 self.targetted_shape = shape
221
222 def randInt(self, low=0, high=256):
223 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
224
225 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100226 low, high = self.getDTypeRange(dtype)
227
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100228 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100229 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100230 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100231 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100232 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100233 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
234 return gtu.vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700235 elif dtype == DType.BOOL:
236 return self.rng.choice([False, True])
Eric Kunzee5e26762020-10-13 16:11:07 -0700237 elif dtype == DType.INT48:
Eric Kunzee5e26762020-10-13 16:11:07 -0700238 # Special size
239 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700240
241 return np.int32(self.rng.integers(low, high, size=1))[0]
242
243 def shapeStr(self, shape):
244
245 sStr = []
246 # Convert to strings
247 for i in shape:
248 sStr.append(str(i))
249
Kevin Cheng550ccc52021-03-03 11:21:43 -0800250 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700251
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100252 def typeStr(self, dtype):
253 if isinstance(dtype, list) or isinstance(dtype, tuple):
254 assert len(dtype) >= 2
255 strs = [self.typeStr(t) for t in dtype]
256 # Limit types to the first 2 as the 3rd is the accumulator
257 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700258 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100259 if dtype in gtu.DTYPE_ATTRIBUTES:
260 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700261 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100262 raise Exception(
263 "Unknown dtype, cannot convert to string: {}".format(dtype)
264 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700265
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100266 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100267 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100268 if dtype in gtu.DTYPE_ATTRIBUTES:
269 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700270 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100271 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700272
Luke Hutton57287132023-02-06 14:54:18 +0000273 def constrictBatchSize(self, shape):
274 # Limit the batch size unless an explicit target shape set
275 if self.args.max_batch_size and not self.args.target_shapes:
276 shape[0] = min(shape[0], self.args.max_batch_size)
277 return shape
278
James Ward30124a82023-02-02 14:56:33 +0000279 def makeDimension(self):
280 return self.randInt(
281 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
282 )
283
Jeremy Johnson1271c442023-09-05 11:39:26 +0100284 def tensorComplianceMetaData(self, op, argsDict, outputTensor, errorName):
285 if errorName:
286 # No compliance for error tests
287 return None
288 # Create compliance meta data for expected output tensor
289 compliance_tens = {"mode": None}
290 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
291 mode = gtu.ComplianceMode.DOT_PRODUCT
292 compliance_tens["dot_product_info"] = {
293 "s": argsDict["s"],
294 "ks": argsDict["ks"],
295 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
296 }
297 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
298 mode = gtu.ComplianceMode.FP_SPECIAL
299 elif "compliance" in op and "ulp" in op["compliance"]:
300 mode = gtu.ComplianceMode.ULP
301 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
302 elif op["op"] == Op.REDUCE_PRODUCT:
303 mode = gtu.ComplianceMode.REDUCE_PRODUCT
304 else:
305 mode = gtu.ComplianceMode.EXACT
306 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
307
308 return compliance_tens
309
310 # Build Op functions
311 # Create the output tensor (calling OutputShaper as needed)
312 # Do final tweaks to attributes (if necessary for errorIf)
313 # Add Op into graph
314 # Return resulting tensor information or BuildInfo
315
316 class BuildInfo:
317 """Enhanced build information containing result tensor and associated compliance dict."""
318
319 def __init__(self, resultTensor, complianceDict):
320 self.resultTensor = resultTensor
321 self.complianceDict = complianceDict
Eric Kunzee5e26762020-10-13 16:11:07 -0700322
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100323 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
324 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
325
Matthew Haddon848efb42021-09-09 12:30:53 +0100326 # build_placeholder returns an int, ABS/other ops does not
327 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000328 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100329 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000330 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000331 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100332 return result_tens
333
334 # Ensure new output type has correct qinfo
335 if error_name == ErrorIf.WrongOutputType:
336 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000337 qinfo = [
338 TosaQuantGen.getZeroPoint(self, a.dtype),
339 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
340 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100341
342 # Invalidate Input/Output list for error if checks.
343 input_list = [a.name]
344 output_list = [result_tens.name]
345 pCount, cCount = op["operands"]
346 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000347 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
348 self, error_name, input_list, output_list
349 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100350
Les Bell729b0352021-11-24 10:28:21 +0000351 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100352 self.ser,
353 validator_fcns,
354 error_name,
355 op=op,
356 input_dtype=a.dtype,
357 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000358 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000359 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100360 input_list=input_list,
361 output_list=output_list,
362 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000363 ):
364 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100365
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000366 attr = None
367 if op["op"] == Op.NEGATE:
368 attr = ts.TosaSerializerAttribute()
369 attr.NegateAttribute(qinfo[0], qinfo[1])
370
371 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700372 return result_tens
373
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100374 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000375 result_tens = OutputShaper.binaryBroadcastOp(
376 self.ser, self.rng, a, b, error_name
377 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100378
379 # Invalidate Input/Output list for error if checks.
380 input_list = [a.name, b.name]
381 output_list = [result_tens.name]
382 pCount, cCount = op["operands"]
383 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000384 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
385 self, error_name, input_list, output_list
386 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100387
Les Bell729b0352021-11-24 10:28:21 +0000388 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100389 self.ser,
390 validator_fcns,
391 error_name,
392 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 input1=a,
394 input2=b,
395 input_dtype=a.dtype,
396 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000397 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100398 input_list=input_list,
399 output_list=output_list,
400 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000401 ):
402 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100403
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000404 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700405 return result_tens
406
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100407 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700408 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000409 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700410 return result_tens
411
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000412 def build_arithmetic_right_shift(
413 self, op, a, b, round, validator_fcns=None, error_name=None
414 ):
415 result_tens = OutputShaper.binaryBroadcastOp(
416 self.ser, self.rng, a, b, error_name
417 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100418
419 # Invalidate Input/Output list for error if checks.
420 input_list = [a.name, b.name]
421 output_list = [result_tens.name]
422 pCount, cCount = op["operands"]
423 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000424 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
425 self, error_name, input_list, output_list
426 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100427
Les Bell729b0352021-11-24 10:28:21 +0000428 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100429 self.ser,
430 validator_fcns,
431 error_name,
432 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000433 input1=a,
434 input2=b,
435 input_dtype=a.dtype,
436 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000437 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100438 input_list=input_list,
439 output_list=output_list,
440 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000441 ):
442 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800443
444 attr = ts.TosaSerializerAttribute()
445 attr.ArithmeticRightShiftAttribute(round)
446
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000447 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800448 return result_tens
449
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100450 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000451 result_tens = OutputShaper.binaryBroadcastOp(
452 self.ser, self.rng, a, b, error_name
453 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700454
455 # Special for multiply:
456 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100457 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700458 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100459 if error_name == ErrorIf.WrongOutputType:
460 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
461 outputDType = self.rng.choice(all_dtypes)
462 result_tens.setDtype(outputDType)
463
464 # Invalidate Input/Output list for error if checks.
465 input_list = [a.name, b.name]
466 output_list = [result_tens.name]
467 pCount, cCount = op["operands"]
468 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000469 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
470 self, error_name, input_list, output_list
471 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100472
Les Bell729b0352021-11-24 10:28:21 +0000473 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100474 self.ser,
475 validator_fcns,
476 error_name,
477 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000478 input1=a,
479 input2=b,
480 input_dtype=a.dtype,
481 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000482 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100483 input_list=input_list,
484 output_list=output_list,
485 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000486 ):
487 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700488
Kevin Chengaee1fac2020-11-11 13:54:06 -0800489 attr = ts.TosaSerializerAttribute()
490 attr.MulAttribute(shift)
491
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000492 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700493 return result_tens
494
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100495 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
496 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700497
Kevin Chengfe392ce2021-10-18 21:51:55 +0000498 attr = ts.TosaSerializerAttribute()
499 attr.TableAttribute(table)
500
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100501 # Invalidate Input/Output list for error if checks.
502 input_list = [a.name]
503 output_list = [result_tens.name]
504 pCount, cCount = op["operands"]
505 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000506 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
507 self, error_name, input_list, output_list
508 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100509
Les Bell729b0352021-11-24 10:28:21 +0000510 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100511 self.ser,
512 validator_fcns,
513 error_name,
514 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000515 input_shape=a.shape,
516 input_dtype=a.dtype,
517 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000518 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100519 input_list=input_list,
520 output_list=output_list,
521 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000522 ):
523 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100524
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000525 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700526
527 return result_tens
528
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100529 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
530 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
531
532 # Invalidate Input/Output list for error if checks.
533 input_list = [cond.name, a.name, b.name]
534 output_list = [result_tens.name]
535 pCount, cCount = op["operands"]
536 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000537 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
538 self, error_name, input_list, output_list
539 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100540
Les Bell729b0352021-11-24 10:28:21 +0000541 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100542 self.ser,
543 validator_fcns,
544 error_name,
545 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000546 input1=cond,
547 input2=a,
548 input3=b,
549 input_shape=a.shape,
550 input_dtype=a.dtype,
551 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000552 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100553 input_list=input_list,
554 output_list=output_list,
555 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000556 ):
557 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100558
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000559 self.ser.addOperator(
560 op["op"],
561 input_list,
562 output_list,
563 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700564 return result_tens
565
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100566 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000567 result_tens = OutputShaper.binaryComparisonOp(
568 self.ser, self.rng, a, b, error_name
569 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100570
571 # Invalidate Input/Output list for error if checks.
572 input_list = [a.name, b.name]
573 output_list = [result_tens.name]
574 pCount, cCount = op["operands"]
575 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000576 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
577 self, error_name, input_list, output_list
578 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100579
Les Bell729b0352021-11-24 10:28:21 +0000580 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100581 self.ser,
582 validator_fcns,
583 error_name,
584 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000585 input1=a,
586 input2=b,
587 input_shape=a.shape,
588 input_dtype=a.dtype,
589 output_shape=result_tens.shape,
590 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000591 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100592 input_list=input_list,
593 output_list=output_list,
594 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000595 ):
596 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100597
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000598 self.ser.addOperator(
599 op["op"],
600 input_list,
601 output_list,
602 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700603 return result_tens
604
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100605 def build_argmax(self, op, a, axis, validator_fcns, error_name):
606 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
607
608 # Invalidate Input/Output list for error if checks.
609 input_list = [a.name]
610 output_list = [result_tens.name]
611 pCount, cCount = op["operands"]
612 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000613 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
614 self, error_name, input_list, output_list
615 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100616
Les Bell729b0352021-11-24 10:28:21 +0000617 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100618 self.ser,
619 validator_fcns,
620 error_name,
621 op=op,
622 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000623 input_shape=a.shape,
624 input_dtype=a.dtype,
625 output_shape=result_tens.shape,
626 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000627 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100628 input_list=input_list,
629 output_list=output_list,
630 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000631 ):
632 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700633
634 attr = ts.TosaSerializerAttribute()
635 attr.AxisAttribute(axis)
636
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000637 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700638 return result_tens
639
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000640 def build_pool2d(
641 self,
642 op,
643 input,
James Ward8b390432022-08-12 20:48:56 +0100644 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000645 stride,
646 pad,
647 kernel,
648 validator_fcns=None,
649 error_name=None,
650 qinfo=None,
651 ):
652 result_tens = OutputShaper.pool2dOp(
653 self.ser, self.rng, input, kernel, stride, pad, error_name
654 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100655
656 # Ensure new output type has correct qinfo
657 if error_name == ErrorIf.WrongInputType:
658 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000659 qinfo = [
660 TosaQuantGen.getZeroPoint(self, input.dtype),
661 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
662 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100663
664 # Invalidate Input/Output list for error if checks.
665 input_list = [input.name]
666 output_list = [result_tens.name]
667 pCount, cCount = op["operands"]
668 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000669 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
670 self, error_name, input_list, output_list
671 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100672
Les Bell729b0352021-11-24 10:28:21 +0000673 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100674 self.ser,
675 validator_fcns,
676 error_name,
677 op=op,
678 input_shape=input.shape,
679 input_dtype=input.dtype,
680 output_shape=result_tens.shape,
681 output_dtype=result_tens.dtype,
682 kernel=kernel,
683 stride=stride,
684 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000685 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000686 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100687 input_list=input_list,
688 output_list=output_list,
689 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000690 ):
691 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700692
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000693 if qinfo is None:
694 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700695
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000696 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100697 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000698
699 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700700 return result_tens
701
James Ward8b390432022-08-12 20:48:56 +0100702 def build_maxpool2d(
703 self,
704 op,
705 input,
706 stride,
707 pad,
708 kernel,
709 validator_fcns=None,
710 error_name=None,
711 qinfo=None,
712 ):
713 # Same as build_pool2d but manually sets accum_dtype value
714 # (maxpool has no accum_dtype)
715 return self.build_pool2d(
716 op,
717 input,
718 DType.UNKNOWN,
719 stride,
720 pad,
721 kernel,
722 validator_fcns,
723 error_name,
724 qinfo,
725 )
726
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000727 def build_conv2d(
728 self,
729 op,
730 ifm,
731 filter,
732 bias,
James Ward8b390432022-08-12 20:48:56 +0100733 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000734 strides,
735 padding,
736 dilations,
737 validator_fcns=None,
738 error_name=None,
739 qinfo=None,
740 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800741 assert len(padding) == 4
742 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100743 self.ser,
744 self.rng,
745 ifm,
746 filter,
747 accum_dtype,
748 strides,
749 padding,
750 dilations,
751 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000752 )
753
754 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000755 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
756 DType.INT8,
757 DType.UINT8,
758 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000759 qinfo = [
760 TosaQuantGen.getZeroPoint(self, ifm.dtype),
761 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
762 ]
Les Bell0e027d42021-11-09 14:42:14 +0000763
764 # Invalidate Input/Output list for error_if checks.
765 input_list = [ifm.name, filter.name, bias.name]
766 output_list = [result_tens.name]
767 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000768 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
769 self, error_name, input_list, output_list
770 )
Les Bell0e027d42021-11-09 14:42:14 +0000771
Les Bell729b0352021-11-24 10:28:21 +0000772 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000773 self.ser,
774 validator_fcns,
775 error_name,
776 op=op,
777 input_dtype=ifm.dtype,
778 weight_dtype=filter.dtype,
779 output_dtype=result_tens.dtype,
780 qinfo=qinfo,
781 input_list=input_list,
782 num_operands=num_operands,
783 output_list=output_list,
784 pad=padding,
785 stride=strides,
786 dilation=dilations,
787 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100788 weight_shape=filter.shape,
789 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000790 ):
791 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700792
793 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000794 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700795
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000796 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700797 return result_tens
798
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000799 def build_conv3d(
800 self,
801 op,
802 ifm,
803 filter,
804 bias,
James Ward8b390432022-08-12 20:48:56 +0100805 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000806 strides,
807 padding,
808 dilations,
809 validator_fcns=None,
810 error_name=None,
811 qinfo=None,
812 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700813 assert len(padding) == 6
814 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100815 self.ser,
816 self.rng,
817 ifm,
818 filter,
819 accum_dtype,
820 strides,
821 padding,
822 dilations,
823 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000824 )
825
826 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000827 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
828 DType.INT8,
829 DType.UINT8,
830 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000831 qinfo = [
832 TosaQuantGen.getZeroPoint(self, ifm.dtype),
833 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
834 ]
Les Bell0e027d42021-11-09 14:42:14 +0000835
836 # Invalidate Input/Output list for error_if checks.
837 input_list = [ifm.name, filter.name, bias.name]
838 output_list = [result_tens.name]
839 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000840 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
841 self, error_name, input_list, output_list
842 )
Les Bell0e027d42021-11-09 14:42:14 +0000843
Les Bell729b0352021-11-24 10:28:21 +0000844 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000845 self.ser,
846 validator_fcns,
847 error_name,
848 op=op,
849 input_dtype=ifm.dtype,
850 weight_dtype=filter.dtype,
851 output_dtype=result_tens.dtype,
852 qinfo=qinfo,
853 input_list=input_list,
854 num_operands=num_operands,
855 output_list=output_list,
856 pad=padding,
857 stride=strides,
858 dilation=dilations,
859 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100860 weight_shape=filter.shape,
861 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000862 ):
863 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700864
865 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000866 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700867
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000868 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700869 return result_tens
870
Kevin Cheng550ccc52021-03-03 11:21:43 -0800871 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000872 self,
873 op,
874 ifm,
875 filter,
876 bias,
James Ward8b390432022-08-12 20:48:56 +0100877 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000878 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700879 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000880 output_shape,
881 validator_fcns=None,
882 error_name=None,
883 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800884 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700885 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000886 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100887 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000888 )
Les Bell0e027d42021-11-09 14:42:14 +0000889
890 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000891 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
892 DType.INT8,
893 DType.UINT8,
894 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000895 qinfo = [
896 TosaQuantGen.getZeroPoint(self, ifm.dtype),
897 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
898 ]
Les Bell0e027d42021-11-09 14:42:14 +0000899
900 # Invalidate Input/Output list for error_if checks.
901 input_list = [ifm.name, filter.name, bias.name]
902 output_list = [result_tens.name]
903 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000904 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
905 self, error_name, input_list, output_list
906 )
Les Bell0e027d42021-11-09 14:42:14 +0000907
Les Bell729b0352021-11-24 10:28:21 +0000908 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000909 self.ser,
910 validator_fcns,
911 error_name,
912 op=op,
913 input_dtype=ifm.dtype,
914 weight_dtype=filter.dtype,
915 output_dtype=result_tens.dtype,
916 qinfo=qinfo,
917 input_list=input_list,
918 num_operands=num_operands,
919 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700920 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000921 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000922 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100923 weight_shape=filter.shape,
924 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000925 ):
926 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700927
928 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000929 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700930
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000931 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700932 return result_tens
933
Kevin Cheng550ccc52021-03-03 11:21:43 -0800934 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000935 self,
936 op,
937 ifm,
938 filter,
939 bias,
James Ward8b390432022-08-12 20:48:56 +0100940 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000941 strides,
942 padding,
943 dilations,
944 validator_fcns=None,
945 error_name=None,
946 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800947 ):
948 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100949 self.ser,
950 self.rng,
951 ifm,
952 filter,
953 accum_dtype,
954 strides,
955 padding,
956 dilations,
957 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000958 )
959
960 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000961 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
962 DType.INT8,
963 DType.UINT8,
964 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000965 qinfo = [
966 TosaQuantGen.getZeroPoint(self, ifm.dtype),
967 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
968 ]
Les Bell0e027d42021-11-09 14:42:14 +0000969
970 # Invalidate Input/Output list for error_if checks.
971 input_list = [ifm.name, filter.name, bias.name]
972 output_list = [result_tens.name]
973 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000974 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
975 self, error_name, input_list, output_list
976 )
Les Bell0e027d42021-11-09 14:42:14 +0000977
Les Bell729b0352021-11-24 10:28:21 +0000978 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000979 self.ser,
980 validator_fcns,
981 error_name,
982 op=op,
983 input_dtype=ifm.dtype,
984 weight_dtype=filter.dtype,
985 output_dtype=result_tens.dtype,
986 qinfo=qinfo,
987 input_list=input_list,
988 num_operands=num_operands,
989 output_list=output_list,
990 pad=padding,
991 stride=strides,
992 dilation=dilations,
993 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100994 weight_shape=filter.shape,
995 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000996 ):
997 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700998
999 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001000 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07001001
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001002 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001003 return result_tens
1004
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001005 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001006 self,
1007 op,
1008 ifm,
1009 filter,
1010 bias,
1011 accum_dtype,
1012 validator_fcns=None,
1013 error_name=None,
1014 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001015 ):
1016 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001017 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001018 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001019
1020 # Invalidate Input/Output list for error if checks.
1021 input_list = [ifm.name, filter.name, bias.name]
1022 output_list = [result_tens.name]
1023 pCount, cCount = op["operands"]
1024 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001025 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1026 self, error_name, input_list, output_list
1027 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001028
Les Bell729b0352021-11-24 10:28:21 +00001029 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001030 self.ser,
1031 validator_fcns,
1032 error_name,
1033 op=op,
1034 input_shape=ifm.shape,
1035 input_dtype=ifm.dtype,
1036 weight_dtype=filter.dtype,
1037 output_shape=result_tens.shape,
1038 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001039 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001040 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001041 input_list=input_list,
1042 output_list=output_list,
1043 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001044 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001045 ):
1046 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001047
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001048 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001049 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001050
1051 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001052 return result_tens
1053
James Ward8b390432022-08-12 20:48:56 +01001054 def build_matmul(
Jeremy Johnson1271c442023-09-05 11:39:26 +01001055 self, op, a, b, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001056 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001057 accum_dtype = args_dict["acc_type"]
1058 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001059 self.ser, self.rng, a, b, accum_dtype, error_name
1060 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001061
1062 # Invalidate Input/Output list for error if checks.
1063 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001064 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001065 pCount, cCount = op["operands"]
1066 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001067 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1068 self, error_name, input_list, output_list
1069 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001070
Les Bell729b0352021-11-24 10:28:21 +00001071 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001072 self.ser,
1073 validator_fcns,
1074 error_name,
1075 op=op,
1076 input_shape=a.shape,
1077 input_dtype=a.dtype,
1078 input2_shape=b.shape,
1079 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001080 output_shape=result_tensor.shape,
1081 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001082 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001083 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001084 input_list=input_list,
1085 output_list=output_list,
1086 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001087 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001088 ):
1089 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001090
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001091 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001092 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001093
1094 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001095
1096 compliance = self.tensorComplianceMetaData(
1097 op, args_dict, result_tensor, error_name
1098 )
1099
1100 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001101
Matthew Haddond6ce7252021-09-29 15:35:44 +01001102 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1103 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1104
1105 # Invalidate Input/Output list for error if checks.
1106 input_list = [a.name]
1107 output_list = [result_tens.name]
1108 pCount, cCount = op["operands"]
1109 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001110 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1111 self, error_name, input_list, output_list
1112 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001113
Les Bell729b0352021-11-24 10:28:21 +00001114 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001115 self.ser,
1116 validator_fcns,
1117 error_name,
1118 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001119 axis=axis,
1120 input_shape=a.shape,
1121 output_shape=result_tens.shape,
1122 input_dtype=a.dtype,
1123 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001124 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001125 input_list=input_list,
1126 output_list=output_list,
1127 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001128 ):
1129 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001130
1131 attr = ts.TosaSerializerAttribute()
1132 attr.AxisAttribute(axis)
1133
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001134 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001135 return result_tens
1136
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001137 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1138 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001139
Jeremy Johnson18e26662021-07-22 16:15:29 +01001140 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001141
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001142 if error_name == ErrorIf.MaxSmallerMin:
1143 # Make sure the numbers are different to invoke this error
1144 while v[0] == v[1]:
1145 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1146 max_val = min(v)
1147 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001148 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001149 max_val = max(v)
1150 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001151
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001152 # Invalidate Input/Output list for error if checks.
1153 input_list = [a.name]
1154 output_list = [result_tens.name]
1155 pCount, cCount = op["operands"]
1156 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001157 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1158 self, error_name, input_list, output_list
1159 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001160
Les Bell729b0352021-11-24 10:28:21 +00001161 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001162 self.ser,
1163 validator_fcns,
1164 error_name,
1165 op=op,
1166 max_val=max_val,
1167 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001168 input_shape=a.shape,
1169 output_shape=result_tens.shape,
1170 input_dtype=a.dtype,
1171 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001172 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001173 input_list=input_list,
1174 output_list=output_list,
1175 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001176 ):
1177 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001178
1179 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001180 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1181 if a.dtype == DType.FP16:
1182 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1183 min_val = min_val.astype(np.float32)
1184 max_val = max_val.astype(np.float32)
1185
1186 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001187 else:
James Ward34071252022-12-07 15:48:47 +00001188 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001189
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001190 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001191 return result_tens
1192
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001193 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1194 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001195 attr = ts.TosaSerializerAttribute()
1196
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001197 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001198
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001199 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001200 return result_tens
1201
1202 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001203 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1204 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001205
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001206 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001207 return result_tens
1208
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001209 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1210 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1211
1212 # Invalidate Input/Output list for error if checks.
1213 input_list = [a.name]
1214 output_list = [result_tens.name]
1215 pCount, cCount = op["operands"]
1216 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001217 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1218 self, error_name, input_list, output_list
1219 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001220
Les Bell729b0352021-11-24 10:28:21 +00001221 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001222 self.ser,
1223 validator_fcns,
1224 error_name,
1225 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001226 input_shape=a.shape,
1227 output_shape=result_tens.shape,
1228 input_dtype=a.dtype,
1229 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001230 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001231 input_list=input_list,
1232 output_list=output_list,
1233 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001234 ):
1235 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001236
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001237 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001238 return result_tens
1239
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001240 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1241 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1242
1243 # Invalidate Input/Output list for error if checks.
1244 input_list = [a.name]
1245 output_list = [result_tens.name]
1246 pCount, cCount = op["operands"]
1247 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001248 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1249 self, error_name, input_list, output_list
1250 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001251
Les Bell729b0352021-11-24 10:28:21 +00001252 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001253 self.ser,
1254 validator_fcns,
1255 error_name,
1256 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001257 input_shape=a.shape,
1258 output_shape=result_tens.shape,
1259 input_dtype=a.dtype,
1260 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001261 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001262 input_list=input_list,
1263 output_list=output_list,
1264 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001265 ):
1266 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001267
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001268 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001269 return result_tens
1270
Won Jeon78155c62023-06-10 00:20:04 +00001271 def build_erf(self, op, a, validator_fcns=None, error_name=None):
1272 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1273
1274 # Invalidate Input/Output list for error if checks.
1275 input_list = [a.name]
1276 output_list = [result_tens.name]
1277 pCount, cCount = op["operands"]
1278 num_operands = pCount + cCount
1279 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1280 self, error_name, input_list, output_list
1281 )
1282
1283 if not TosaErrorValidator.evValidateErrorIfs(
1284 self.ser,
1285 validator_fcns,
1286 error_name,
1287 op=op,
1288 input_shape=a.shape,
1289 output_shape=result_tens.shape,
1290 input_dtype=a.dtype,
1291 output_dtype=result_tens.dtype,
1292 result_tensors=[result_tens],
1293 input_list=input_list,
1294 output_list=output_list,
1295 num_operands=num_operands,
1296 ):
1297 return None
1298
1299 self.ser.addOperator(op["op"], input_list, output_list)
1300 return result_tens
1301
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001302 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1303 if error_name != ErrorIf.WrongInputType:
1304 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001305
1306 # To store variable length list of input tensors we need to store axis along with it
1307 axis = a[-1]
1308 a = a[:-1]
1309
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 result_tens = OutputShaper.concatOp(
1311 self.ser, self.rng, axis, *a, error_name=error_name
1312 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001313
Matthew Haddon818ab902021-07-27 09:12:49 +01001314 input_tensor_names = []
1315 for tensor in a:
1316 input_tensor_names.append(tensor.name)
1317
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001318 # Invalidate Input/Output list for error if checks.
1319 input_list = input_tensor_names
1320 output_list = [result_tens.name]
1321 pCount, cCount = op["operands"]
1322 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001323 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1324 self, error_name, input_list, output_list
1325 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001326
Les Bell729b0352021-11-24 10:28:21 +00001327 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001328 self.ser,
1329 validator_fcns,
1330 error_name,
1331 op=op,
1332 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001333 input_shape=a[0].shape,
1334 output_shape=result_tens.shape,
1335 input_dtype=a[0].dtype,
1336 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001337 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001338 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001339 input_list=input_list,
1340 output_list=output_list,
1341 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001342 ):
1343 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001344
1345 attr = ts.TosaSerializerAttribute()
1346 attr.AxisAttribute(axis)
1347
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001348 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001349 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001350
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001351 def build_pad(
1352 self,
1353 op,
1354 a,
1355 padding,
1356 pad_const_int,
1357 pad_const_float,
1358 validator_fcns=None,
1359 error_name=None,
1360 qinfo=None,
1361 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001362 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001363
Kevin Chengfe392ce2021-10-18 21:51:55 +00001364 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001365 attr.PadAttribute(
1366 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1367 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001368
Matthew Haddone807aae2021-10-11 18:12:58 +01001369 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001370 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001371 output_list = [result_tens.name]
1372 pCount, cCount = op["operands"]
1373 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001374 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1375 self, error_name, input_list, output_list
1376 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001377
Les Bell729b0352021-11-24 10:28:21 +00001378 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001379 self.ser,
1380 validator_fcns,
1381 error_name,
1382 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001383 input_shape=a.shape,
1384 output_shape=result_tens.shape,
1385 input_dtype=a.dtype,
1386 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001387 pad=padding,
1388 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001389 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001390 input_list=input_list,
1391 output_list=output_list,
1392 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001393 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001394 ):
1395 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001396
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001397 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001398 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001399
Won Jeona21b2e82023-08-10 10:33:01 +00001400 def build_dim(
1401 self,
1402 op,
1403 a,
1404 axis,
1405 validator_fcns=None,
1406 error_name=None,
1407 qinfo=None,
1408 ):
1409 result_tens = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
1410
1411 # Invalidate Input/Output list for error if checks.
1412 input_list = [a.name]
1413 output_list = [result_tens.name]
1414 pCount, cCount = op["operands"]
1415 num_operands = pCount + cCount
1416 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1417 self, error_name, input_list, output_list
1418 )
1419
1420 if not TosaErrorValidator.evValidateErrorIfs(
1421 self.ser,
1422 validator_fcns,
1423 error_name,
1424 op=op,
1425 axis=axis,
1426 input_shape=a.shape,
1427 input_dtype=a.dtype,
1428 output_shape=result_tens.shape,
1429 output_dtype=result_tens.dtype,
1430 result_tensors=[result_tens],
1431 input_list=input_list,
1432 output_list=output_list,
1433 num_operands=num_operands,
1434 ):
1435 return None
1436
1437 attr = ts.TosaSerializerAttribute()
1438 attr.AxisAttribute(axis)
1439
1440 self.ser.addOperator(op["op"], input_list, output_list, attr)
1441 return result_tens
1442
Matthew Haddone807aae2021-10-11 18:12:58 +01001443 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001444 result_tens = OutputShaper.reshapeOp(
1445 self.ser, self.rng, a, newShape, error_name
1446 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001447
1448 # Invalidate Input/Output list for error if checks.
1449 input_list = [a.name]
1450 output_list = [result_tens.name]
1451 pCount, cCount = op["operands"]
1452 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001453 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1454 self, error_name, input_list, output_list
1455 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001456
Les Bell729b0352021-11-24 10:28:21 +00001457 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001458 self.ser,
1459 validator_fcns,
1460 error_name,
1461 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001462 input_shape=a.shape,
1463 output_shape=result_tens.shape,
1464 input_dtype=a.dtype,
1465 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001466 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001467 input_list=input_list,
1468 output_list=output_list,
1469 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001470 ):
1471 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001472
1473 attr = ts.TosaSerializerAttribute()
1474 attr.ReshapeAttribute(newShape)
1475
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001476 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001477 return result_tens
1478
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001479 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1480 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1481
1482 # Invalidate Input/Output list for error if checks.
1483 input_list = [a.name]
1484 output_list = [result_tens.name]
1485 pCount, cCount = op["operands"]
1486 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001487 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1488 self, error_name, input_list, output_list
1489 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001490
Les Bell729b0352021-11-24 10:28:21 +00001491 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001492 self.ser,
1493 validator_fcns,
1494 error_name,
1495 op=op,
1496 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001497 input_shape=a.shape,
1498 output_shape=result_tens.shape,
1499 input_dtype=a.dtype,
1500 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001501 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001502 input_list=input_list,
1503 output_list=output_list,
1504 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001505 ):
1506 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001507
1508 attr = ts.TosaSerializerAttribute()
1509 attr.AxisAttribute(axis)
1510
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001511 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001512 return result_tens
1513
Matthew Haddone807aae2021-10-11 18:12:58 +01001514 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1515 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001516
Kevin Chengfe392ce2021-10-18 21:51:55 +00001517 attr = ts.TosaSerializerAttribute()
1518 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001519
Matthew Haddone807aae2021-10-11 18:12:58 +01001520 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001521 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001522 output_list = [result_tens.name]
1523 pCount, cCount = op["operands"]
1524 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001525 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1526 self, error_name, input_list, output_list
1527 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001528
Les Bell729b0352021-11-24 10:28:21 +00001529 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001530 self.ser,
1531 validator_fcns,
1532 error_name,
1533 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001534 input_shape=a.shape,
1535 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001536 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001537 input_dtype=a.dtype,
1538 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001539 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001540 input_list=input_list,
1541 output_list=output_list,
1542 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001543 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001544 ):
1545 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001546
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001547 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001548 return result_tens
1549
Matthew Haddone807aae2021-10-11 18:12:58 +01001550 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001551 result_tens = OutputShaper.sliceOp(
1552 self.ser, self.rng, a, start, size, error_name
1553 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001554
1555 # Invalidate Input/Output list for error if checks.
1556 input_list = [a.name]
1557 output_list = [result_tens.name]
1558 pCount, cCount = op["operands"]
1559 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001560 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1561 self, error_name, input_list, output_list
1562 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001563
Les Bell729b0352021-11-24 10:28:21 +00001564 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001565 self.ser,
1566 validator_fcns,
1567 error_name,
1568 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001569 input_shape=a.shape,
1570 output_shape=result_tens.shape,
1571 input_dtype=a.dtype,
1572 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001573 start=start,
1574 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001575 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001576 input_list=input_list,
1577 output_list=output_list,
1578 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001579 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001580 ):
1581 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001582
1583 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001584 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001585
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001586 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001587 return result_tens
1588
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001589 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1590 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1591
1592 # Invalidate Input/Output list for error if checks.
1593 input_list = [a.name]
1594 output_list = [result_tens.name]
1595 pCount, cCount = op["operands"]
1596 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001597 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1598 self, error_name, input_list, output_list
1599 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001600
Les Bell729b0352021-11-24 10:28:21 +00001601 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001602 self.ser,
1603 validator_fcns,
1604 error_name,
1605 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001606 input_shape=a.shape,
1607 output_shape=result_tens.shape,
1608 input_dtype=a.dtype,
1609 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001610 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001611 input_list=input_list,
1612 output_list=output_list,
1613 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001614 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001615 ):
1616 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001617
1618 attr = ts.TosaSerializerAttribute()
1619 attr.TileAttribute(multiples)
1620
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001621 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001622 return result_tens
1623
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001624 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001625
1626 # Create a new indicies tensor
1627 # here with data that doesn't exceed the dimensions of the values tensor
1628
Kevin Cheng550ccc52021-03-03 11:21:43 -08001629 K = values.shape[1] # K
1630 W = self.randInt(
1631 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1632 ) # W
1633 indicies_arr = np.int32(
1634 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1635 ) # (N, W)
1636 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001637
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001638 result_tens = OutputShaper.gatherOp(
1639 self.ser, self.rng, values, indicies, error_name
1640 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001641
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001642 # Invalidate Input/Output list for error if checks.
1643 input_list = [values.name, indicies.name]
1644 output_list = [result_tens.name]
1645 pCount, cCount = op["operands"]
1646 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001647 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1648 self, error_name, input_list, output_list
1649 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001650
Les Bell729b0352021-11-24 10:28:21 +00001651 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001652 self.ser,
1653 validator_fcns,
1654 error_name,
1655 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001656 input_shape=values.shape,
1657 output_shape=result_tens.shape,
1658 input_dtype=values.dtype,
1659 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001660 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001661 input_list=input_list,
1662 output_list=output_list,
1663 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001664 ):
1665 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001666
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001667 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001668
1669 return result_tens
1670
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001671 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001672
1673 # Create a new indicies tensor
1674 # here with data that doesn't exceed the dimensions of the values_in tensor
1675
Kevin Cheng550ccc52021-03-03 11:21:43 -08001676 K = values_in.shape[1] # K
1677 W = input.shape[1] # W
1678 indicies_arr = np.int32(
1679 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1680 ) # (N, W)
1681 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001682
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001683 result_tens = OutputShaper.scatterOp(
1684 self.ser, self.rng, values_in, indicies, input, error_name
1685 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001686
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001687 # Invalidate Input/Output list for error if checks.
1688 input_list = [values_in.name, indicies.name, input.name]
1689 output_list = [result_tens.name]
1690 pCount, cCount = op["operands"]
1691 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001692 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1693 self, error_name, input_list, output_list
1694 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001695
Les Bell729b0352021-11-24 10:28:21 +00001696 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001697 self.ser,
1698 validator_fcns,
1699 error_name,
1700 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001701 input_shape=values_in.shape,
1702 output_shape=result_tens.shape,
1703 input_dtype=values_in.dtype,
1704 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001705 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001706 input_list=input_list,
1707 output_list=output_list,
1708 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001709 ):
1710 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001711
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001712 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001713
Kevin Cheng77d0f762020-11-24 10:26:32 -08001714 return result_tens
1715
Kevin Cheng550ccc52021-03-03 11:21:43 -08001716 def build_resize(
1717 self,
1718 op,
1719 input,
1720 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001721 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001722 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001723 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001724 input_dtype,
1725 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001726 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001727 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001728 ):
1729 result_tens = OutputShaper.resizeOp(
1730 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001731 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001732 input,
1733 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001734 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001735 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001736 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001737 input_dtype,
1738 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001739 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001740 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001741
Matthew Haddon848efb42021-09-09 12:30:53 +01001742 # Invalidate Input/Output list for error if checks.
1743 input_list = [input.name]
1744 output_list = [result_tens.name]
1745 pCount, cCount = op["operands"]
1746 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001747 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1748 self, error_name, input_list, output_list
1749 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001750
Les Bell729b0352021-11-24 10:28:21 +00001751 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001752 self.ser,
1753 validator_fcns,
1754 error_name,
1755 op=op,
1756 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001757 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001758 input_dtype=input_dtype,
1759 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001760 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001761 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001762 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001763 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001764 input_list=input_list,
1765 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001766 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001767 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001768 ):
1769 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001770
Eric Kunzee5e26762020-10-13 16:11:07 -07001771 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001772
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001773 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001774
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001775 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001776 return result_tens
1777
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001778 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1779 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1780 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001781 self.ser.addOperator(
1782 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1783 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001784 return result_tens
1785
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001786 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001787 self.ser.addOutputTensor(val)
1788 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001789
1790 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001791 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001792 result_tens = OutputShaper.typeConversionOp(
1793 self.ser, self.rng, val, out_dtype, error_name
1794 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001795
1796 # Invalidate Input/Output list for error if checks.
1797 input_list = [val.name]
1798 output_list = [result_tens.name]
1799 pCount, cCount = op["operands"]
1800 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001801 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1802 self, error_name, input_list, output_list
1803 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001804
Les Bell729b0352021-11-24 10:28:21 +00001805 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001806 self.ser,
1807 validator_fcns,
1808 error_name,
1809 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001810 input_shape=val.shape,
1811 output_shape=result_tens.shape,
1812 input_dtype=val.dtype,
1813 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001814 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001815 input_list=input_list,
1816 output_list=output_list,
1817 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001818 ):
1819 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001820
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001821 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001822 return result_tens
1823
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001824 def build_rescale(
1825 self,
1826 op,
1827 val,
1828 out_dtype,
1829 scale32,
1830 double_round,
1831 per_channel,
1832 validator_fcns,
1833 error_name,
1834 ):
1835 result_tens = OutputShaper.typeConversionOp(
1836 self.ser, self.rng, val, out_dtype, error_name
1837 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001838
1839 if per_channel:
1840 nc = val.shape[-1]
1841 else:
1842 nc = 1
1843
1844 in_type_width = self.typeWidth(val.dtype)
1845 out_type_width = self.typeWidth(out_dtype)
1846
Kevin Cheng3a478572021-01-22 17:21:02 -08001847 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001848 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001849 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001850 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001851 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001852 in_type_width += 1
1853 elif error_name in [
1854 ErrorIf.InputZeroPointNotZero,
1855 ErrorIf.U16InputZeroPointNotValid,
1856 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001857 input_zp = self.randInt(-128, 128)
1858 if input_zp == 0:
1859 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001860 in_type_width += 1
1861 elif val.dtype == DType.UINT16:
1862 # Must come after ErrorIf.U16InputZeroPointNotValid check
1863 input_zp = self.rng.choice([0, 32768])
1864 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001865 else:
1866 input_zp = 0
1867
Kevin Cheng3a478572021-01-22 17:21:02 -08001868 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001869 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001870 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001871 elif out_dtype == DType.UINT8:
1872 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001873 out_type_width += 1
1874 elif error_name in [
1875 ErrorIf.OutputZeroPointNotZero,
1876 ErrorIf.U16OutputZeroPointNotValid,
1877 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001878 output_zp = self.randInt(-128, 128)
1879 if output_zp == 0:
1880 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001881 out_type_width += 1
1882 elif out_dtype == DType.UINT16:
1883 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1884 output_zp = self.rng.choice([0, 32768])
1885 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001886 else:
1887 output_zp = 0
1888
1889 # Calculate scale based on:
1890 # scale = a *(2^output_width)/(2^input_width))
1891
1892 a = np.float32(self.rng.random(size=[nc]))
1893 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1894
1895 if scale32:
1896 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001897 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001898 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1899 else:
1900 # Cap the scaling at 2^15 - 1 for scale16
1901 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1902
Kevin Cheng550ccc52021-03-03 11:21:43 -08001903 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001904
1905 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1906 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001907 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1908 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001909
1910 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001911 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1912 scale_arr[i], scale32
1913 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001914 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1915 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001916
Kevin Cheng550ccc52021-03-03 11:21:43 -08001917 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001918 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001919 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001920 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001921 assert val.placeholderFilename
1922 values = np.load(
1923 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1924 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001925 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1926 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1927 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1928 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001929 if not np.all(np.array_equal(values, val_adj)):
1930 # Values changed so overwrite file with new values
1931 np.save(
1932 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1933 val_adj,
1934 False,
1935 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001936
Matthew Haddonc2025212021-10-08 21:21:05 +01001937 # Invalidate Input/Output list for error if checks.
1938 input_list = [val.name]
1939 output_list = [result_tens.name]
1940 pCount, cCount = op["operands"]
1941 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001942 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1943 self, error_name, input_list, output_list
1944 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001945
1946 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001947 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001948 self.ser,
1949 validator_fcns,
1950 error_name,
1951 op=op,
1952 input_dtype=val.dtype,
1953 output_dtype=out_dtype,
1954 input_shape=val.shape,
1955 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001956 scale32=scale32,
1957 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001958 input_list=input_list,
1959 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001960 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01001961 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001962 ):
1963 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001964
Eric Kunzee5e26762020-10-13 16:11:07 -07001965 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001966 attr.RescaleAttribute(
1967 input_zp,
1968 output_zp,
1969 multiplier_arr,
1970 shift_arr,
1971 scale32,
1972 double_round,
1973 per_channel,
1974 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001975
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001976 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001977 return result_tens
1978
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001979 def _get_condition_tensor(self, op, cond, error_name):
1980 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001981 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001982 else:
1983 cond_type = DType.BOOL
1984 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1985 choice = self.rng.choice([1, 2])
1986 if choice == 1:
1987 cond_shape = [2]
1988 else:
1989 cond_shape = [1, 2]
1990 else:
1991 # Must be of size 1 (rank 0)
1992 cond_shape = []
1993 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1994 return cond_tens
1995
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001996 def build_cond_if_const(
1997 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1998 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001999 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002000 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002001 # and fill them with const nodes for the body.
2002
2003 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002004 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002005
2006 # Make then/else tensors
2007 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002008
2009 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002010 if error_name in [
2011 ErrorIf.CondIfOutputListThenGraphMismatch,
2012 ErrorIf.CondIfOutputListElseGraphMismatch,
2013 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002014 incorrect_shape = deepcopy(then_tens.shape)
2015 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002016 incorrect_shape[i] += (
2017 self.rng.choice([-3, -2, 2, 3])
2018 if incorrect_shape[i] > 3
2019 else self.rng.choice([1, 2, 4])
2020 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002021 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2022
Jeremy Johnson18e26662021-07-22 16:15:29 +01002023 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2024 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002025
2026 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08002027 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002028
2029 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002030 then_block = "THEN_BLOCK"
2031 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002032 attr = ts.TosaSerializerAttribute()
2033 attr.CondIfAttribute(then_block, else_block)
2034
2035 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002036 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002037
Jerry Ge9e94af82022-10-27 09:57:00 -07002038 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002039 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002040 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
2041 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2042 else:
2043 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002044 self.ser.addOutputTensor(then_tens)
2045
Jerry Ge9e94af82022-10-27 09:57:00 -07002046 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002047 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
2048 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
2049 else:
2050 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002051 self.ser.addOutputTensor(else_tens)
2052
Les Bell729b0352021-11-24 10:28:21 +00002053 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002054 self.ser,
2055 validator_fcns,
2056 error_name,
2057 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002058 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002059 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002060 ):
2061 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002062
Eric Kunzee5e26762020-10-13 16:11:07 -07002063 return result_tens
2064
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002065 def build_cond_if_binary(
2066 self, op, a, b, cond, validator_fcns=None, error_name=None
2067 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002068 # For cond_if with a binary op in the then/else blocks, take a and b and
2069 # alternately add or subtract them based on the condition
2070
2071 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002072 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002073
Kevin Cheng550ccc52021-03-03 11:21:43 -08002074 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002075
2076 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002077 then_block = "THEN_BLOCK"
2078 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002079 attr = ts.TosaSerializerAttribute()
2080 attr.CondIfAttribute(then_block, else_block)
2081
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002082 if error_name in [
2083 ErrorIf.CondIfInputListThenGraphMismatch,
2084 ErrorIf.CondIfInputListElseGraphMismatch,
2085 ErrorIf.CondIfOutputListElseGraphMismatch,
2086 ErrorIf.CondIfOutputListThenGraphMismatch,
2087 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002088 incorrect_shape = a.shape.copy()
2089 for i in range(len(incorrect_shape)):
2090 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2091 incorrect_block_input = deepcopy(a)
2092 incorrect_block_input.shape = incorrect_shape
2093
Eric Kunzee5e26762020-10-13 16:11:07 -07002094 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002095 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002096 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002097 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002098
James Ward24dbc422022-10-19 12:20:31 +01002099 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01002100 then_op, else_op = Op.ADD, Op.SUB
2101 elif a.dtype in (DType.INT8, DType.INT16):
2102 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
2103 else:
2104 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002105
Les Bell6040b4d2021-10-11 12:50:31 +01002106 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002107 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002108 if (
2109 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2110 and block == then_block
2111 ) or (
2112 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2113 and block == else_block
2114 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002115 self.ser.addInputTensor(incorrect_block_input)
2116 self.ser.addInputTensor(b)
2117 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002118 elif (
2119 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2120 and block == then_block
2121 ) or (
2122 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2123 and block == else_block
2124 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002125 self.ser.addInputTensor(a)
2126 self.ser.addInputTensor(b)
2127 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2128 else:
2129 self.ser.addInputTensor(a)
2130 self.ser.addInputTensor(b)
2131 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01002132 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002133
Les Bell729b0352021-11-24 10:28:21 +00002134 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002135 self.ser,
2136 validator_fcns,
2137 error_name,
2138 op=op,
2139 a=a,
2140 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002141 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002142 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002143 ):
2144 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002145
Eric Kunzee5e26762020-10-13 16:11:07 -07002146 return result_tens
2147
Matthew Haddon630c17c2021-10-14 15:05:41 +01002148 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002149 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002150
Kevin Cheng550ccc52021-03-03 11:21:43 -08002151 cond_block = "COND_BLOCK"
2152 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002153
2154 attr = ts.TosaSerializerAttribute()
2155 attr.WhileLoopAttribute(cond_block, body_block)
2156
2157 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002158 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002159 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002160 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002161
2162 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002163 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2164 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002165 if error_name == ErrorIf.InputListOutputListMismatch:
2166 incorrect_acc = deepcopy(acc)
2167 for i in range(len(incorrect_acc.shape)):
2168 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2169 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2170 else:
2171 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002172
2173 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002174 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002175 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002176 [iter.name, a.name, acc.name],
2177 [iter_out.name, a_out.name, acc_out.name],
2178 attr,
2179 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002180 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002181
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002182 if error_name in [
2183 ErrorIf.InputListCondGraphMismatch,
2184 ErrorIf.InputListBodyGraphInputMismatch,
2185 ErrorIf.InputListBodyGraphOutputMismatch,
2186 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002187 incorrect_iter = deepcopy(iter)
2188 for i in range(len(incorrect_iter.shape)):
2189 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2190 if len(incorrect_iter.shape) == 0:
2191 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2192
2193 incorrect_acc = deepcopy(acc)
2194 for i in range(len(incorrect_acc.shape)):
2195 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2196
Eric Kunzee5e26762020-10-13 16:11:07 -07002197 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002198 self.ser.addBasicBlock(cond_block)
2199
Matthew Haddon630c17c2021-10-14 15:05:41 +01002200 if error_name == ErrorIf.InputListCondGraphMismatch:
2201 self.ser.addInputTensor(incorrect_iter)
2202 self.ser.addInputTensor(a)
2203 self.ser.addInputTensor(incorrect_acc)
2204 else:
2205 self.ser.addInputTensor(iter)
2206 self.ser.addInputTensor(a)
2207 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002208 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002209
2210 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002211 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002212 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002213 cond_type = DType.BOOL
2214 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2215 choice = self.rng.choice([1, 2])
2216 if choice == 1:
2217 cond_shape = [3]
2218 else:
2219 cond_shape = [1, 2]
2220 else:
2221 cond_shape = []
2222 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002223
Kevin Cheng550ccc52021-03-03 11:21:43 -08002224 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002225
2226 # BODY block (input: a, acc, iter, output: a, acc, iter)
2227 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002228 self.ser.addBasicBlock(body_block)
2229
Matthew Haddon630c17c2021-10-14 15:05:41 +01002230 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2231 self.ser.addInputTensor(incorrect_iter)
2232 self.ser.addInputTensor(a)
2233 self.ser.addInputTensor(incorrect_acc)
2234 else:
2235 self.ser.addInputTensor(iter)
2236 self.ser.addInputTensor(a)
2237 self.ser.addInputTensor(acc)
2238
Kevin Cheng550ccc52021-03-03 11:21:43 -08002239 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002240
2241 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002242 iter_body_out = self.ser.addIntermediate(
2243 incorrect_iter.shape, incorrect_iter.dtype
2244 )
2245 acc_body_out = self.ser.addIntermediate(
2246 incorrect_acc.shape, incorrect_acc.dtype
2247 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002248 else:
2249 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2250 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2251
Eric Kunzee5e26762020-10-13 16:11:07 -07002252 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2253 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2254 self.ser.addOutputTensor(iter_body_out)
2255 self.ser.addOutputTensor(a)
2256 self.ser.addOutputTensor(acc_body_out)
2257
Les Bell729b0352021-11-24 10:28:21 +00002258 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002259 self.ser,
2260 validator_fcns,
2261 error_name,
2262 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002263 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002264 ):
2265 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002266
Eric Kunzee5e26762020-10-13 16:11:07 -07002267 return acc_out
2268
Luke Hutton57287132023-02-06 14:54:18 +00002269 def build_fft2d(
2270 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2271 ):
2272 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2273
2274 input_names = [val1.name, val2.name]
2275 pCount, cCount = op["operands"]
2276 num_operands = pCount + cCount
2277
2278 output_names = [res.name for res in results]
2279 output_shapes = [res.shape for res in results]
2280 output_dtypes = [res.dtype for res in results]
2281
2282 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2283 self, error_name, input_names, output_names
2284 )
2285
2286 if not TosaErrorValidator.evValidateErrorIfs(
2287 self.ser,
2288 validator_fcns,
2289 error_name,
2290 op=op,
2291 inverse=inverse,
2292 input1=val1,
2293 input2=val2,
2294 input_shape=val1.shape,
2295 input_dtype=val1.dtype,
2296 output_shape=output_shapes,
2297 output_dtype=output_dtypes,
2298 result_tensors=results,
2299 input_list=input_names,
2300 output_list=output_names,
2301 num_operands=num_operands,
2302 ):
2303 return None
2304
2305 attr = ts.TosaSerializerAttribute()
2306 attr.FFTAttribute(inverse)
2307
2308 self.ser.addOperator(op["op"], input_names, output_names, attr)
2309 return results
2310
Luke Hutton261b7b62023-01-10 14:50:31 +00002311 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2312 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2313
2314 input_names = [val.name]
2315 pCount, cCount = op["operands"]
2316 num_operands = pCount + cCount
2317
2318 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002319 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002320 output_dtypes = [res.dtype for res in results]
2321
2322 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2323 self, error_name, input_names, output_names
2324 )
2325
2326 if not TosaErrorValidator.evValidateErrorIfs(
2327 self.ser,
2328 validator_fcns,
2329 error_name,
2330 op=op,
2331 input_shape=val.shape,
2332 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002333 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002334 output_dtype=output_dtypes,
2335 result_tensors=results,
2336 input_list=input_names,
2337 output_list=output_names,
2338 num_operands=num_operands,
2339 ):
2340 return None
2341
2342 self.ser.addOperator(op["op"], input_names, output_names)
2343 return results
2344
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002345 def create_filter_lists(
2346 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2347 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002348 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2349 default_test_rank_range = range(1, 5)
2350 if not shapeFilter:
2351 shapeFilter = [None]
2352
2353 # Calculate the filters based on what is requested and what the operator allows
2354 rmin, rmax = op["rank"]
2355 if rankFilter is not None:
2356 cleanRankFilter = []
2357 # Ensure rankFilter values are allowed by operator
2358 for rank in rankFilter:
2359 if rank >= rmin and rank <= rmax:
2360 cleanRankFilter.append(rank)
2361 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002362 # Ensure default behaviour is bounded by default range or by operator,
2363 # whichever is the smaller range of ranks.
2364 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002365 cleanRankFilter = (
2366 opRankRange
2367 if len(opRankRange) <= len(default_test_rank_range)
2368 else default_test_rank_range
2369 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002370 else:
2371 cleanRankFilter = range(rmin, rmax + 1)
2372
2373 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002374
Matthew Haddon1c00b712021-10-01 15:51:03 +01002375 if dtypeFilter is not None:
2376 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002377 # Create list of operator dtypes filtered by requested dtypes
2378 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002379 if dtype in dtypeFilter or (
2380 isinstance(dtype, list) and dtype[0] in dtypeFilter
2381 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002382 cleanDtypeFilter.append(dtype)
2383 else:
2384 cleanDtypeFilter = dtypes
2385
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002386 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002387 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002388 "shapeFilter": shapeFilter,
2389 "rankFilter": cleanRankFilter,
2390 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002391 }
2392 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002393 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002394 if validator is not None:
2395 validator_info = validator(check=False, op=op)
2396 else:
2397 return None
2398
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002399 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002400
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002401 # Set parameters as required
2402 if error_arguments["rank"] is not None:
2403 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002404 else:
2405 rankFilter = cleanRankFilter
2406
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002407 if error_arguments["dtype"] is not None:
2408 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002409 else:
2410 dtypeFilter = cleanDtypeFilter
2411
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002412 if error_arguments["shape"] is not None:
2413 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002414 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002415 shapeFilter = shapeFilter[
2416 :2
2417 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002418
2419 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002420 "shapeFilter": shapeFilter,
2421 "rankFilter": rankFilter,
2422 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002423 }
2424 return filterDict
2425
Kevin Cheng550ccc52021-03-03 11:21:43 -08002426 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002427 self,
2428 opName,
2429 shapeFilter=[None],
2430 rankFilter=None,
2431 dtypeFilter=None,
2432 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002433 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002434
2435 try:
2436 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002437 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002438 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002439
2440 # Initialize a new random number generator
2441 self.rng = np.random.default_rng(self.random_seed)
2442
Jeremy Johnson1271c442023-09-05 11:39:26 +01002443 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002444
Eric Kunzee5e26762020-10-13 16:11:07 -07002445 # Test list consists of a tuple of:
2446 # (opName, testNameStr, dtype, shapeList, argumentsList)
2447 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002448 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002449 error_if_validators = op["error_if_validators"]
2450 else:
2451 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002452
Matthew Haddon1c00b712021-10-01 15:51:03 +01002453 for validator in error_if_validators:
2454 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002455 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002456 else:
2457 error_name = None
2458
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002459 filterDict = self.create_filter_lists(
2460 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2461 )
2462 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002463 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002464 cleanRankFilter = filterDict["rankFilter"]
2465 cleanDtypeFilter = filterDict["dtypeFilter"]
2466 cleanShapeFilter = filterDict["shapeFilter"]
2467 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002468
2469 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002470 for t in cleanDtypeFilter:
2471 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002472 # Filter out by rank
2473 if shape is not None and len(shape) != r:
2474 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002475 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002476 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002477
Matthew Haddon74567092021-07-16 15:38:20 +01002478 shapeStr = self.shapeStr(shapeList[0])
2479 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002480
Matthew Haddon74567092021-07-16 15:38:20 +01002481 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2482 argList = []
2483 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002484 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002485 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002486 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002487
Matthew Haddon74567092021-07-16 15:38:20 +01002488 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002489 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002490 if argStr:
2491 testStr = "{}_{}_{}_{}".format(
2492 opName, shapeStr, typeStr, argStr
2493 )
2494 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002495 testStr = "{}_{}_{}".format(
2496 opName, shapeStr, typeStr
2497 )
2498 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002499 if argStr:
2500 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2501 opName, error_name, shapeStr, typeStr, argStr
2502 )
2503 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002504 testStr = "{}_ERRORIF_{}_{}_{}".format(
2505 opName, error_name, shapeStr, typeStr
2506 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002507
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002508 testList.append(
2509 (opName, testStr, t, error_name, shapeList, args)
2510 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002511
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002512 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002513 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2514 if "invalid_test_validators" in op:
2515 invalid_test_validators = op["invalid_test_validators"]
2516 clean_testList = []
2517 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002518 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002519 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002520 if validator_fcn(
2521 opName=test[0],
2522 input_dtype=test[2],
2523 shapeList=test[4],
2524 args=test[5],
2525 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002526 remove_test = True
2527 if not remove_test:
2528 clean_testList.append(test)
2529 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002530
2531 return testList
2532
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002533 def serializeTest(
2534 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2535 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002536 try:
2537 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002538 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002539 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002540
Jeremy Johnson0c716862023-04-13 17:18:19 +01002541 if self.args.verbose:
2542 print(f"Creating {testStr}")
2543
Eric Kunzee5e26762020-10-13 16:11:07 -07002544 # Create a serializer
2545 self.createSerializer(opName, testStr)
2546
Jeremy Johnson1271c442023-09-05 11:39:26 +01002547 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002548 if "error_if_validators" in op:
2549 error_if_validators = op["error_if_validators"]
2550 else:
2551 error_if_validators = None
2552
Kevin Cheng550ccc52021-03-03 11:21:43 -08002553 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002554 num_operands = pCount + cCount
2555
2556 if isinstance(dtype_or_dtypeList, list):
2557 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002558 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002559 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002560 else:
2561 dtypeList = [dtype_or_dtypeList] * (num_operands)
2562
Kevin Cheng93a16282021-08-31 16:14:03 -07002563 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002564 assert (
2565 len(shapeList) == num_operands
2566 ), "shapeList length {} must match number of operands {}".format(
2567 len(shapeList), num_operands
2568 )
2569 assert (
2570 len(dtypeList) == num_operands
2571 ), "dtypeList length {} must match number of operands {}".format(
2572 len(dtypeList), num_operands
2573 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002574
2575 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002576 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002577 except KeyError:
2578 qgen = None
2579
2580 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08002581
Matthew Haddon1c00b712021-10-01 15:51:03 +01002582 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002583 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002584 else:
2585 qinfo = None
2586
Jeremy Johnson1271c442023-09-05 11:39:26 +01002587 # Extra meta data for the desc.json
2588 tensMeta = {}
2589
2590 # Check we are using the new testArgs interface with an argsDict dictionary
2591 if len(testArgs) == 1 and isinstance(testArgs[0], dict):
2592 argsDict = testArgs[0]
2593 assert "dg_type" in argsDict
2594 tvgInfo = tvgen_fcn(
2595 self, opName, dtypeList, shapeList, argsDict, error_name
2596 )
2597 if tvgInfo.dataGenDict:
2598 tensMeta["data_gen"] = tvgInfo.dataGenDict
2599 tens = tvgInfo.tensorList
2600 else:
2601 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002602
Matthew Haddon1c00b712021-10-01 15:51:03 +01002603 try:
2604 if error_if_validators is None:
2605 if qinfo is not None:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002606 result = build_fcn(self, op, *tens, *testArgs, qinfo)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002607 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002608 result = build_fcn(self, op, *tens, *testArgs)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002609 else:
2610 if qinfo is not None:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002611 result = build_fcn(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002612 self,
2613 op,
2614 *tens,
2615 *testArgs,
2616 validator_fcns=error_if_validators,
2617 error_name=error_name,
2618 qinfo=qinfo,
2619 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002620 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002621 result = build_fcn(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002622 self,
2623 op,
2624 *tens,
2625 *testArgs,
2626 validator_fcns=error_if_validators,
2627 error_name=error_name,
2628 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002629 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002630 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002631 raise e
2632
Jeremy Johnson1271c442023-09-05 11:39:26 +01002633 if result:
Les Bell729b0352021-11-24 10:28:21 +00002634 # The test is valid, serialize it
Jeremy Johnson1271c442023-09-05 11:39:26 +01002635 if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
2636 # Add the compliance meta data
2637 # NOTE: This currently expects only one result output
2638 tensMeta["compliance"] = {
2639 "version": "0.1",
2640 "tensors": {result.resultTensor.name: result.complianceDict},
2641 }
2642 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00002643 else:
2644 # The test is not valid
2645 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002646
Eric Kunzee5e26762020-10-13 16:11:07 -07002647 def createDynamicOpLists(self):
2648
Jeremy Johnson00423432022-09-12 17:27:37 +01002649 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2650 # Already created these lists (can occur when class is initialized more than once)
2651 return
2652
Eric Kunzee5e26762020-10-13 16:11:07 -07002653 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002654 if not self.args.level8k:
2655 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2656 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2657 else:
2658 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2659 KERNELS_2D = [[1, bigK], [bigK, 2]]
2660 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002661
Kevin Cheng1533b852021-09-01 12:51:58 -07002662 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002663 testName = "conv2d_{}x{}".format(k[0], k[1])
2664 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2665 self.TOSA_OP_LIST[testName]["filter"] = k
2666 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002667
Kevin Cheng550ccc52021-03-03 11:21:43 -08002668 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2669 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2670 "depthwise_conv2d_TEMPLATE"
2671 ].copy()
2672 self.TOSA_OP_LIST[testName]["filter"] = k
2673 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002674
Kevin Cheng550ccc52021-03-03 11:21:43 -08002675 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2676 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2677 "transpose_conv2d_TEMPLATE"
2678 ].copy()
2679 self.TOSA_OP_LIST[testName]["filter"] = k
2680 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002681
Kevin Cheng1533b852021-09-01 12:51:58 -07002682 for k in KERNELS_3D:
2683 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2684 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2685 self.TOSA_OP_LIST[testName]["filter"] = k
2686 self.TOSA_OP_LIST[testName]["template"] = False
2687
Eric Kunzee5e26762020-10-13 16:11:07 -07002688 # Delete any templates after having created any dynamic ops
2689 # This is a two-pass operation because it's bad practice to delete
2690 # keys from dictionaries while iterating
2691 keyList = []
2692 for k in self.TOSA_OP_LIST:
2693 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002694 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002695 keyList.append(k)
2696 continue
2697 except KeyError:
2698 pass
2699
2700 for k in keyList:
2701 del self.TOSA_OP_LIST[k]
2702
2703 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002704 """Fill in default fields for ops if they aren't already specified.
2705 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002706 for op in self.TOSA_OP_LIST:
2707
2708 # Required fields
2709 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002710 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002711 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002712 raise Exception(
2713 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2714 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002715
2716 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002717 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002718 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002719 raise Exception(
2720 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2721 op
2722 )
2723 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002724
2725 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002726 _ = self.TOSA_OP_LIST[op]["types"]
2727 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002728 raise Exception(
2729 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2730 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002731
2732 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002733 _ = self.TOSA_OP_LIST[op]["op"]
2734 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002735 raise Exception(
2736 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2737 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002738
2739 # Put in default rank range, if missing
2740 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002741 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002742 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002743 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002744
2745 # Tensor operator list
2746 # 'op': op name
2747 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002748 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2749 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002750 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2751 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002752 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002753
Kevin Cheng550ccc52021-03-03 11:21:43 -08002754 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002755 TYPE_INT_FP = [
2756 DType.INT8,
2757 DType.INT16,
2758 DType.INT32,
2759 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002760 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002761 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002762 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002763
Kevin Cheng550ccc52021-03-03 11:21:43 -08002764 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002765 TYPE_FI32 = [
2766 DType.FP32,
2767 DType.FP16,
2768 DType.BF16,
2769 DType.INT32,
2770 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002771 TYPE_FIB = [
2772 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002773 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002774 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002775 DType.INT8,
2776 DType.INT16,
2777 DType.INT32,
2778 DType.BOOL,
2779 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002780 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002781
James Ward24dbc422022-10-19 12:20:31 +01002782 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002783
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002784 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002785 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002786 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002787 [DType.INT8, DType.INT8, DType.INT32],
2788 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002789 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002790 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002791 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002792 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002793 ]
2794
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002795 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002796
2797 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002798 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002799 "argmax": {
2800 "op": Op.ARGMAX,
2801 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002802 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002803 "build_fcn": (
2804 build_argmax,
2805 TosaTensorGen.tgBasic,
2806 TosaTensorValuesGen.tvgDefault,
2807 TosaArgGen.agAxis,
2808 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002809 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002810 "error_if_validators": (
2811 TosaErrorValidator.evAxisSmallerZero,
2812 TosaErrorValidator.evAxisLargerRank,
2813 TosaErrorValidator.evArgmaxOutputRankMismatch,
2814 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2815 TosaErrorValidator.evWrongRank,
2816 TosaErrorValidator.evWrongInputType,
2817 TosaErrorValidator.evWrongOutputType,
2818 TosaErrorValidator.evWrongInputList,
2819 TosaErrorValidator.evWrongOutputList,
2820 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002821 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002822 "avg_pool2d": {
2823 "op": Op.AVG_POOL2D,
2824 "operands": (1, 0),
2825 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002826 "build_fcn": (
2827 build_pool2d,
2828 TosaTensorGen.tgNHWC,
2829 TosaTensorValuesGen.tvgDefault,
2830 TosaArgGen.agPooling,
2831 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002832 "qgen": TosaQuantGen.qgUnary,
2833 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002834 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002835 "error_if_validators": (
2836 TosaErrorValidator.evKernelSmallerOne,
2837 TosaErrorValidator.evStrideSmallerOne,
2838 TosaErrorValidator.evPadSmallerZero,
2839 TosaErrorValidator.evWrongRank,
2840 TosaErrorValidator.evWrongInputType,
2841 TosaErrorValidator.evWrongOutputType,
2842 TosaErrorValidator.evWrongInputList,
2843 TosaErrorValidator.evWrongOutputList,
2844 TosaErrorValidator.evInputZeroPointNotZero,
2845 TosaErrorValidator.evOutputZeroPointNotZero,
2846 TosaErrorValidator.evPadLargerEqualKernel,
2847 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002848 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002849 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002850 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002851 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002852 "conv2d_TEMPLATE": {
2853 "op": Op.CONV2D,
2854 "operands": (1, 2),
2855 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002856 "build_fcn": (
2857 build_conv2d,
2858 TosaTensorGen.tgConv2D,
2859 TosaTensorValuesGen.tvgDefault,
2860 TosaArgGen.agConv,
2861 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002862 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002863 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002864 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2865 "error_if_validators": (
2866 TosaErrorValidator.evWrongInputType,
2867 TosaErrorValidator.evWrongOutputType,
2868 TosaErrorValidator.evWrongInputList,
2869 TosaErrorValidator.evWrongOutputList,
2870 TosaErrorValidator.evInputZeroPointNotZero,
2871 TosaErrorValidator.evWeightZeroPointNotZero,
2872 TosaErrorValidator.evPadSmallerZero,
2873 TosaErrorValidator.evStrideSmallerOne,
2874 TosaErrorValidator.evDilationSmallerOne,
2875 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002876 TosaErrorValidator.evConvOutputShapeMismatch,
2877 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002878 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002879 "template": True,
2880 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002881 # Templated operator. Filled in by createDynamicOpLists
2882 "conv3d_TEMPLATE": {
2883 "op": Op.CONV3D,
2884 "operands": (1, 2),
2885 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002886 "build_fcn": (
2887 build_conv3d,
2888 TosaTensorGen.tgConv3D,
2889 TosaTensorValuesGen.tvgDefault,
2890 TosaArgGen.agConv,
2891 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002892 "qgen": TosaQuantGen.qgConv,
2893 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002894 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2895 "error_if_validators": (
2896 TosaErrorValidator.evWrongInputType,
2897 TosaErrorValidator.evWrongOutputType,
2898 TosaErrorValidator.evWrongInputList,
2899 TosaErrorValidator.evWrongOutputList,
2900 TosaErrorValidator.evInputZeroPointNotZero,
2901 TosaErrorValidator.evWeightZeroPointNotZero,
2902 TosaErrorValidator.evPadSmallerZero,
2903 TosaErrorValidator.evStrideSmallerOne,
2904 TosaErrorValidator.evDilationSmallerOne,
2905 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002906 TosaErrorValidator.evConvOutputShapeMismatch,
2907 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002908 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002909 "template": True,
2910 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002911 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002912 "depthwise_conv2d_TEMPLATE": {
2913 "op": Op.DEPTHWISE_CONV2D,
2914 "operands": (1, 2),
2915 "filter": [1, 1],
2916 "rank": (4, 4),
2917 "build_fcn": (
2918 build_depthwise_conv2d,
2919 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002920 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002921 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002922 ),
2923 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002924 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002925 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2926 "error_if_validators": (
2927 TosaErrorValidator.evWrongInputType,
2928 TosaErrorValidator.evWrongOutputType,
2929 TosaErrorValidator.evWrongInputList,
2930 TosaErrorValidator.evWrongOutputList,
2931 TosaErrorValidator.evInputZeroPointNotZero,
2932 TosaErrorValidator.evWeightZeroPointNotZero,
2933 TosaErrorValidator.evPadSmallerZero,
2934 TosaErrorValidator.evStrideSmallerOne,
2935 TosaErrorValidator.evDilationSmallerOne,
2936 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002937 TosaErrorValidator.evConvOutputShapeMismatch,
2938 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002939 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002940 "template": True,
2941 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002942 "fully_connected": {
2943 "op": Op.FULLY_CONNECTED,
2944 "operands": (1, 2),
2945 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002946 "build_fcn": (
2947 build_fully_connected,
2948 TosaTensorGen.tgFullyConnected,
2949 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002950 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002951 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002952 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002953 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002954 "error_if_validators": (
2955 TosaErrorValidator.evInputZeroPointNotZero,
2956 TosaErrorValidator.evWeightZeroPointNotZero,
2957 TosaErrorValidator.evWrongRank,
2958 TosaErrorValidator.evWrongInputType,
2959 TosaErrorValidator.evWrongOutputType,
2960 TosaErrorValidator.evWrongInputList,
2961 TosaErrorValidator.evWrongOutputList,
2962 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002963 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002964 "matmul": {
2965 "op": Op.MATMUL,
2966 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002967 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002968 "build_fcn": (
2969 build_matmul,
2970 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002971 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01002972 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002973 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002974 "qgen": TosaQuantGen.qgMatmul,
2975 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002976 "error_if_validators": (
2977 TosaErrorValidator.evInputZeroPointNotZero,
2978 TosaErrorValidator.evWrongRank,
2979 TosaErrorValidator.evWrongInputType,
2980 TosaErrorValidator.evWrongOutputType,
2981 TosaErrorValidator.evWrongInputList,
2982 TosaErrorValidator.evWrongOutputList,
2983 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01002984 "data_gen": {
2985 "fp": (gtu.DataGenType.DOT_PRODUCT,),
2986 "int": (gtu.DataGenType.PSEUDO_RANDOM,),
2987 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002988 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002989 "max_pool2d": {
2990 "op": Op.MAX_POOL2D,
2991 "operands": (1, 0),
2992 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002993 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002994 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002995 TosaTensorGen.tgNHWC,
2996 TosaTensorValuesGen.tvgDefault,
2997 TosaArgGen.agPooling,
2998 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002999 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00003000 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003001 "error_if_validators": (
3002 TosaErrorValidator.evKernelSmallerOne,
3003 TosaErrorValidator.evStrideSmallerOne,
3004 TosaErrorValidator.evPadSmallerZero,
3005 TosaErrorValidator.evWrongRank,
3006 TosaErrorValidator.evWrongInputType,
3007 TosaErrorValidator.evWrongOutputType,
3008 TosaErrorValidator.evWrongInputList,
3009 TosaErrorValidator.evWrongOutputList,
3010 TosaErrorValidator.evPadLargerEqualKernel,
3011 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003012 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003013 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003014 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003015 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003016 "transpose_conv2d_TEMPLATE": {
3017 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003018 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003019 "rank": (4, 4),
3020 "build_fcn": (
3021 build_transpose_conv2d,
3022 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003023 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003024 TosaArgGen.agTransposeConv2D,
3025 ),
3026 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003027 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003028 "invalid_test_validators": (
3029 TosaInvalidValidator.ivHeightWidthInvalid,
3030 TosaInvalidValidator.ivNonPositiveOutputShape,
3031 ),
3032 "error_if_validators": (
3033 TosaErrorValidator.evWrongInputType,
3034 TosaErrorValidator.evWrongOutputType,
3035 TosaErrorValidator.evWrongInputList,
3036 TosaErrorValidator.evWrongOutputList,
3037 TosaErrorValidator.evInputZeroPointNotZero,
3038 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003039 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003040 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003041 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003042 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003043 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003044 "template": True,
3045 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003046 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003047 "clamp": {
3048 "op": Op.CLAMP,
3049 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003050 "build_fcn": (
3051 build_clamp,
3052 TosaTensorGen.tgBasic,
3053 TosaTensorValuesGen.tvgDefault,
3054 None,
3055 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003056 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003057 "error_if_validators": (
3058 TosaErrorValidator.evMaxSmallerMin,
3059 TosaErrorValidator.evWrongInputType,
3060 TosaErrorValidator.evWrongOutputType,
3061 TosaErrorValidator.evWrongInputList,
3062 TosaErrorValidator.evWrongOutputList,
3063 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003064 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003065 "sigmoid": {
3066 "op": Op.SIGMOID,
3067 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003068 "build_fcn": (
3069 build_sigmoid,
3070 TosaTensorGen.tgBasic,
3071 TosaTensorValuesGen.tvgDefault,
3072 None,
3073 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003074 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003075 "error_if_validators": (
3076 TosaErrorValidator.evWrongInputType,
3077 TosaErrorValidator.evWrongOutputType,
3078 TosaErrorValidator.evWrongInputList,
3079 TosaErrorValidator.evWrongOutputList,
3080 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003081 },
3082 "tanh": {
3083 "op": Op.TANH,
3084 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003085 "build_fcn": (
3086 build_tanh,
3087 TosaTensorGen.tgBasic,
3088 TosaTensorValuesGen.tvgDefault,
3089 None,
3090 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003091 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003092 "error_if_validators": (
3093 TosaErrorValidator.evWrongInputType,
3094 TosaErrorValidator.evWrongOutputType,
3095 TosaErrorValidator.evWrongInputList,
3096 TosaErrorValidator.evWrongOutputList,
3097 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003098 },
Won Jeon78155c62023-06-10 00:20:04 +00003099 "erf": {
3100 "op": Op.ERF,
3101 "operands": (1, 0),
3102 "build_fcn": (
3103 build_erf,
3104 TosaTensorGen.tgBasic,
3105 TosaTensorValuesGen.tvgDefault,
3106 None,
3107 ),
3108 "types": TYPE_FP,
3109 "error_if_validators": (
3110 TosaErrorValidator.evWrongInputType,
3111 TosaErrorValidator.evWrongOutputType,
3112 TosaErrorValidator.evWrongInputList,
3113 TosaErrorValidator.evWrongOutputList,
3114 ),
3115 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003116 # Elementwise Binary Operators
3117 "add": {
3118 "op": Op.ADD,
3119 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003120 "build_fcn": (
3121 build_binary_broadcast,
3122 TosaTensorGen.tgBroadcastFuzz,
3123 TosaTensorValuesGen.tvgAddSub,
3124 None,
3125 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003126 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003127 "error_if_validators": (
3128 TosaErrorValidator.evRankMismatch,
3129 TosaErrorValidator.evWrongInputType,
3130 TosaErrorValidator.evWrongOutputType,
3131 TosaErrorValidator.evWrongInputList,
3132 TosaErrorValidator.evWrongOutputList,
3133 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003134 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003135 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003136 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003137 "arithmetic_right_shift": {
3138 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3139 "operands": (2, 0),
3140 "build_fcn": (
3141 build_arithmetic_right_shift,
3142 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003143 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003144 TosaArgGen.agArithmeticRightShift,
3145 ),
3146 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003147 "error_if_validators": (
3148 TosaErrorValidator.evRankMismatch,
3149 TosaErrorValidator.evWrongInputType,
3150 TosaErrorValidator.evWrongOutputType,
3151 TosaErrorValidator.evWrongInputList,
3152 TosaErrorValidator.evWrongOutputList,
3153 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003154 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003155 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003156 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003157 "bitwise_and": {
3158 "op": Op.BITWISE_AND,
3159 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003160 "build_fcn": (
3161 build_binary_broadcast,
3162 TosaTensorGen.tgBroadcastFuzz,
3163 TosaTensorValuesGen.tvgDefault,
3164 None,
3165 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003166 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003167 "error_if_validators": (
3168 TosaErrorValidator.evRankMismatch,
3169 TosaErrorValidator.evWrongInputType,
3170 TosaErrorValidator.evWrongOutputType,
3171 TosaErrorValidator.evWrongInputList,
3172 TosaErrorValidator.evWrongOutputList,
3173 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003174 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003175 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003176 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003177 "bitwise_or": {
3178 "op": Op.BITWISE_OR,
3179 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003180 "build_fcn": (
3181 build_binary_broadcast,
3182 TosaTensorGen.tgBroadcastFuzz,
3183 TosaTensorValuesGen.tvgDefault,
3184 None,
3185 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003186 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003187 "error_if_validators": (
3188 TosaErrorValidator.evRankMismatch,
3189 TosaErrorValidator.evWrongInputType,
3190 TosaErrorValidator.evWrongOutputType,
3191 TosaErrorValidator.evWrongInputList,
3192 TosaErrorValidator.evWrongOutputList,
3193 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003194 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003195 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003196 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003197 "bitwise_xor": {
3198 "op": Op.BITWISE_XOR,
3199 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003200 "build_fcn": (
3201 build_binary_broadcast,
3202 TosaTensorGen.tgBroadcastFuzz,
3203 TosaTensorValuesGen.tvgDefault,
3204 None,
3205 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003206 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003207 "error_if_validators": (
3208 TosaErrorValidator.evRankMismatch,
3209 TosaErrorValidator.evWrongInputType,
3210 TosaErrorValidator.evWrongOutputType,
3211 TosaErrorValidator.evWrongInputList,
3212 TosaErrorValidator.evWrongOutputList,
3213 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003214 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003215 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003216 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003217 "intdiv": {
3218 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003219 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003220 "build_fcn": (
3221 build_binary_broadcast,
3222 TosaTensorGen.tgBroadcastFuzz,
3223 TosaTensorValuesGen.tvgIntDiv,
3224 None,
3225 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003226 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003227 "error_if_validators": (
3228 TosaErrorValidator.evRankMismatch,
3229 TosaErrorValidator.evWrongInputType,
3230 TosaErrorValidator.evWrongOutputType,
3231 TosaErrorValidator.evWrongInputList,
3232 TosaErrorValidator.evWrongOutputList,
3233 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003234 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003235 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003236 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003237 "logical_and": {
3238 "op": Op.LOGICAL_AND,
3239 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003240 "build_fcn": (
3241 build_binary_broadcast,
3242 TosaTensorGen.tgBroadcastFuzz,
3243 TosaTensorValuesGen.tvgDefault,
3244 None,
3245 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003246 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003247 "error_if_validators": (
3248 TosaErrorValidator.evRankMismatch,
3249 TosaErrorValidator.evWrongInputType,
3250 TosaErrorValidator.evWrongOutputType,
3251 TosaErrorValidator.evWrongInputList,
3252 TosaErrorValidator.evWrongOutputList,
3253 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003254 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003255 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003256 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003257 "logical_left_shift": {
3258 "op": Op.LOGICAL_LEFT_SHIFT,
3259 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003260 "build_fcn": (
3261 build_binary_broadcast,
3262 TosaTensorGen.tgBroadcastFuzz,
3263 TosaTensorValuesGen.tvgLogicalShift,
3264 None,
3265 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003266 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003267 "error_if_validators": (
3268 TosaErrorValidator.evRankMismatch,
3269 TosaErrorValidator.evWrongInputType,
3270 TosaErrorValidator.evWrongOutputType,
3271 TosaErrorValidator.evWrongInputList,
3272 TosaErrorValidator.evWrongOutputList,
3273 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003274 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003275 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003276 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003277 "logical_right_shift": {
3278 "op": Op.LOGICAL_RIGHT_SHIFT,
3279 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003280 "build_fcn": (
3281 build_binary_broadcast,
3282 TosaTensorGen.tgBroadcastFuzz,
3283 TosaTensorValuesGen.tvgLogicalShift,
3284 None,
3285 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003286 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003287 "error_if_validators": (
3288 TosaErrorValidator.evRankMismatch,
3289 TosaErrorValidator.evWrongInputType,
3290 TosaErrorValidator.evWrongOutputType,
3291 TosaErrorValidator.evWrongInputList,
3292 TosaErrorValidator.evWrongOutputList,
3293 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003294 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003295 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003296 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003297 "logical_or": {
3298 "op": Op.LOGICAL_OR,
3299 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003300 "build_fcn": (
3301 build_binary_broadcast,
3302 TosaTensorGen.tgBroadcastFuzz,
3303 TosaTensorValuesGen.tvgDefault,
3304 None,
3305 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003306 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003307 "error_if_validators": (
3308 TosaErrorValidator.evRankMismatch,
3309 TosaErrorValidator.evWrongInputType,
3310 TosaErrorValidator.evWrongOutputType,
3311 TosaErrorValidator.evWrongInputList,
3312 TosaErrorValidator.evWrongOutputList,
3313 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003314 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003315 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003316 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003317 "logical_xor": {
3318 "op": Op.LOGICAL_XOR,
3319 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003320 "build_fcn": (
3321 build_binary_broadcast,
3322 TosaTensorGen.tgBroadcastFuzz,
3323 TosaTensorValuesGen.tvgDefault,
3324 None,
3325 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003326 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003327 "error_if_validators": (
3328 TosaErrorValidator.evRankMismatch,
3329 TosaErrorValidator.evWrongInputType,
3330 TosaErrorValidator.evWrongOutputType,
3331 TosaErrorValidator.evWrongInputList,
3332 TosaErrorValidator.evWrongOutputList,
3333 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003334 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003335 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003336 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003337 "maximum": {
3338 "op": Op.MAXIMUM,
3339 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003340 "build_fcn": (
3341 build_binary_broadcast,
3342 TosaTensorGen.tgBroadcastFuzz,
3343 TosaTensorValuesGen.tvgDefault,
3344 None,
3345 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003346 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003347 "error_if_validators": (
3348 TosaErrorValidator.evRankMismatch,
3349 TosaErrorValidator.evWrongInputType,
3350 TosaErrorValidator.evWrongOutputType,
3351 TosaErrorValidator.evWrongInputList,
3352 TosaErrorValidator.evWrongOutputList,
3353 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003354 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003355 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003356 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003357 "minimum": {
3358 "op": Op.MINIMUM,
3359 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003360 "build_fcn": (
3361 build_binary_broadcast,
3362 TosaTensorGen.tgBroadcastFuzz,
3363 TosaTensorValuesGen.tvgDefault,
3364 None,
3365 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003366 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003367 "error_if_validators": (
3368 TosaErrorValidator.evRankMismatch,
3369 TosaErrorValidator.evWrongInputType,
3370 TosaErrorValidator.evWrongOutputType,
3371 TosaErrorValidator.evWrongInputList,
3372 TosaErrorValidator.evWrongOutputList,
3373 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003374 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003375 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003376 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003377 "mul": {
3378 "op": Op.MUL,
3379 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003380 "build_fcn": (
3381 build_mul,
3382 TosaTensorGen.tgBroadcastFuzz,
3383 TosaTensorValuesGen.tvgMul,
3384 TosaArgGen.agMul,
3385 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003386 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003387 "error_if_validators": (
3388 TosaErrorValidator.evWrongInputType,
3389 TosaErrorValidator.evWrongOutputType,
3390 TosaErrorValidator.evWrongInputList,
3391 TosaErrorValidator.evWrongOutputList,
3392 TosaErrorValidator.evRankMismatch,
3393 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003394 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003395 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003396 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003397 "pow": {
3398 "op": Op.POW,
3399 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003400 "build_fcn": (
3401 build_binary_broadcast,
3402 TosaTensorGen.tgBroadcastFuzz,
3403 TosaTensorValuesGen.tvgDefault,
3404 None,
3405 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003406 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003407 "error_if_validators": (
3408 TosaErrorValidator.evRankMismatch,
3409 TosaErrorValidator.evWrongInputType,
3410 TosaErrorValidator.evWrongOutputType,
3411 TosaErrorValidator.evWrongInputList,
3412 TosaErrorValidator.evWrongOutputList,
3413 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003414 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003415 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003416 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003417 "sub": {
3418 "op": Op.SUB,
3419 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003420 "build_fcn": (
3421 build_binary_broadcast,
3422 TosaTensorGen.tgBroadcastFuzz,
3423 TosaTensorValuesGen.tvgAddSub,
3424 None,
3425 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003426 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003427 "error_if_validators": (
3428 TosaErrorValidator.evRankMismatch,
3429 TosaErrorValidator.evWrongInputType,
3430 TosaErrorValidator.evWrongOutputType,
3431 TosaErrorValidator.evWrongInputList,
3432 TosaErrorValidator.evWrongOutputList,
3433 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003434 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003435 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003436 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003437 "table": {
3438 "op": Op.TABLE,
3439 # Use the automatic generation functions to create the input array
3440 # but create the table tensor in the build function, as it may be
3441 # a different type from the input
3442 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003443 "build_fcn": (
3444 build_table,
3445 TosaTensorGen.tgBasic,
3446 TosaTensorValuesGen.tvgDefault,
3447 TosaArgGen.agTable,
3448 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003449 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003450 "error_if_validators": (
3451 TosaErrorValidator.evWrongInputType,
3452 TosaErrorValidator.evWrongOutputType,
3453 TosaErrorValidator.evWrongInputList,
3454 TosaErrorValidator.evWrongOutputList,
3455 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003456 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003457 # Elementwise Unary operators
3458 "abs": {
3459 "op": Op.ABS,
3460 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003461 "build_fcn": (
3462 build_unary,
3463 TosaTensorGen.tgBasic,
3464 TosaTensorValuesGen.tvgDefault,
3465 None,
3466 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003467 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003468 "error_if_validators": (
3469 TosaErrorValidator.evWrongInputType,
3470 TosaErrorValidator.evWrongOutputType,
3471 TosaErrorValidator.evWrongInputList,
3472 TosaErrorValidator.evWrongOutputList,
3473 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003474 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003475 "bitwise_not": {
3476 "op": Op.BITWISE_NOT,
3477 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003478 "build_fcn": (
3479 build_unary,
3480 TosaTensorGen.tgBasic,
3481 TosaTensorValuesGen.tvgDefault,
3482 None,
3483 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003484 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003485 "error_if_validators": (
3486 TosaErrorValidator.evWrongInputType,
3487 TosaErrorValidator.evWrongOutputType,
3488 TosaErrorValidator.evWrongInputList,
3489 TosaErrorValidator.evWrongOutputList,
3490 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003491 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003492 "ceil": {
3493 "op": Op.CEIL,
3494 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003495 "build_fcn": (
3496 build_unary,
3497 TosaTensorGen.tgBasic,
3498 TosaTensorValuesGen.tvgDefault,
3499 None,
3500 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003501 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003502 "error_if_validators": (
3503 TosaErrorValidator.evWrongInputType,
3504 TosaErrorValidator.evWrongOutputType,
3505 TosaErrorValidator.evWrongInputList,
3506 TosaErrorValidator.evWrongOutputList,
3507 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003508 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003509 "clz": {
3510 "op": Op.CLZ,
3511 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003512 "build_fcn": (
3513 build_unary,
3514 TosaTensorGen.tgBasic,
3515 TosaTensorValuesGen.tvgDefault,
3516 None,
3517 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003518 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003519 "error_if_validators": (
3520 TosaErrorValidator.evWrongInputType,
3521 TosaErrorValidator.evWrongOutputType,
3522 TosaErrorValidator.evWrongInputList,
3523 TosaErrorValidator.evWrongOutputList,
3524 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003525 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003526 "exp": {
3527 "op": Op.EXP,
3528 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003529 "build_fcn": (
3530 build_unary,
3531 TosaTensorGen.tgBasic,
3532 TosaTensorValuesGen.tvgDefault,
3533 None,
3534 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003536 "error_if_validators": (
3537 TosaErrorValidator.evWrongInputType,
3538 TosaErrorValidator.evWrongOutputType,
3539 TosaErrorValidator.evWrongInputList,
3540 TosaErrorValidator.evWrongOutputList,
3541 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003542 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003543 "floor": {
3544 "op": Op.FLOOR,
3545 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003546 "build_fcn": (
3547 build_unary,
3548 TosaTensorGen.tgBasic,
3549 TosaTensorValuesGen.tvgDefault,
3550 None,
3551 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003552 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003553 "error_if_validators": (
3554 TosaErrorValidator.evWrongInputType,
3555 TosaErrorValidator.evWrongOutputType,
3556 TosaErrorValidator.evWrongInputList,
3557 TosaErrorValidator.evWrongOutputList,
3558 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003559 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003560 "log": {
3561 "op": Op.LOG,
3562 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003563 "build_fcn": (
3564 build_unary,
3565 TosaTensorGen.tgBasic,
3566 TosaTensorValuesGen.tvgDefault,
3567 None,
3568 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003569 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003570 "error_if_validators": (
3571 TosaErrorValidator.evWrongInputType,
3572 TosaErrorValidator.evWrongOutputType,
3573 TosaErrorValidator.evWrongInputList,
3574 TosaErrorValidator.evWrongOutputList,
3575 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003576 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003577 "logical_not": {
3578 "op": Op.LOGICAL_NOT,
3579 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003580 "build_fcn": (
3581 build_unary,
3582 TosaTensorGen.tgBasic,
3583 TosaTensorValuesGen.tvgDefault,
3584 None,
3585 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003586 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003587 "error_if_validators": (
3588 TosaErrorValidator.evWrongInputType,
3589 TosaErrorValidator.evWrongOutputType,
3590 TosaErrorValidator.evWrongInputList,
3591 TosaErrorValidator.evWrongOutputList,
3592 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003593 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003594 "negate": {
3595 "op": Op.NEGATE,
3596 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003597 "build_fcn": (
3598 build_unary,
3599 TosaTensorGen.tgBasic,
3600 TosaTensorValuesGen.tvgNegate,
3601 None,
3602 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003603 "qgen": TosaQuantGen.qgUnary,
3604 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003605 "error_if_validators": (
3606 TosaErrorValidator.evInputZeroPointNotZero,
3607 TosaErrorValidator.evOutputZeroPointNotZero,
3608 TosaErrorValidator.evWrongInputType,
3609 TosaErrorValidator.evWrongOutputType,
3610 TosaErrorValidator.evWrongInputList,
3611 TosaErrorValidator.evWrongOutputList,
3612 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003613 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003614 "reciprocal": {
3615 "op": Op.RECIPROCAL,
3616 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003617 "build_fcn": (
3618 build_unary,
3619 TosaTensorGen.tgBasic,
3620 TosaTensorValuesGen.tvgDefault,
3621 None,
3622 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003623 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003624 "error_if_validators": (
3625 TosaErrorValidator.evWrongInputType,
3626 TosaErrorValidator.evWrongOutputType,
3627 TosaErrorValidator.evWrongInputList,
3628 TosaErrorValidator.evWrongOutputList,
3629 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003630 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003631 "rsqrt": {
3632 "op": Op.RSQRT,
3633 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003634 "build_fcn": (
3635 build_unary,
3636 TosaTensorGen.tgBasic,
3637 TosaTensorValuesGen.tvgDefault,
3638 None,
3639 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003640 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003641 "error_if_validators": (
3642 TosaErrorValidator.evWrongInputType,
3643 TosaErrorValidator.evWrongOutputType,
3644 TosaErrorValidator.evWrongInputList,
3645 TosaErrorValidator.evWrongOutputList,
3646 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003647 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003648 # Elementwise Ternary operators
3649 "select": {
3650 "op": Op.SELECT,
3651 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003652 "build_fcn": (
3653 build_select,
3654 TosaTensorGen.tgBroadcastFuzz,
3655 TosaTensorValuesGen.tvgSelect,
3656 None,
3657 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003658 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003659 "error_if_validators": (
3660 TosaErrorValidator.evRankMismatch,
3661 TosaErrorValidator.evWrongInputType,
3662 TosaErrorValidator.evWrongOutputType,
3663 TosaErrorValidator.evWrongInputList,
3664 TosaErrorValidator.evWrongOutputList,
3665 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003666 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003667 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003668 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003669 # Comparison operators
3670 "equal": {
3671 "op": Op.EQUAL,
3672 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003673 "build_fcn": (
3674 build_comparison,
3675 TosaTensorGen.tgBroadcastFuzz,
3676 TosaTensorValuesGen.tvgEqual,
3677 None,
3678 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003679 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003680 "error_if_validators": (
3681 TosaErrorValidator.evRankMismatch,
3682 TosaErrorValidator.evWrongInputType,
3683 TosaErrorValidator.evWrongOutputType,
3684 TosaErrorValidator.evWrongInputList,
3685 TosaErrorValidator.evWrongOutputList,
3686 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003687 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003688 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003689 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003690 "greater_equal": {
3691 "op": Op.GREATER_EQUAL,
3692 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003693 "build_fcn": (
3694 build_comparison,
3695 TosaTensorGen.tgBroadcastFuzz,
3696 TosaTensorValuesGen.tvgDefault,
3697 None,
3698 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003699 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003700 "error_if_validators": (
3701 TosaErrorValidator.evRankMismatch,
3702 TosaErrorValidator.evWrongInputType,
3703 TosaErrorValidator.evWrongOutputType,
3704 TosaErrorValidator.evWrongInputList,
3705 TosaErrorValidator.evWrongOutputList,
3706 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003707 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003708 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003709 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003710 "greater": {
3711 "op": Op.GREATER,
3712 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003713 "build_fcn": (
3714 build_comparison,
3715 TosaTensorGen.tgBroadcastFuzz,
3716 TosaTensorValuesGen.tvgDefault,
3717 None,
3718 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003719 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003720 "error_if_validators": (
3721 TosaErrorValidator.evRankMismatch,
3722 TosaErrorValidator.evWrongInputType,
3723 TosaErrorValidator.evWrongOutputType,
3724 TosaErrorValidator.evWrongInputList,
3725 TosaErrorValidator.evWrongOutputList,
3726 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003727 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003728 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003729 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003730 # Reduction operators
3731 "reduce_all": {
3732 "op": Op.REDUCE_ALL,
3733 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003734 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003735 "build_fcn": (
3736 build_reduce,
3737 TosaTensorGen.tgBasic,
3738 TosaTensorValuesGen.tvgDefault,
3739 TosaArgGen.agAxis,
3740 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003741 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003742 "error_if_validators": (
3743 TosaErrorValidator.evAxisLargerRank,
3744 TosaErrorValidator.evAxisSmallerZero,
3745 TosaErrorValidator.evShapeOfAxisNotOne,
3746 TosaErrorValidator.evWrongInputType,
3747 TosaErrorValidator.evWrongOutputType,
3748 TosaErrorValidator.evWrongRank,
3749 TosaErrorValidator.evWrongInputList,
3750 TosaErrorValidator.evWrongOutputList,
3751 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003752 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003753 "reduce_any": {
3754 "op": Op.REDUCE_ANY,
3755 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003756 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003757 "build_fcn": (
3758 build_reduce,
3759 TosaTensorGen.tgBasic,
3760 TosaTensorValuesGen.tvgDefault,
3761 TosaArgGen.agAxis,
3762 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003763 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003764 "error_if_validators": (
3765 TosaErrorValidator.evAxisLargerRank,
3766 TosaErrorValidator.evAxisSmallerZero,
3767 TosaErrorValidator.evShapeOfAxisNotOne,
3768 TosaErrorValidator.evWrongInputType,
3769 TosaErrorValidator.evWrongOutputType,
3770 TosaErrorValidator.evWrongRank,
3771 TosaErrorValidator.evWrongInputList,
3772 TosaErrorValidator.evWrongOutputList,
3773 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003774 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003775 "reduce_max": {
3776 "op": Op.REDUCE_MAX,
3777 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003778 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003779 "build_fcn": (
3780 build_reduce,
3781 TosaTensorGen.tgBasic,
3782 TosaTensorValuesGen.tvgDefault,
3783 TosaArgGen.agAxis,
3784 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003785 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003786 "error_if_validators": (
3787 TosaErrorValidator.evAxisLargerRank,
3788 TosaErrorValidator.evAxisSmallerZero,
3789 TosaErrorValidator.evShapeOfAxisNotOne,
3790 TosaErrorValidator.evWrongInputType,
3791 TosaErrorValidator.evWrongOutputType,
3792 TosaErrorValidator.evWrongRank,
3793 TosaErrorValidator.evWrongInputList,
3794 TosaErrorValidator.evWrongOutputList,
3795 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003796 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003797 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003798 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003799 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003800 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003801 "build_fcn": (
3802 build_reduce,
3803 TosaTensorGen.tgBasic,
3804 TosaTensorValuesGen.tvgDefault,
3805 TosaArgGen.agAxis,
3806 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003807 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003808 "error_if_validators": (
3809 TosaErrorValidator.evAxisLargerRank,
3810 TosaErrorValidator.evAxisSmallerZero,
3811 TosaErrorValidator.evShapeOfAxisNotOne,
3812 TosaErrorValidator.evWrongInputType,
3813 TosaErrorValidator.evWrongOutputType,
3814 TosaErrorValidator.evWrongRank,
3815 TosaErrorValidator.evWrongInputList,
3816 TosaErrorValidator.evWrongOutputList,
3817 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003818 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003819 "reduce_product": {
3820 "op": Op.REDUCE_PRODUCT,
3821 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003822 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003823 "build_fcn": (
3824 build_reduce,
3825 TosaTensorGen.tgBasic,
3826 TosaTensorValuesGen.tvgDefault,
3827 TosaArgGen.agAxis,
3828 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003829 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003830 "error_if_validators": (
3831 TosaErrorValidator.evAxisLargerRank,
3832 TosaErrorValidator.evAxisSmallerZero,
3833 TosaErrorValidator.evShapeOfAxisNotOne,
3834 TosaErrorValidator.evWrongInputType,
3835 TosaErrorValidator.evWrongOutputType,
3836 TosaErrorValidator.evWrongRank,
3837 TosaErrorValidator.evWrongInputList,
3838 TosaErrorValidator.evWrongOutputList,
3839 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003840 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003841 "reduce_sum": {
3842 "op": Op.REDUCE_SUM,
3843 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003844 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003845 "build_fcn": (
3846 build_reduce,
3847 TosaTensorGen.tgBasic,
3848 TosaTensorValuesGen.tvgReduceSum,
3849 TosaArgGen.agAxis,
3850 ),
James Ward24dbc422022-10-19 12:20:31 +01003851 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003852 "error_if_validators": (
3853 TosaErrorValidator.evAxisLargerRank,
3854 TosaErrorValidator.evAxisSmallerZero,
3855 TosaErrorValidator.evShapeOfAxisNotOne,
3856 TosaErrorValidator.evWrongInputType,
3857 TosaErrorValidator.evWrongOutputType,
3858 TosaErrorValidator.evWrongRank,
3859 TosaErrorValidator.evWrongInputList,
3860 TosaErrorValidator.evWrongOutputList,
3861 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003862 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003863 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003864 "concat": {
3865 "op": Op.CONCAT,
3866 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003867 "build_fcn": (
3868 build_concat,
3869 TosaTensorGen.tgConcat,
3870 TosaTensorValuesGen.tvgConcat,
3871 TosaArgGen.agAxis,
3872 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003873 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003874 "error_if_validators": (
3875 TosaErrorValidator.evAxisLargerRank,
3876 TosaErrorValidator.evAxisSmallerZero,
3877 TosaErrorValidator.evConcatInputRankMismatch,
3878 TosaErrorValidator.evConcatShapeSumMismatch,
3879 TosaErrorValidator.evConcatInputDimMismatch,
3880 TosaErrorValidator.evWrongInputType,
3881 TosaErrorValidator.evWrongOutputType,
3882 TosaErrorValidator.evWrongOutputList,
3883 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003884 },
3885 "pad": {
3886 "op": Op.PAD,
3887 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003888 "build_fcn": (
3889 build_pad,
3890 TosaTensorGen.tgBasic,
3891 TosaTensorValuesGen.tvgDefault,
3892 TosaArgGen.agPad,
3893 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003894 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003895 "error_if_validators": (
3896 TosaErrorValidator.evWrongInputType,
3897 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003898 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003899 TosaErrorValidator.evWrongOutputType,
3900 TosaErrorValidator.evWrongInputList,
3901 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003902 TosaErrorValidator.evRankMismatch,
3903 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003904 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003905 },
Won Jeona21b2e82023-08-10 10:33:01 +00003906 "dim": {
3907 "op": Op.DIM,
3908 "operands": (1, 0),
3909 "build_fcn": (
3910 build_dim,
3911 TosaTensorGen.tgBasic,
3912 TosaTensorValuesGen.tvgDefault,
3913 TosaArgGen.agAxis,
3914 ),
3915 "types": TYPE_FIB,
3916 "error_if_validators": (
3917 TosaErrorValidator.evAxisLargerRank,
3918 TosaErrorValidator.evAxisSmallerZero,
3919 TosaErrorValidator.evWrongInputType,
3920 TosaErrorValidator.evWrongInputList,
3921 TosaErrorValidator.evWrongOutputList,
3922 TosaErrorValidator.evWrongRank,
3923 ),
3924 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003925 "reshape": {
3926 "op": Op.RESHAPE,
3927 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003928 "build_fcn": (
3929 build_reshape,
3930 TosaTensorGen.tgBasic,
3931 TosaTensorValuesGen.tvgDefault,
3932 TosaArgGen.agReshape,
3933 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003934 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003935 "error_if_validators": (
3936 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3937 TosaErrorValidator.evWrongInputType,
3938 TosaErrorValidator.evWrongOutputType,
3939 TosaErrorValidator.evWrongInputList,
3940 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00003941 TosaErrorValidator.evReshapeOutputSizeMultiInference,
3942 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003943 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003944 },
3945 "reverse": {
3946 "op": Op.REVERSE,
3947 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003948 "build_fcn": (
3949 build_reverse,
3950 TosaTensorGen.tgBasic,
3951 TosaTensorValuesGen.tvgDefault,
3952 TosaArgGen.agAxis,
3953 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003954 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003955 "error_if_validators": (
3956 TosaErrorValidator.evAxisSmallerZero,
3957 TosaErrorValidator.evAxisLargerRank,
3958 TosaErrorValidator.evWrongInputType,
3959 TosaErrorValidator.evWrongOutputType,
3960 TosaErrorValidator.evWrongInputList,
3961 TosaErrorValidator.evWrongOutputList,
3962 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003963 },
3964 "slice": {
3965 "op": Op.SLICE,
3966 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003967 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003968 "build_fcn": (
3969 build_slice,
3970 TosaTensorGen.tgBasic,
3971 TosaTensorValuesGen.tvgDefault,
3972 TosaArgGen.agSlice,
3973 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003974 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003975 "error_if_validators": (
3976 TosaErrorValidator.evStartSmallerZero,
3977 TosaErrorValidator.evSizeSmallerEqualZero,
3978 TosaErrorValidator.evStartSizeOutsideBounds,
3979 TosaErrorValidator.evSizeOutputShapeMismatch,
3980 TosaErrorValidator.evInputSizeStartLengthMismatch,
3981 TosaErrorValidator.evWrongRank,
3982 TosaErrorValidator.evWrongInputType,
3983 TosaErrorValidator.evWrongOutputType,
3984 TosaErrorValidator.evWrongInputList,
3985 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003986 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003987 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003988 },
3989 "tile": {
3990 "op": Op.TILE,
3991 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003992 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003993 "build_fcn": (
3994 build_tile,
3995 TosaTensorGen.tgBasic,
3996 TosaTensorValuesGen.tvgDefault,
3997 TosaArgGen.agTile,
3998 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003999 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004000 "error_if_validators": (
4001 TosaErrorValidator.evWrongInputType,
4002 TosaErrorValidator.evWrongOutputType,
4003 TosaErrorValidator.evWrongInputList,
4004 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004005 TosaErrorValidator.evRankMismatch,
4006 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004007 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004008 },
4009 "transpose": {
4010 "op": Op.TRANSPOSE,
4011 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004012 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004013 "build_fcn": (
4014 build_transpose,
4015 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004016 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004017 TosaArgGen.agTranspose,
4018 ),
4019 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004020 "error_if_validators": (
4021 TosaErrorValidator.evIndexOutsideBounds,
4022 TosaErrorValidator.evIndexUsedTwice,
4023 TosaErrorValidator.evWrongInputType,
4024 TosaErrorValidator.evWrongOutputType,
4025 TosaErrorValidator.evWrongInputList,
4026 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004027 TosaErrorValidator.evWrongRank,
4028 TosaErrorValidator.evRankMismatch,
4029 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004030 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004031 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004032 # Data nodes
4033 "const": {
4034 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004035 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004036 "build_fcn": (
4037 build_const,
4038 TosaTensorGen.tgBasic,
4039 TosaTensorValuesGen.tvgDefault,
4040 None,
4041 ),
Luke Hutton65872422023-02-20 10:33:04 +00004042 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08004043 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004044 "identity": {
4045 "op": Op.IDENTITY,
4046 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004047 "build_fcn": (
4048 build_unary,
4049 TosaTensorGen.tgBasic,
4050 TosaTensorValuesGen.tvgDefault,
4051 None,
4052 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004053 "types": TYPE_FIB,
4054 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004055 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004056 "gather": {
4057 "op": Op.GATHER,
4058 # Only specify 'values' tensor here. 'indices' is generated in op building stage
4059 "operands": (1, 0),
4060 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004061 "build_fcn": (
4062 build_gather,
4063 TosaTensorGen.tgBasic,
4064 TosaTensorValuesGen.tvgDefault,
4065 None,
4066 ),
James Ward24dbc422022-10-19 12:20:31 +01004067 "types": (
4068 DType.INT8,
4069 DType.INT16,
4070 DType.INT32,
4071 DType.FP16,
4072 DType.BF16,
4073 DType.FP32,
4074 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004075 "error_if_validators": (
4076 TosaErrorValidator.evWrongInputType,
4077 TosaErrorValidator.evWrongOutputType,
4078 TosaErrorValidator.evWrongInputList,
4079 TosaErrorValidator.evWrongOutputList,
4080 TosaErrorValidator.evWrongRank,
4081 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004082 },
4083 "scatter": {
4084 "op": Op.SCATTER,
4085 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004086 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08004087 "operands": (2, 0),
4088 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004089 "build_fcn": (
4090 build_scatter,
4091 TosaTensorGen.tgScatter,
4092 TosaTensorValuesGen.tvgDefault,
4093 None,
4094 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004095 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004096 "error_if_validators": (
4097 TosaErrorValidator.evWrongInputType,
4098 TosaErrorValidator.evWrongOutputType,
4099 TosaErrorValidator.evWrongInputList,
4100 TosaErrorValidator.evWrongOutputList,
4101 TosaErrorValidator.evWrongRank,
4102 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004103 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004104 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004105 "resize": {
4106 "op": Op.RESIZE,
4107 "operands": (1, 0),
4108 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004109 "build_fcn": (
4110 build_resize,
4111 TosaTensorGen.tgNHWC,
4112 TosaTensorValuesGen.tvgDefault,
4113 TosaArgGen.agResize,
4114 ),
James Ward24dbc422022-10-19 12:20:31 +01004115 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004116 "invalid_test_validators": (
4117 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004118 ),
4119 "error_if_validators": (
4120 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004121 TosaErrorValidator.evScaleSmallerEqualZero,
4122 TosaErrorValidator.evScaleNLargerMax,
4123 TosaErrorValidator.evScaleDLargerMax,
4124 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004125 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004126 TosaErrorValidator.evBorderSmallerMin,
4127 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004128 TosaErrorValidator.evWrongInputType,
4129 TosaErrorValidator.evWrongOutputType,
4130 TosaErrorValidator.evWrongRank,
4131 TosaErrorValidator.evWrongInputList,
4132 TosaErrorValidator.evWrongOutputList,
4133 TosaErrorValidator.evBatchMismatch,
4134 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004135 TosaErrorValidator.evResizeOutputShapeMismatch,
4136 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004137 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004138 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004139 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004140 "cast": {
4141 "op": Op.CAST,
4142 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004143 "build_fcn": (
4144 build_cast,
4145 TosaTensorGen.tgBasic,
4146 TosaTensorValuesGen.tvgDefault,
4147 TosaArgGen.agCast,
4148 ),
James Ward8b390432022-08-12 20:48:56 +01004149 "types": (
4150 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004151 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004152 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004153 DType.INT8,
4154 DType.INT16,
4155 DType.INT32,
4156 DType.BOOL,
4157 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004158 "error_if_validators": (
4159 TosaErrorValidator.evWrongInputType,
4160 TosaErrorValidator.evWrongOutputType,
4161 TosaErrorValidator.evWrongInputList,
4162 TosaErrorValidator.evWrongOutputList,
4163 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004164 },
4165 "rescale": {
4166 "op": Op.RESCALE,
4167 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004168 "build_fcn": (
4169 build_rescale,
4170 TosaTensorGen.tgBasic,
4171 TosaTensorValuesGen.tvgDefault,
4172 TosaArgGen.agRescale,
4173 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004174 "types": [
4175 DType.UINT8,
4176 DType.INT8,
4177 DType.INT16,
4178 DType.INT32,
4179 DType.INT48,
4180 DType.UINT16,
4181 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004182 "error_if_validators": (
4183 TosaErrorValidator.evInputZeroPointNotZero,
4184 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004185 TosaErrorValidator.evU16InputZeroPointNotValid,
4186 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004187 TosaErrorValidator.evScaleTrue,
4188 TosaErrorValidator.evScaleNotTrue,
4189 TosaErrorValidator.evWrongInputType,
4190 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004191 TosaErrorValidator.evWrongInputList,
4192 TosaErrorValidator.evWrongOutputList,
4193 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004194 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004195 # Custom
4196 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004197 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004198 # Two varients of cond_if, one that generates one of two constant tensors (no
4199 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4200 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004201 "cond_if_const": {
4202 "op": Op.COND_IF,
4203 "operands": (0, 2),
4204 "build_fcn": (
4205 build_cond_if_const,
4206 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004207 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004208 TosaArgGen.agCondIf,
4209 ),
4210 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004211 "error_if_validators": (
4212 TosaErrorValidator.evOutputListThenGraphMismatch,
4213 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004214 TosaErrorValidator.evCondIfCondNotMatchingBool,
4215 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004216 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004217 },
4218 "cond_if_binary": {
4219 "op": Op.COND_IF,
4220 "operands": (2, 0),
4221 "build_fcn": (
4222 build_cond_if_binary,
4223 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004224 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004225 TosaArgGen.agCondIf,
4226 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004227 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004228 "error_if_validators": (
4229 TosaErrorValidator.evInputListThenGraphMismatch,
4230 TosaErrorValidator.evInputListElseGraphMismatch,
4231 TosaErrorValidator.evOutputListThenGraphMismatch,
4232 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004233 TosaErrorValidator.evCondIfCondNotMatchingBool,
4234 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004235 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004236 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004237 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004238 "while_loop": {
4239 "op": Op.WHILE_LOOP,
4240 "operands": (0, 1),
4241 "build_fcn": (
4242 build_while_loop,
4243 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004244 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004245 TosaArgGen.agWhileLoop,
4246 ),
4247 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004248 "error_if_validators": (
4249 TosaErrorValidator.evInputListOutputListMismatch,
4250 TosaErrorValidator.evInputListCondGraphMismatch,
4251 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4252 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4253 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004254 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004255 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004256 },
Luke Hutton57287132023-02-06 14:54:18 +00004257 "fft2d": {
4258 "op": Op.FFT2D,
4259 "operands": (2, 0),
4260 "rank": (3, 3),
4261 "build_fcn": (
4262 build_fft2d,
4263 TosaTensorGen.tgFFT2d,
4264 TosaTensorValuesGen.tvgDefault,
4265 TosaArgGen.agFFT2d,
4266 ),
4267 "types": [DType.FP32],
4268 "error_if_validators": (
4269 TosaErrorValidator.evWrongInputType,
4270 TosaErrorValidator.evWrongOutputType,
4271 TosaErrorValidator.evWrongInputList,
4272 TosaErrorValidator.evWrongOutputList,
4273 TosaErrorValidator.evWrongRank,
4274 TosaErrorValidator.evBatchMismatch,
4275 TosaErrorValidator.evKernelNotPowerOfTwo,
4276 TosaErrorValidator.evFFTInputShapeMismatch,
4277 TosaErrorValidator.evFFTOutputShapeMismatch,
4278 ),
4279 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004280 "rfft2d": {
4281 "op": Op.RFFT2D,
4282 "operands": (1, 0),
4283 "rank": (3, 3),
4284 "build_fcn": (
4285 build_rfft2d,
4286 TosaTensorGen.tgRFFT2d,
4287 TosaTensorValuesGen.tvgDefault,
4288 TosaArgGen.agNone,
4289 ),
4290 "types": [DType.FP32],
4291 "error_if_validators": (
4292 TosaErrorValidator.evWrongInputType,
4293 TosaErrorValidator.evWrongOutputType,
4294 TosaErrorValidator.evWrongInputList,
4295 TosaErrorValidator.evWrongOutputList,
4296 TosaErrorValidator.evWrongRank,
4297 TosaErrorValidator.evBatchMismatch,
4298 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004299 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004300 ),
4301 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004302 }
4303
Kevin Cheng550ccc52021-03-03 11:21:43 -08004304
Eric Kunzee5e26762020-10-13 16:11:07 -07004305class OutputShaper:
4306 # Methods in this class compute the expected output shape and datatype
4307 # for common classes of operations
4308 def __init__(self):
4309 pass
4310
4311 # These methods return arguments that can be used for
4312 # creating a new output tensor
4313 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004314 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4315 if error_name != ErrorIf.RankMismatch:
4316 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004317 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004318
4319 shape = []
4320 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004321 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004322 shape.append(b.shape[i])
4323 else:
4324 shape.append(a.shape[i])
4325
Jerry Ge135c9552023-05-23 20:59:32 +00004326 fuzz_idx = rng.integers(0, len(a.shape))
4327 if error_name == ErrorIf.DimensionMismatch:
4328 shape[fuzz_idx] += 1
4329
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004330 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004331 all_dtypes = [
4332 DType.INT8,
4333 DType.INT16,
4334 DType.INT32,
4335 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004336 DType.FP16,
4337 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004338 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004339 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004340 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4341 outputDType = rng.choice(wrong_dtypes)
4342 else:
4343 outputDType = a.dtype
4344
4345 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004346
4347 @staticmethod
4348 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004349 assert len(a.shape) == len(b.shape)
4350 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004351
4352 shape = []
4353 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004354 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004355 shape.append(a.shape[i])
4356
Kevin Cheng550ccc52021-03-03 11:21:43 -08004357 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004358
4359 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004360 def unaryOp(ser, rng, a, error_name=None):
4361 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004362 all_dtypes = [
4363 DType.INT8,
4364 DType.INT16,
4365 DType.INT32,
4366 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004367 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004368 DType.FP16,
4369 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004370 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004371 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4372 outputDType = rng.choice(wrong_dtypes)
4373 else:
4374 outputDType = a.dtype
4375
4376 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004377
4378 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004379 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004380 if error_name != ErrorIf.RankMismatch:
4381 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004382 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004383
4384 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004385 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004386 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004387 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4388 else:
4389 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004390
Jerry Ge135c9552023-05-23 20:59:32 +00004391 fuzz_idx = rng.integers(0, len(a.shape))
4392 if error_name == ErrorIf.DimensionMismatch:
4393 shape[fuzz_idx] += 1
4394
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004395 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004396 all_dtypes = [
4397 DType.INT8,
4398 DType.INT16,
4399 DType.INT32,
4400 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004401 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004402 DType.FP16,
4403 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004404 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004405 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4406 outputDType = rng.choice(wrong_dtypes)
4407 else:
4408 outputDType = a.dtype
4409
4410 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004411
4412 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004413 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004414 if error_name != ErrorIf.RankMismatch:
4415 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004416 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004417
4418 # Do broadcast
4419 shape = []
4420 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004421 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004422 shape.append(b.shape[i])
4423 else:
4424 shape.append(a.shape[i])
4425
Jerry Ge135c9552023-05-23 20:59:32 +00004426 fuzz_idx = rng.integers(0, len(a.shape))
4427 if error_name == ErrorIf.DimensionMismatch:
4428 shape[fuzz_idx] += 1
4429
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004430 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004431 wrong_dtypes = [
4432 DType.INT8,
4433 DType.INT16,
4434 DType.INT32,
4435 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004436 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004437 DType.FP16,
4438 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004439 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004440 outputDType = rng.choice(wrong_dtypes)
4441 else:
4442 outputDType = DType.BOOL
4443
4444 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004445
4446 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004447 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004448 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004449 if error_name not in [
4450 ErrorIf.AxisSmallerZero,
4451 ErrorIf.AxisLargerRank,
4452 ErrorIf.ShapeOfAxisNotOne,
4453 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004454 shape[axis] = 1
4455 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4456 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004457
Matthew Haddond6ce7252021-09-29 15:35:44 +01004458 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004459 all_dtypes = [
4460 DType.INT8,
4461 DType.INT16,
4462 DType.INT32,
4463 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004464 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004465 DType.FP16,
4466 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004467 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004468 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4469 outputDType = rng.choice(wrong_dtypes)
4470 else:
4471 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004472
Matthew Haddond6ce7252021-09-29 15:35:44 +01004473 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004474
4475 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004476 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004477 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004478
4479 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4480 del shape[axis]
4481
4482 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4483 remove = rng.choice([True, False])
4484 if remove and len(shape) > 1:
4485 del shape[0]
4486 else:
4487 shape.append(1)
4488 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4489 for i in range(len(shape)):
4490 shape[i] = shape[i] + rng.integers(1, 10)
4491
4492 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004493 all_dtypes = [
4494 DType.INT8,
4495 DType.INT16,
4496 DType.INT32,
4497 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004498 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004499 DType.FP16,
4500 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004501 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004502 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4503 outputDType = rng.choice(wrong_dtypes)
4504 else:
4505 outputDType = DType.INT32
4506
4507 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004508
4509 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004510 def conv2dOp(
4511 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4512 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004513
4514 # IFM: NHWC
4515 # Filter: OHWI
4516 # OFM: NHWC
4517
Kevin Cheng550ccc52021-03-03 11:21:43 -08004518 h = (
4519 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004520 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004521 + padding[0]
4522 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004523 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004524 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004525
Kevin Cheng550ccc52021-03-03 11:21:43 -08004526 w = (
4527 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004528 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004529 + padding[2]
4530 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004531 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004532 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004533
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004534 if error_name == ErrorIf.ConvOutputShapeMismatch:
4535 choices = [1, 2, 3]
4536 change = rng.choice(choices)
4537 # increment in multiples of stride to not hit non-integer error case
4538 if change in [1, 3]:
4539 h = h + (rng.choice(choices) * strides[0])
4540 if change in [2, 3]:
4541 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004542
Eric Kunzee5e26762020-10-13 16:11:07 -07004543 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4544
James Ward8b390432022-08-12 20:48:56 +01004545 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004546 # Pick some potentially correct output dtype if input type is incorrect
4547 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004548 else:
James Ward8b390432022-08-12 20:48:56 +01004549 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004550
4551 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004552 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004553 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004554 else:
4555 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004556 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004557 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004558
Kevin Cheng550ccc52021-03-03 11:21:43 -08004559 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004560
4561 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004562 def conv3dOp(
4563 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4564 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004565
4566 # IFM: NDHWC
4567 # Filter: ODHWI
4568 # OFM: NDHWC
4569
4570 d = (
4571 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004572 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004573 + padding[0]
4574 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004575 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004576 ) // strides[0] + 1
4577
4578 h = (
4579 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004580 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004581 + padding[2]
4582 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004583 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004584 ) // strides[1] + 1
4585
4586 w = (
4587 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004588 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004589 + padding[4]
4590 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004591 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004592 ) // strides[2] + 1
4593
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004594 if error_name == ErrorIf.ConvOutputShapeMismatch:
4595 choices = [1, 2, 3, 4]
4596 change = rng.choice(choices)
4597 # increment in multiples of stride to not hit non-integer error case
4598 if change in [1, 4]:
4599 d = d + (rng.choice(choices) * strides[0])
4600 if change in [2, 4]:
4601 h = h + (rng.choice(choices) * strides[1])
4602 if change in [3, 4]:
4603 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004604
Kevin Cheng1533b852021-09-01 12:51:58 -07004605 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4606
James Ward8b390432022-08-12 20:48:56 +01004607 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004608 # Pick some potentially correct output dtype if input type is incorrect
4609 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004610 else:
James Ward8b390432022-08-12 20:48:56 +01004611 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004612
4613 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004614 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004615 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004616 else:
4617 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004618 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004619 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004620
4621 return ser.addOutput(ofm_shape, out_dtype)
4622
4623 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004624 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004625 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004626 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004627 # IFM: NHWC
4628 # Filter: HWCM
4629 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004630
Kevin Cheng550ccc52021-03-03 11:21:43 -08004631 h = (
4632 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004633 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004634 + padding[0]
4635 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004636 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004637 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004638
Kevin Cheng550ccc52021-03-03 11:21:43 -08004639 w = (
4640 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004641 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004642 + padding[2]
4643 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004644 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004645 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004646
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004647 if error_name == ErrorIf.ConvOutputShapeMismatch:
4648 choices = [1, 2, 3]
4649 change = rng.choice(choices)
4650 # increment in multiples of stride to not hit non-integer error case
4651 if change in [1, 3]:
4652 h = h + (rng.choice(choices) * strides[0])
4653 if change in [2, 3]:
4654 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004655
Eric Kunzee5e26762020-10-13 16:11:07 -07004656 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4657
James Ward8b390432022-08-12 20:48:56 +01004658 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004659 # Pick some potentially correct output dtype if input type is incorrect
4660 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004661 else:
James Ward8b390432022-08-12 20:48:56 +01004662 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004663
4664 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004665 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004666 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004667 else:
4668 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01004669 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004670 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004671
Kevin Cheng550ccc52021-03-03 11:21:43 -08004672 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004673
4674 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004675 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004676 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004677 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004678 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004679 h = 1
4680 w = 1
4681 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004682 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4683 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004684
4685 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004686 choices = [1, 2, 3]
4687 change = rng.choice(choices)
4688 # increment in multiples of stride to not hit non-integer error case
4689 if change in [1, 3]:
4690 h = h + (rng.choice(choices) * stride[0])
4691 if change in [2, 3]:
4692 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004693 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004694
4695 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004696 all_dtypes = [
4697 DType.INT8,
4698 DType.INT16,
4699 DType.INT32,
4700 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004701 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004702 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004703 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004704 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004705 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4706 outputDType = rng.choice(wrong_dtypes)
4707 else:
4708 outputDType = ifm.dtype
4709
4710 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004711
4712 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004713 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004714 # input: N, IC
4715 # filter: OC, IC
4716 # output: N, OC
4717
4718 output_shape = [input.shape[0], filter.shape[0]]
4719
James Ward8b390432022-08-12 20:48:56 +01004720 # Validated in arg_gen (also invalidated for ErrorIf)
4721 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004722
Kevin Cheng550ccc52021-03-03 11:21:43 -08004723 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004724
4725 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004726 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004727 # a: N, H, C
4728 # b: N, C, W
4729 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004730
Kevin Cheng2d60f002021-06-09 14:18:32 -07004731 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004732
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004733 if error_name == ErrorIf.WrongOutputType:
4734 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004735 incorrect_types = (
4736 DType.INT4,
4737 DType.INT8,
4738 DType.INT16,
4739 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004740 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004741 DType.FP16,
4742 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004743 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004744 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004745 incorrect_types = (
4746 DType.INT4,
4747 DType.INT8,
4748 DType.INT16,
4749 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004750 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004751 DType.FP16,
4752 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004753 )
James Ward24dbc422022-10-19 12:20:31 +01004754 elif (
4755 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4756 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004757 incorrect_types = (
4758 DType.INT4,
4759 DType.INT8,
4760 DType.INT16,
4761 DType.INT32,
4762 DType.INT48,
4763 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004764 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004765 elif error_name == ErrorIf.WrongInputType:
4766 # Pick some potentially correct output dtype if input type is incorrect
4767 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004768 else:
James Ward8b390432022-08-12 20:48:56 +01004769 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004770
Kevin Cheng550ccc52021-03-03 11:21:43 -08004771 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004772
4773 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004774 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004775 input1 = a[0]
4776 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004777
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004778 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004779 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004780 if not (
4781 # unable to concat tensors of different ranks
4782 error_name == ErrorIf.ConcatInputRankMismatch
4783 # unable to concat tensors along an invalid axis
4784 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004785 ):
4786 for tensor in remaining_inputs:
4787 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004788
Matthew Haddon01c359d2021-10-15 16:30:48 +01004789 if error_name == ErrorIf.ConcatShapeSumMismatch:
4790 output_shape[axis] += rng.integers(5, 10)
4791
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004792 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004793 all_dtypes = {
4794 DType.INT8,
4795 DType.INT16,
4796 DType.INT32,
4797 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004798 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004799 DType.FP16,
4800 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004801 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004802 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4803 outputDType = rng.choice(wrong_dtypes)
4804 else:
4805 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004806
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004807 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004808
4809 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004810 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004811
4812 output_shape = a.shape.copy()
4813
4814 for i in range(len(output_shape)):
4815 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4816
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004817 if error_name == ErrorIf.PadOutputShapeMismatch:
4818 bad_dim = rng.choice(range(len(output_shape)))
4819 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004820 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004821 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004822
Matthew Haddone807aae2021-10-11 18:12:58 +01004823 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004824 all_dtypes = [
4825 DType.INT8,
4826 DType.INT16,
4827 DType.INT32,
4828 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004829 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004830 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004831 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004832 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004833 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4834 outputDType = rng.choice(wrong_dtypes)
4835 else:
4836 outputDType = a.dtype
4837
4838 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004839
4840 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00004841 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly79359722023-08-31 18:17:38 +00004842 output_shape = []
Won Jeona21b2e82023-08-10 10:33:01 +00004843
4844 if error_name == ErrorIf.WrongOutputType:
4845 all_dtypes = [
4846 DType.INT8,
4847 DType.INT16,
4848 DType.INT32,
4849 DType.INT48,
4850 DType.FP32,
4851 DType.FP16,
4852 DType.BF16,
4853 ]
4854 wrong_dtypes = list(set(all_dtypes))
4855 outputDType = rng.choice(wrong_dtypes)
4856 else:
4857 outputDType = DType.SHAPE
4858
4859 return ser.addOutput(output_shape, outputDType)
4860
4861 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004862 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004863 output_shape = shape.copy()
4864
Matthew Haddone807aae2021-10-11 18:12:58 +01004865 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4866 for i in range(len(output_shape)):
4867 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4868
4869 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004870 all_dtypes = [
4871 DType.INT8,
4872 DType.INT16,
4873 DType.INT32,
4874 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004875 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004876 DType.FP16,
4877 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004878 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004879 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4880 outputDType = rng.choice(wrong_dtypes)
4881 else:
4882 outputDType = a.dtype
4883
4884 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004885
4886 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00004887 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004888
Matthew Haddone807aae2021-10-11 18:12:58 +01004889 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004890 all_dtypes = [
4891 DType.INT8,
4892 DType.INT16,
4893 DType.INT32,
4894 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004895 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004896 DType.FP16,
4897 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004898 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00004899 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01004900 outputDType = rng.choice(wrong_dtypes)
4901 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00004902 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01004903
Luke Huttona4e48ca2023-02-22 11:53:48 +00004904 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004905 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01004906 for index in range(len(output_shape)):
4907 if output_shape[index] <= 2:
4908 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4909 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004910 output_shape[index] = output_shape[index] + rng.choice(
4911 [-2, -1, 1, 2]
4912 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00004913 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
4914 output_shape = input.shape.copy()
4915 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004916 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01004917
4918 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004919
4920 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004921 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004922
4923 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004924 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004925
4926 for i in range(len(output_shape)):
4927 output_shape[i] = a.shape[i] * multiples[i]
4928
Luke Huttona4e48ca2023-02-22 11:53:48 +00004929 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004930 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00004931
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004932 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004933 all_dtypes = [
4934 DType.INT8,
4935 DType.INT16,
4936 DType.INT32,
4937 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004938 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004939 DType.FP16,
4940 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004941 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004942 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4943 outputDType = rng.choice(wrong_dtypes)
4944 else:
4945 outputDType = a.dtype
4946
4947 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004948
4949 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004950 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004951 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004952
Kevin Cheng550ccc52021-03-03 11:21:43 -08004953 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004954
Luke Huttona4e48ca2023-02-22 11:53:48 +00004955 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01004956 for i in range(len(output_shape)):
4957 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004958
Luke Huttona4e48ca2023-02-22 11:53:48 +00004959 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4960 for i in range(len(output_shape)):
4961 output_shape[i] += rng.integers(1, 10)
4962 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01004963 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00004964
Matthew Haddone807aae2021-10-11 18:12:58 +01004965 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004966 all_dtypes = [
4967 DType.INT8,
4968 DType.INT16,
4969 DType.INT32,
4970 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004971 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004972 DType.FP16,
4973 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004974 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004975 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4976 outputDType = rng.choice(wrong_dtypes)
4977 else:
4978 outputDType = a.dtype
4979
4980 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004981
4982 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004983 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004984 if error_name != ErrorIf.WrongRank:
4985 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004986 assert len(indices.shape) == 2
4987 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004988
Kevin Cheng77d0f762020-11-24 10:26:32 -08004989 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4990
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004991 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004992 all_dtypes = [
4993 DType.INT8,
4994 DType.INT16,
4995 DType.INT32,
4996 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004997 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004998 DType.FP16,
4999 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005000 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005001 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5002 outputDType = rng.choice(wrong_dtypes)
5003 else:
5004 outputDType = values.dtype
5005
5006 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005007
5008 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005009 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005010 if error_name != ErrorIf.WrongRank:
5011 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005012 assert len(indices.shape) == 2
5013 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005014 assert values_in.shape[0] == indices.shape[0] # N
5015 assert input.shape[1] == indices.shape[1] # W
5016 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005017
5018 output_shape = values_in.shape
5019
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005020 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005021 all_dtypes = [
5022 DType.INT8,
5023 DType.INT16,
5024 DType.INT32,
5025 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005026 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005027 DType.FP16,
5028 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005029 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005030 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5031 outputDType = rng.choice(wrong_dtypes)
5032 else:
5033 outputDType = values_in.dtype
5034
5035 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005036
5037 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005038 def tableOp(ser, rng, input, error_name=None):
5039 # Same shape as the input, dtype dependent on input dtype
5040 if error_name != ErrorIf.WrongInputType:
5041 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005042 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005043 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005044 wrong_dtypes = [
5045 DType.INT8,
5046 DType.INT16,
5047 DType.INT32,
5048 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005049 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005050 DType.FP16,
5051 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005052 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005053 wrong_dtypes.remove(output_dtype)
5054 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005055 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005056
5057 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005058 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005059 serializer,
5060 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005061 input,
5062 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005063 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005064 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005065 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005066 input_dtype,
5067 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005068 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005069 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005070 # Calculate OH, OW
5071 scale_y_n = scale[0]
5072 scale_y_d = scale[1]
5073 scale_x_n = scale[2]
5074 scale_x_d = scale[3]
5075 if error_name == ErrorIf.ScaleSmallerEqualZero:
5076 scale_y_n = max(scale_y_n, 1)
5077 scale_y_d = max(scale_y_d, 1)
5078 scale_x_n = max(scale_x_n, 1)
5079 scale_x_d = max(scale_x_d, 1)
5080
5081 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5082 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5083
5084 if error_name is not None:
5085 # Make sure the output tensor is valid, which can occur when
5086 # scale, offset or border have been changed for ERROR_IFs
5087 oh = max(oh, 1)
5088 ow = max(ow, 1)
5089 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005090 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5091 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005092
5093 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5094 choices = [1, 2, 3]
5095 change = rng.choice(choices)
5096 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5097 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005098 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005099 oh -= scale_y_d
5100 assert oh > 0 # Should have been caught in agResize
5101 else:
5102 oh += scale_y_d
5103 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005104 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005105 ow -= scale_x_d
5106 assert ow > 0 # Should have been caught in agResize
5107 else:
5108 ow += scale_x_d
5109
Matthew Haddon848efb42021-09-09 12:30:53 +01005110 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005111 output_dims = [
5112 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005113 oh,
5114 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005115 input.shape[0],
5116 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005117 elif error_name == ErrorIf.BatchMismatch:
5118 output_dims = [
5119 input.shape[0] + rng.integers(1, 10),
5120 oh,
5121 ow,
5122 input.shape[3],
5123 ]
5124 elif error_name == ErrorIf.ChannelMismatch:
5125 output_dims = [
5126 input.shape[0],
5127 oh,
5128 ow,
5129 input.shape[3] + rng.integers(1, 10),
5130 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005131 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005132 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005133
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005134 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005135
5136 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005137 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005138 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005139
5140 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005141 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005142 if error_name == ErrorIf.ConvOutputShapeMismatch:
5143 choices = [1, 2, 3]
5144 change = rng.choice(choices)
5145 if change in [1, 3]:
5146 output_shape[1] = output_shape[1] + rng.choice(choices)
5147 if change in [2, 3]:
5148 output_shape[2] = output_shape[2] + rng.choice(choices)
5149
James Ward8b390432022-08-12 20:48:56 +01005150 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005151 # Pick some potentially correct output dtype if input type is incorrect
5152 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005153 else:
James Ward8b390432022-08-12 20:48:56 +01005154 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005155
5156 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005157 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005158 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005159 else:
5160 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005161 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005162 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005163
Kevin Cheng550ccc52021-03-03 11:21:43 -08005164 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005165
5166 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005167 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5168 outputs = []
5169
5170 assert ifm1.dtype == ifm2.dtype
5171 input_dtype = ifm1.dtype
5172
5173 if error_name != ErrorIf.FFTInputShapeMismatch:
5174 assert ifm1.shape == ifm2.shape
5175
5176 input_shape = ifm1.shape
5177 if error_name != ErrorIf.WrongRank:
5178 assert len(input_shape) == 3
5179
5180 output_shape = input_shape.copy()
5181 output_dtype = input_dtype
5182
5183 if error_name == ErrorIf.WrongOutputType:
5184 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005185 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005186 output_dtype = rng.choice(wrong_dtypes)
5187 elif error_name == ErrorIf.BatchMismatch:
5188 output_shape[0] += rng.integers(1, 10)
5189 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5190 modify_dim = rng.choice([1, 2])
5191 output_shape[modify_dim] += rng.integers(1, 10)
5192
5193 outputs.append(serializer.addOutput(output_shape, output_dtype))
5194 outputs.append(serializer.addOutput(output_shape, output_dtype))
5195 return outputs
5196
5197 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005198 def rfft2dOp(serializer, rng, value, error_name=None):
5199 outputs = []
5200
5201 input_shape = value.shape
5202 if error_name != ErrorIf.WrongRank:
5203 assert len(input_shape) == 3
5204
5205 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5206
5207 output_dtype = value.dtype
5208 if error_name == ErrorIf.WrongOutputType:
5209 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005210 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005211 output_dtype = rng.choice(wrong_dtypes)
5212 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005213 output_shape[0] += rng.integers(1, 10)
5214 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5215 modify_dim = rng.choice([1, 2])
5216 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005217
5218 outputs.append(serializer.addOutput(output_shape, output_dtype))
5219 outputs.append(serializer.addOutput(output_shape, output_dtype))
5220 return outputs