blob: bc931dcc8724a259317b5c4161c8917fb3558ef3 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
Won Jeon2c34b462024-02-06 18:37:00 +000079 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010080 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Won Jeon2c34b462024-02-06 18:37:00 +0000155 if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000198 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100199 return np.int64(self.rng.integers(low=low, high=high, size=shape))
Won Jeon2c34b462024-02-06 18:37:00 +0000200 elif dtype in (
201 DType.FP16,
202 DType.BF16,
203 DType.FP32,
204 DType.FP8E4M3,
205 DType.FP8E5M2,
206 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100207 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
208
209 if dtype == DType.FP16:
210 return np.float16(f_tensor)
211 else:
212 f32_tensor = np.float32(f_tensor)
213 if dtype == DType.BF16:
214 # Floor the last 16 bits of each f32 value
215 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
Won Jeon2c34b462024-02-06 18:37:00 +0000216 elif dtype == DType.FP8E4M3:
217 return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
218 elif dtype == DType.FP8E5M2:
219 return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
Jeremy Johnson1271c442023-09-05 11:39:26 +0100220 else:
221 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700222 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100223 # All other integer types
224 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700225
Kevin Cheng989cb052021-04-28 16:29:44 -0700226 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700227 placeholders = []
228
Kevin Cheng989cb052021-04-28 16:29:44 -0700229 assert len(shape_list) == len(dtype_list)
230
Jeremy Johnson1271c442023-09-05 11:39:26 +0100231 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700232 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100233 if not self.args.lazy_data_gen:
234 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700235 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700236
237 return placeholders
238
Kevin Cheng989cb052021-04-28 16:29:44 -0700239 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700240 consts = []
241
Kevin Cheng989cb052021-04-28 16:29:44 -0700242 assert len(shape_list) == len(dtype_list)
243
Jeremy Johnson1271c442023-09-05 11:39:26 +0100244 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700245 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100246 if not self.args.lazy_data_gen:
247 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700248 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700249
250 return consts
251
252 def makeShape(self, rank):
253 if self.targetted_shape:
254 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800255 return np.int32(
256 self.rng.integers(
257 low=self.args.tensor_shape_range[0],
258 high=self.args.tensor_shape_range[1],
259 size=rank,
260 )
261 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700262
263 def setTargetShape(self, shape):
264 self.targetted_shape = shape
265
266 def randInt(self, low=0, high=256):
267 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
268
269 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100270 low, high = self.getDTypeRange(dtype)
271
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100272 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100273 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100274 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100275 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100276 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100277 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
278 return gtu.vect_f32_to_bf16(rand_f32)
Won Jeon2c34b462024-02-06 18:37:00 +0000279 elif dtype == DType.FP8E4M3:
280 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
281 return gtu.vect_f32_to_fp8e4m3(rand_f32)
282 elif dtype == DType.FP8E5M2:
283 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
284 return gtu.vect_f32_to_fp8e5m2(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700285 elif dtype == DType.BOOL:
286 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000287 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700288 # Special size
289 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700290
291 return np.int32(self.rng.integers(low, high, size=1))[0]
292
293 def shapeStr(self, shape):
294
295 sStr = []
296 # Convert to strings
297 for i in shape:
298 sStr.append(str(i))
299
Kevin Cheng550ccc52021-03-03 11:21:43 -0800300 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700301
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100302 def typeStr(self, dtype):
303 if isinstance(dtype, list) or isinstance(dtype, tuple):
304 assert len(dtype) >= 2
305 strs = [self.typeStr(t) for t in dtype]
306 # Limit types to the first 2 as the 3rd is the accumulator
307 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700308 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100309 if dtype in gtu.DTYPE_ATTRIBUTES:
310 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700311 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100312 raise Exception(
313 "Unknown dtype, cannot convert to string: {}".format(dtype)
314 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700315
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100316 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100317 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100318 if dtype in gtu.DTYPE_ATTRIBUTES:
319 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700320 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100321 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700322
Luke Hutton57287132023-02-06 14:54:18 +0000323 def constrictBatchSize(self, shape):
324 # Limit the batch size unless an explicit target shape set
325 if self.args.max_batch_size and not self.args.target_shapes:
326 shape[0] = min(shape[0], self.args.max_batch_size)
327 return shape
328
James Ward30124a82023-02-02 14:56:33 +0000329 def makeDimension(self):
330 return self.randInt(
331 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
332 )
333
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100334 def tensorComplianceMetaData(
335 self, op, inputType, argsDict, outputTensor, errorName
336 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000337 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
338 UNSUPPORTED_NON_FP32_INPUT_OPS = (
339 Op.MATMUL,
340 Op.CONV2D,
341 Op.FULLY_CONNECTED,
342 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000343 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000344 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000345 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100346 if (
347 errorName
348 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000349 or (
350 not gtu.dtypeIsSupportedByCompliance(inputType)
351 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
352 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100353 ):
354 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100355 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100356
Jeremy Johnson1271c442023-09-05 11:39:26 +0100357 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100358 compliance_tens = {
359 "mode": None,
360 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
361 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
362 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100363 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
364 mode = gtu.ComplianceMode.DOT_PRODUCT
365 compliance_tens["dot_product_info"] = {
366 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100367 "ks": int(argsDict["ksb"])
368 if "ksb" in argsDict
369 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100370 }
371 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
372 mode = gtu.ComplianceMode.FP_SPECIAL
373 elif "compliance" in op and "ulp" in op["compliance"]:
374 mode = gtu.ComplianceMode.ULP
375 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000376 elif "compliance" in op and "relative" in op["compliance"]:
377 mode = gtu.ComplianceMode.RELATIVE
378 compliance_tens["relative_info"] = {
379 "max": argsDict["max_abs_value"],
380 "scale": op["compliance"]["relative"],
381 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100382 elif op["op"] == Op.REDUCE_PRODUCT:
383 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000384 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000385 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000386 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000387 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
388 compliance_tens["abs_error_info"] = {
389 "lower_bound": op["compliance"]["abs_error_lower_bound"]
390 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100391 else:
392 mode = gtu.ComplianceMode.EXACT
393 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
394
395 return compliance_tens
396
397 # Build Op functions
398 # Create the output tensor (calling OutputShaper as needed)
399 # Do final tweaks to attributes (if necessary for errorIf)
400 # Add Op into graph
401 # Return resulting tensor information or BuildInfo
402
403 class BuildInfo:
404 """Enhanced build information containing result tensor and associated compliance dict."""
405
406 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000407 if isinstance(resultTensor, list):
408 assert complianceDict is None or isinstance(complianceDict, list)
409 self.resultTensorList = resultTensor
410 self.complianceDictList = complianceDict
411 else:
412 self.resultTensorList = [resultTensor]
413 if complianceDict is None:
414 self.complianceDictList = None
415 else:
416 self.complianceDictList = [complianceDict]
417
418 def getComplianceInfo(self):
419 if self.complianceDictList is None:
420 return None
421 else:
422 tens_dict = {}
423 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
424 if comp is not None:
425 tens_dict[tens.name] = comp
426
427 if tens_dict:
428 # Have some compliance data, so return the info
429 compliance = {
430 "version": "0.1",
431 "tensors": tens_dict,
432 }
433 else:
434 compliance = None
435 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700436
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000437 def build_unary(
438 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
439 ):
440 assert len(inputs) == 1
441 a = inputs[0]
442 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100443
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000444 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100445
446 # Ensure new output type has correct qinfo
447 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000448 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000449 qinfo = [
450 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000451 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000452 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100453
454 # Invalidate Input/Output list for error if checks.
455 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000456 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100457 pCount, cCount = op["operands"]
458 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000459 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
460 self, error_name, input_list, output_list
461 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100462
Les Bell729b0352021-11-24 10:28:21 +0000463 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100464 self.ser,
465 validator_fcns,
466 error_name,
467 op=op,
468 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000469 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000470 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000471 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100472 input_list=input_list,
473 output_list=output_list,
474 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000475 ):
476 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100477
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000478 attr = None
479 if op["op"] == Op.NEGATE:
480 attr = ts.TosaSerializerAttribute()
481 attr.NegateAttribute(qinfo[0], qinfo[1])
482
483 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000484
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000485 compliance = self.tensorComplianceMetaData(
486 op, a.dtype, args_dict, result_tensor, error_name
487 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000488 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700489
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000490 def build_binary_broadcast(
491 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
492 ):
493 assert len(inputs) == 2
494 a, b = inputs
495 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000496 self.ser, self.rng, a, b, error_name
497 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100498
499 # Invalidate Input/Output list for error if checks.
500 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000501 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100502 pCount, cCount = op["operands"]
503 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000504 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
505 self, error_name, input_list, output_list
506 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100507
Les Bell729b0352021-11-24 10:28:21 +0000508 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100509 self.ser,
510 validator_fcns,
511 error_name,
512 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000513 input1=a,
514 input2=b,
515 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000516 output_dtype=result_tensor.dtype,
517 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100518 input_list=input_list,
519 output_list=output_list,
520 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000521 ):
522 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100523
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000524 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000525
Jeremy Johnson9a758382023-11-07 16:27:35 +0000526 compliance = self.tensorComplianceMetaData(
527 op, a.dtype, args_dict, result_tensor, error_name
528 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000529
530 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700531
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100532 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700533 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000534 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700535 return result_tens
536
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000537 def build_arithmetic_right_shift(
Jeremy Johnson587cc842024-02-08 11:45:44 +0000538 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000539 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000540 assert len(inputs) == 2
541 a, b = inputs
542 round = args_dict["round"]
543 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000544 self.ser, self.rng, a, b, error_name
545 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100546
547 # Invalidate Input/Output list for error if checks.
548 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000549 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100550 pCount, cCount = op["operands"]
551 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000552 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
553 self, error_name, input_list, output_list
554 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100555
Les Bell729b0352021-11-24 10:28:21 +0000556 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100557 self.ser,
558 validator_fcns,
559 error_name,
560 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000561 input1=a,
562 input2=b,
563 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000564 output_dtype=result_tensor.dtype,
565 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100566 input_list=input_list,
567 output_list=output_list,
568 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000569 ):
570 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800571
572 attr = ts.TosaSerializerAttribute()
573 attr.ArithmeticRightShiftAttribute(round)
574
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000575 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000576
577 compliance = self.tensorComplianceMetaData(
578 op, a.dtype, args_dict, result_tensor, error_name
579 )
580
581 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800582
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100583 def build_mul(
584 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
585 ):
586 assert len(inputs) == 2
587 a, b = inputs
588 shift = args_dict["shift"]
589
590 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000591 self.ser, self.rng, a, b, error_name
592 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700593
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100594 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100595 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100596 result_tensor.setDtype(DType.INT32)
597
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100598 if error_name == ErrorIf.WrongOutputType:
599 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
600 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100601 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100602
603 # Invalidate Input/Output list for error if checks.
604 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100605 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100606 pCount, cCount = op["operands"]
607 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000608 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
609 self, error_name, input_list, output_list
610 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100611
Les Bell729b0352021-11-24 10:28:21 +0000612 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100613 self.ser,
614 validator_fcns,
615 error_name,
616 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000617 input1=a,
618 input2=b,
619 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100620 output_dtype=result_tensor.dtype,
621 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622 input_list=input_list,
623 output_list=output_list,
624 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000625 ):
626 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700627
Kevin Chengaee1fac2020-11-11 13:54:06 -0800628 attr = ts.TosaSerializerAttribute()
629 attr.MulAttribute(shift)
630
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000631 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100632
633 compliance = self.tensorComplianceMetaData(
634 op, a.dtype, args_dict, result_tensor, error_name
635 )
636
637 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700638
Jeremy Johnson587cc842024-02-08 11:45:44 +0000639 def build_table(
640 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
641 ):
642 assert len(inputs) == 1
643 a = inputs[0]
644 table = args_dict["table"]
645 result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700646
Kevin Chengfe392ce2021-10-18 21:51:55 +0000647 attr = ts.TosaSerializerAttribute()
648 attr.TableAttribute(table)
649
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100650 # Invalidate Input/Output list for error if checks.
651 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000652 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100653 pCount, cCount = op["operands"]
654 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000655 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
656 self, error_name, input_list, output_list
657 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100658
Les Bell729b0352021-11-24 10:28:21 +0000659 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100660 self.ser,
661 validator_fcns,
662 error_name,
663 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000664 input_shape=a.shape,
665 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000666 output_dtype=result_tensor.dtype,
667 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100668 input_list=input_list,
669 output_list=output_list,
670 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000671 ):
672 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100673
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000674 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700675
Jeremy Johnson587cc842024-02-08 11:45:44 +0000676 compliance = self.tensorComplianceMetaData(
677 op, a.dtype, args_dict, result_tensor, error_name
678 )
679
680 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700681
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000682 def build_select(
683 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
684 ):
685 assert len(inputs) == 3
686 cond, a, b = inputs
687
688 result_tensor = OutputShaper.selectOp(
689 self.ser, self.rng, cond, a, b, error_name
690 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100691
692 # Invalidate Input/Output list for error if checks.
693 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000694 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100695 pCount, cCount = op["operands"]
696 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000697 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
698 self, error_name, input_list, output_list
699 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100700
Les Bell729b0352021-11-24 10:28:21 +0000701 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100702 self.ser,
703 validator_fcns,
704 error_name,
705 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000706 input1=cond,
707 input2=a,
708 input3=b,
709 input_shape=a.shape,
710 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000711 output_dtype=result_tensor.dtype,
712 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100713 input_list=input_list,
714 output_list=output_list,
715 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000716 ):
717 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100718
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000719 self.ser.addOperator(
720 op["op"],
721 input_list,
722 output_list,
723 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000724 compliance = self.tensorComplianceMetaData(
725 op, a.dtype, args_dict, result_tensor, error_name
726 )
727
728 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700729
Jeremy Johnsona0150012023-11-15 15:52:06 +0000730 def build_comparison(
731 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
732 ):
733 assert len(inputs) == 2
734 a, b = inputs
735
736 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000737 self.ser, self.rng, a, b, error_name
738 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100739
740 # Invalidate Input/Output list for error if checks.
741 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000742 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100743 pCount, cCount = op["operands"]
744 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000745 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
746 self, error_name, input_list, output_list
747 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100748
Les Bell729b0352021-11-24 10:28:21 +0000749 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100750 self.ser,
751 validator_fcns,
752 error_name,
753 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000754 input1=a,
755 input2=b,
756 input_shape=a.shape,
757 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000758 output_shape=result_tensor.shape,
759 output_dtype=result_tensor.dtype,
760 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100761 input_list=input_list,
762 output_list=output_list,
763 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000764 ):
765 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100766
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000767 self.ser.addOperator(
768 op["op"],
769 input_list,
770 output_list,
771 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000772
773 compliance = self.tensorComplianceMetaData(
774 op, a.dtype, args_dict, result_tensor, error_name
775 )
776 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700777
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000778 def build_argmax(
779 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
780 ):
781 assert len(inputs) == 1
782 a = inputs[0]
783 axis = args_dict["axis"]
784 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100785
786 # Invalidate Input/Output list for error if checks.
787 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000788 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100789 pCount, cCount = op["operands"]
790 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000791 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
792 self, error_name, input_list, output_list
793 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100794
Les Bell729b0352021-11-24 10:28:21 +0000795 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100796 self.ser,
797 validator_fcns,
798 error_name,
799 op=op,
800 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000801 input_shape=a.shape,
802 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000803 output_shape=result_tensor.shape,
804 output_dtype=result_tensor.dtype,
805 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100806 input_list=input_list,
807 output_list=output_list,
808 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000809 ):
810 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700811
812 attr = ts.TosaSerializerAttribute()
813 attr.AxisAttribute(axis)
814
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000815 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000816
817 compliance = self.tensorComplianceMetaData(
818 op, inputs[0].dtype, args_dict, result_tensor, error_name
819 )
820 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700821
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000822 def build_pool2d(
823 self,
824 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100825 inputs,
826 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000827 validator_fcns=None,
828 error_name=None,
829 qinfo=None,
830 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100831 assert len(inputs) == 1
832 input = inputs[0]
833 # max_pool has no accum_dtype
834 accum_dtype = (
835 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
836 )
837 stride = args_dict["stride"]
838 pad = args_dict["pad"]
839 kernel = args_dict["kernel"]
840
Jeremy Johnson0601f802023-11-08 16:28:09 +0000841 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000842 self.ser, self.rng, input, kernel, stride, pad, error_name
843 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100844
845 # Ensure new output type has correct qinfo
846 if error_name == ErrorIf.WrongInputType:
847 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000848 qinfo = [
849 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000850 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000851 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100852
853 # Invalidate Input/Output list for error if checks.
854 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000855 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100856 pCount, cCount = op["operands"]
857 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000858 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
859 self, error_name, input_list, output_list
860 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100861
Les Bell729b0352021-11-24 10:28:21 +0000862 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100863 self.ser,
864 validator_fcns,
865 error_name,
866 op=op,
867 input_shape=input.shape,
868 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000869 output_shape=result_tensor.shape,
870 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000871 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100872 kernel=kernel,
873 stride=stride,
874 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000875 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000876 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100877 input_list=input_list,
878 output_list=output_list,
879 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000880 ):
881 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700882
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000883 if qinfo is None:
884 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700885
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000886 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100887 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000888
889 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700890
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100891 compliance = self.tensorComplianceMetaData(
892 op, inputs[0].dtype, args_dict, result_tensor, error_name
893 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100894
895 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100896
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000897 def build_conv2d(
898 self,
899 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100900 inputs,
901 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000902 validator_fcns=None,
903 error_name=None,
904 qinfo=None,
905 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100906 assert len(inputs) == 3
907 ifm, filter, bias = inputs
908 accum_dtype = args_dict["acc_type"]
909 strides = args_dict["stride"]
910 padding = args_dict["pad"]
911 dilations = args_dict["dilation"]
912
Kevin Cheng550ccc52021-03-03 11:21:43 -0800913 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100914 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100915 self.ser,
916 self.rng,
917 ifm,
918 filter,
919 accum_dtype,
920 strides,
921 padding,
922 dilations,
923 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000924 )
925
926 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000927 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
928 DType.INT8,
929 DType.UINT8,
930 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000931 qinfo = [
932 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100933 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000934 ]
Les Bell0e027d42021-11-09 14:42:14 +0000935
936 # Invalidate Input/Output list for error_if checks.
937 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100938 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000939 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000940 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
941 self, error_name, input_list, output_list
942 )
Les Bell0e027d42021-11-09 14:42:14 +0000943
Les Bell729b0352021-11-24 10:28:21 +0000944 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000945 self.ser,
946 validator_fcns,
947 error_name,
948 op=op,
949 input_dtype=ifm.dtype,
950 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100951 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000952 qinfo=qinfo,
953 input_list=input_list,
954 num_operands=num_operands,
955 output_list=output_list,
956 pad=padding,
957 stride=strides,
958 dilation=dilations,
959 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100960 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100961 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000962 ):
963 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700964
Tai Lyd3797f02023-11-15 23:06:19 +0000965 # TODO - Test local_bound, for now set local bound attribute to False
966 local_bound = False
967
Eric Kunzee5e26762020-10-13 16:11:07 -0700968 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000969 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700970
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000971 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100972
973 compliance = self.tensorComplianceMetaData(
974 op, ifm.dtype, args_dict, result_tensor, error_name
975 )
976
977 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700978
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000979 def build_conv3d(
980 self,
981 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100982 inputs,
983 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000984 validator_fcns=None,
985 error_name=None,
986 qinfo=None,
987 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100988 assert len(inputs) == 3
989 ifm, filter, bias = inputs
990 accum_dtype = args_dict["acc_type"]
991 strides = args_dict["stride"]
992 padding = args_dict["pad"]
993 dilations = args_dict["dilation"]
994
Kevin Cheng1533b852021-09-01 12:51:58 -0700995 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000996 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100997 self.ser,
998 self.rng,
999 ifm,
1000 filter,
1001 accum_dtype,
1002 strides,
1003 padding,
1004 dilations,
1005 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001006 )
1007
1008 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001009 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1010 DType.INT8,
1011 DType.UINT8,
1012 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001013 qinfo = [
1014 TosaQuantGen.getZeroPoint(self, ifm.dtype),
evacha0147ab1762024-01-29 13:23:23 +00001015 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001016 ]
Les Bell0e027d42021-11-09 14:42:14 +00001017
1018 # Invalidate Input/Output list for error_if checks.
1019 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +00001020 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001021 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001022 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1023 self, error_name, input_list, output_list
1024 )
Les Bell0e027d42021-11-09 14:42:14 +00001025
Les Bell729b0352021-11-24 10:28:21 +00001026 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001027 self.ser,
1028 validator_fcns,
1029 error_name,
1030 op=op,
1031 input_dtype=ifm.dtype,
1032 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +00001033 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001034 qinfo=qinfo,
1035 input_list=input_list,
1036 num_operands=num_operands,
1037 output_list=output_list,
1038 pad=padding,
1039 stride=strides,
1040 dilation=dilations,
1041 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001042 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +00001043 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001044 ):
1045 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07001046
Tai Lyd3797f02023-11-15 23:06:19 +00001047 # TODO - Test local_bound, for now set local bound attribute to False
1048 local_bound = False
1049
Kevin Cheng1533b852021-09-01 12:51:58 -07001050 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001051 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -07001052
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001053 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001054
1055 compliance = self.tensorComplianceMetaData(
1056 op, ifm.dtype, args_dict, result_tensor, error_name
1057 )
1058
1059 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001060
Kevin Cheng550ccc52021-03-03 11:21:43 -08001061 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001062 self,
1063 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001064 inputs,
1065 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001066 validator_fcns=None,
1067 error_name=None,
1068 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001069 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001070 assert len(inputs) == 3
1071 ifm, filter, bias = inputs
1072 accum_dtype = args_dict["acc_type"]
1073 strides = args_dict["stride"]
1074 out_pad = args_dict["pad"]
1075 output_shape = args_dict["out_shape"]
1076
TatWai Chong24594f52022-06-08 00:48:04 -07001077 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001078 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001079 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001080 )
Les Bell0e027d42021-11-09 14:42:14 +00001081
1082 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001083 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1084 DType.INT8,
1085 DType.UINT8,
1086 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001087 qinfo = [
1088 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001089 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001090 ]
Les Bell0e027d42021-11-09 14:42:14 +00001091
1092 # Invalidate Input/Output list for error_if checks.
1093 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001094 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001095 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001096 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1097 self, error_name, input_list, output_list
1098 )
Les Bell0e027d42021-11-09 14:42:14 +00001099
Les Bell729b0352021-11-24 10:28:21 +00001100 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001101 self.ser,
1102 validator_fcns,
1103 error_name,
1104 op=op,
1105 input_dtype=ifm.dtype,
1106 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001107 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001108 qinfo=qinfo,
1109 input_list=input_list,
1110 num_operands=num_operands,
1111 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001112 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001113 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001114 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001115 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001116 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001117 ):
1118 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001119
Tai Lyd3797f02023-11-15 23:06:19 +00001120 # TODO - Test local_bound, for now set local bound attribute to False
1121 local_bound = False
1122
Eric Kunzee5e26762020-10-13 16:11:07 -07001123 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001124 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001125 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001126 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001127
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001128 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001129
1130 compliance = self.tensorComplianceMetaData(
1131 op, ifm.dtype, args_dict, result_tensor, error_name
1132 )
1133
1134 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001135
Kevin Cheng550ccc52021-03-03 11:21:43 -08001136 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001137 self,
1138 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001139 inputs,
1140 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001141 validator_fcns=None,
1142 error_name=None,
1143 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001144 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001145 assert len(inputs) == 3
1146 ifm, filter, bias = inputs
1147 accum_dtype = args_dict["acc_type"]
1148 strides = args_dict["stride"]
1149 padding = args_dict["pad"]
1150 dilations = args_dict["dilation"]
1151
Jeremy Johnson4f931302024-01-04 17:05:24 +00001152 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001153 self.ser,
1154 self.rng,
1155 ifm,
1156 filter,
1157 accum_dtype,
1158 strides,
1159 padding,
1160 dilations,
1161 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001162 )
1163
1164 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001165 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1166 DType.INT8,
1167 DType.UINT8,
1168 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001169 qinfo = [
1170 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001171 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001172 ]
Les Bell0e027d42021-11-09 14:42:14 +00001173
1174 # Invalidate Input/Output list for error_if checks.
1175 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001176 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001177 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001178 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1179 self, error_name, input_list, output_list
1180 )
Les Bell0e027d42021-11-09 14:42:14 +00001181
Les Bell729b0352021-11-24 10:28:21 +00001182 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001183 self.ser,
1184 validator_fcns,
1185 error_name,
1186 op=op,
1187 input_dtype=ifm.dtype,
1188 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001189 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001190 qinfo=qinfo,
1191 input_list=input_list,
1192 num_operands=num_operands,
1193 output_list=output_list,
1194 pad=padding,
1195 stride=strides,
1196 dilation=dilations,
1197 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001198 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001199 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001200 ):
1201 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001202
Tai Lyd3797f02023-11-15 23:06:19 +00001203 # TODO - Test local_bound, for now set local bound attribute to False
1204 local_bound = False
1205
Eric Kunzee5e26762020-10-13 16:11:07 -07001206 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001207 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001208
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001209 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001210
1211 compliance = self.tensorComplianceMetaData(
1212 op, ifm.dtype, args_dict, result_tensor, error_name
1213 )
1214
1215 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001216
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001217 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001218 self,
1219 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001220 inputs,
1221 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001222 validator_fcns=None,
1223 error_name=None,
1224 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001225 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001226 assert len(inputs) == 3
1227 ifm, filter, bias = inputs
1228 accum_dtype = args_dict["acc_type"]
1229
1230 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001231 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001232 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001233
1234 # Invalidate Input/Output list for error if checks.
1235 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001236 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001237 pCount, cCount = op["operands"]
1238 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001239 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1240 self, error_name, input_list, output_list
1241 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001242
Les Bell729b0352021-11-24 10:28:21 +00001243 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001244 self.ser,
1245 validator_fcns,
1246 error_name,
1247 op=op,
1248 input_shape=ifm.shape,
1249 input_dtype=ifm.dtype,
1250 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001251 output_shape=result_tensor.shape,
1252 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001253 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001254 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001255 input_list=input_list,
1256 output_list=output_list,
1257 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001258 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001259 ):
1260 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001261
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001262 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001263 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001264
1265 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001266
1267 compliance = self.tensorComplianceMetaData(
1268 op, ifm.dtype, args_dict, result_tensor, error_name
1269 )
1270
1271 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001272
James Ward8b390432022-08-12 20:48:56 +01001273 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001274 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001275 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001276 assert len(inputs) == 2
1277 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001278 accum_dtype = args_dict["acc_type"]
1279 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001280 self.ser, self.rng, a, b, accum_dtype, error_name
1281 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001282
1283 # Invalidate Input/Output list for error if checks.
1284 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001285 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001286 pCount, cCount = op["operands"]
1287 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001288 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1289 self, error_name, input_list, output_list
1290 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001291
Les Bell729b0352021-11-24 10:28:21 +00001292 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001293 self.ser,
1294 validator_fcns,
1295 error_name,
1296 op=op,
1297 input_shape=a.shape,
1298 input_dtype=a.dtype,
1299 input2_shape=b.shape,
1300 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001301 output_shape=result_tensor.shape,
1302 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001303 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001304 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001305 input_list=input_list,
1306 output_list=output_list,
1307 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001308 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001309 ):
1310 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001311
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001312 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001313 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001314
1315 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001316
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001317 compliance = self.tensorComplianceMetaData(
1318 op, a.dtype, args_dict, result_tensor, error_name
1319 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001320
1321 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001322
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001323 def build_reduce(
1324 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1325 ):
1326 assert len(inputs) == 1
1327 a = inputs[0]
1328 axis = args_dict["axis"]
1329 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001330
1331 # Invalidate Input/Output list for error if checks.
1332 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001333 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001334 pCount, cCount = op["operands"]
1335 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001336 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1337 self, error_name, input_list, output_list
1338 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001339
Les Bell729b0352021-11-24 10:28:21 +00001340 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001341 self.ser,
1342 validator_fcns,
1343 error_name,
1344 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001345 axis=axis,
1346 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001347 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001348 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001349 output_dtype=result_tensor.dtype,
1350 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001351 input_list=input_list,
1352 output_list=output_list,
1353 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001354 ):
1355 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001356
1357 attr = ts.TosaSerializerAttribute()
1358 attr.AxisAttribute(axis)
1359
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001360 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001361
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001362 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1363 # Number of products - needed for compliance
1364 args_dict["n"] = a.shape[axis]
1365
1366 compliance = self.tensorComplianceMetaData(
1367 op, a.dtype, args_dict, result_tensor, error_name
1368 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001369
1370 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001371
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001372 def build_clamp(
1373 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1374 ):
1375 assert len(inputs) == 1
1376 a = inputs[0]
1377
1378 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001379
Jeremy Johnson18e26662021-07-22 16:15:29 +01001380 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001381
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001382 if error_name == ErrorIf.MaxSmallerMin:
1383 # Make sure the numbers are different to invoke this error
1384 while v[0] == v[1]:
1385 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1386 max_val = min(v)
1387 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001388 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001389 max_val = max(v)
1390 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001391
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001392 # Invalidate Input/Output list for error if checks.
1393 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001394 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001395 pCount, cCount = op["operands"]
1396 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001397 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1398 self, error_name, input_list, output_list
1399 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001400
Les Bell729b0352021-11-24 10:28:21 +00001401 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001402 self.ser,
1403 validator_fcns,
1404 error_name,
1405 op=op,
1406 max_val=max_val,
1407 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001408 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001409 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001410 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001411 output_dtype=result_tensor.dtype,
1412 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001413 input_list=input_list,
1414 output_list=output_list,
1415 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001416 ):
1417 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001418
1419 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001420 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1421 if a.dtype == DType.FP16:
1422 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1423 min_val = min_val.astype(np.float32)
1424 max_val = max_val.astype(np.float32)
1425
1426 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001427 elif a.dtype in (DType.INT8, DType.INT16):
James Ward34071252022-12-07 15:48:47 +00001428 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Won Jeon2c34b462024-02-06 18:37:00 +00001429 else:
1430 # to avoid internal error for incorrect input types
1431 attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001433 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001434
1435 compliance = self.tensorComplianceMetaData(
1436 op, a.dtype, args_dict, result_tensor, error_name
1437 )
1438
1439 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001440
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001441 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1442 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001443 attr = ts.TosaSerializerAttribute()
1444
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001445 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001446
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001447 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001448 return result_tens
1449
1450 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001451 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1452 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001453
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001454 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001455 return result_tens
1456
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001457 def build_activation(
1458 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1459 ):
1460 assert len(inputs) == 1
1461 a = inputs[0]
1462
1463 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001464
1465 # Invalidate Input/Output list for error if checks.
1466 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001467 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001468 pCount, cCount = op["operands"]
1469 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001470 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1471 self, error_name, input_list, output_list
1472 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001473
Les Bell729b0352021-11-24 10:28:21 +00001474 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001475 self.ser,
1476 validator_fcns,
1477 error_name,
1478 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001479 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001480 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001481 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001482 output_dtype=result_tensor.dtype,
1483 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001484 input_list=input_list,
1485 output_list=output_list,
1486 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001487 ):
1488 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001489
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001490 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001491
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001492 compliance = self.tensorComplianceMetaData(
1493 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001494 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001495
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001496 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001497
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001498 def build_concat(
1499 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1500 ):
Won Jeon74342e52024-01-09 00:34:40 +00001501 if op["op"] == Op.CONCAT_SHAPE:
1502 axis = 0
1503 else:
1504 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001505 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001506 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001507
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001508 result_tensor = OutputShaper.concatOp(
1509 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001510 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001511
Matthew Haddon818ab902021-07-27 09:12:49 +01001512 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001513 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001514 input_tensor_names.append(tensor.name)
1515
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001516 # Invalidate Input/Output list for error if checks.
1517 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001518 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001519 pCount, cCount = op["operands"]
1520 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001521 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1522 self, error_name, input_list, output_list
1523 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001524
Les Bell729b0352021-11-24 10:28:21 +00001525 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001526 self.ser,
1527 validator_fcns,
1528 error_name,
1529 op=op,
1530 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001531 input_shape=inputs[0].shape,
1532 output_shape=result_tensor.shape,
1533 input_dtype=inputs[0].dtype,
1534 output_dtype=result_tensor.dtype,
1535 inputs=inputs,
1536 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001537 input_list=input_list,
1538 output_list=output_list,
1539 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001540 ):
1541 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001542
Won Jeon74342e52024-01-09 00:34:40 +00001543 if op["op"] == Op.CONCAT:
1544 attr = ts.TosaSerializerAttribute()
1545 attr.AxisAttribute(axis)
1546 else:
1547 assert op["op"] == Op.CONCAT_SHAPE
1548 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001549 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001550
1551 compliance = self.tensorComplianceMetaData(
1552 op, inputs[0].dtype, args_dict, result_tensor, error_name
1553 )
1554
1555 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001556
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001557 def build_pad(
1558 self,
1559 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001560 inputs,
1561 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001562 validator_fcns=None,
1563 error_name=None,
1564 qinfo=None,
1565 ):
Tai Lye095da72024-01-25 22:00:18 +00001566 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001567 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001568 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001569 padding = args_dict["pad"]
1570 pad_const_int = args_dict["pad_const_int"]
1571 pad_const_float = args_dict["pad_const_fp"]
1572
1573 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001574
Tai Lye095da72024-01-25 22:00:18 +00001575 # write empty padding into PadAttribute to ensure inputs[1] is used
Kevin Chengfe392ce2021-10-18 21:51:55 +00001576 attr = ts.TosaSerializerAttribute()
Tai Lye095da72024-01-25 22:00:18 +00001577 attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001578
Matthew Haddone807aae2021-10-11 18:12:58 +01001579 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001580 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001581 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001582 pCount, cCount = op["operands"]
1583 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001584 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1585 self, error_name, input_list, output_list
1586 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001587
Les Bell729b0352021-11-24 10:28:21 +00001588 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001589 self.ser,
1590 validator_fcns,
1591 error_name,
1592 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001593 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001594 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001595 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001596 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001597 pad=padding,
1598 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001599 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001600 input_list=input_list,
1601 output_list=output_list,
1602 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001603 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001604 ):
1605 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001606
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001607 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001608
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001609 compliance = self.tensorComplianceMetaData(
1610 op, a.dtype, args_dict, result_tensor, error_name
1611 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001612
1613 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001614
Won Jeona21b2e82023-08-10 10:33:01 +00001615 def build_dim(
1616 self,
1617 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001618 inputs,
1619 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001620 validator_fcns=None,
1621 error_name=None,
1622 qinfo=None,
1623 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001624 assert len(inputs) == 1
1625 a = inputs[0]
1626 axis = args_dict["axis"]
1627 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001628
1629 # Invalidate Input/Output list for error if checks.
1630 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001631 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001632 pCount, cCount = op["operands"]
1633 num_operands = pCount + cCount
1634 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1635 self, error_name, input_list, output_list
1636 )
1637
1638 if not TosaErrorValidator.evValidateErrorIfs(
1639 self.ser,
1640 validator_fcns,
1641 error_name,
1642 op=op,
1643 axis=axis,
1644 input_shape=a.shape,
1645 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001646 output_shape=result_tensor.shape,
1647 output_dtype=result_tensor.dtype,
1648 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001649 input_list=input_list,
1650 output_list=output_list,
1651 num_operands=num_operands,
1652 ):
1653 return None
1654
1655 attr = ts.TosaSerializerAttribute()
1656 attr.AxisAttribute(axis)
1657
1658 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001659 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001660
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001661 def build_reshape(
1662 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1663 ):
Tai Ly8690a082023-12-18 20:40:24 +00001664 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001665 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001666 shape = inputs[1]
1667 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001668 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001669 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001670 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001671
1672 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001673 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001674 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001675 pCount, cCount = op["operands"]
1676 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001677 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1678 self, error_name, input_list, output_list
1679 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001680
Les Bell729b0352021-11-24 10:28:21 +00001681 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001682 self.ser,
1683 validator_fcns,
1684 error_name,
1685 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001686 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001687 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001688 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001689 output_dtype=result_tensor.dtype,
1690 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001691 input_list=input_list,
1692 output_list=output_list,
1693 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001694 ):
1695 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001696
Tai Ly8690a082023-12-18 20:40:24 +00001697 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001698
1699 compliance = self.tensorComplianceMetaData(
1700 op, a.dtype, args_dict, result_tensor, error_name
1701 )
1702
1703 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001704
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001705 def build_reverse(
1706 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1707 ):
1708 assert len(inputs) == 1
1709 a = inputs[0]
1710 axis = args_dict["axis"]
1711 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001712
1713 # Invalidate Input/Output list for error if checks.
1714 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001715 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001716 pCount, cCount = op["operands"]
1717 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001718 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1719 self, error_name, input_list, output_list
1720 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001721
Les Bell729b0352021-11-24 10:28:21 +00001722 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001723 self.ser,
1724 validator_fcns,
1725 error_name,
1726 op=op,
1727 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001728 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001729 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001730 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001731 output_dtype=result_tensor.dtype,
1732 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001733 input_list=input_list,
1734 output_list=output_list,
1735 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001736 ):
1737 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001738
1739 attr = ts.TosaSerializerAttribute()
1740 attr.AxisAttribute(axis)
1741
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001742 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001743 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001744
evacha0198477222024-01-26 12:25:32 +00001745 def build_transpose(
1746 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1747 ):
1748 assert len(inputs) == 1
1749 a = inputs[0]
1750 perms = args_dict["perms"]
1751
1752 result_tensor = OutputShaper.transposeOp(
1753 self.ser, self.rng, a, perms, error_name
1754 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001755
Kevin Chengfe392ce2021-10-18 21:51:55 +00001756 attr = ts.TosaSerializerAttribute()
1757 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001758
Matthew Haddone807aae2021-10-11 18:12:58 +01001759 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001760 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001761 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001762 pCount, cCount = op["operands"]
1763 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001764 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1765 self, error_name, input_list, output_list
1766 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001767
Les Bell729b0352021-11-24 10:28:21 +00001768 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001769 self.ser,
1770 validator_fcns,
1771 error_name,
1772 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001773 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001774 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001775 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001776 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001777 output_dtype=result_tensor.dtype,
1778 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001779 input_list=input_list,
1780 output_list=output_list,
1781 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001782 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001783 ):
1784 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001785
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001786 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001787
1788 compliance = self.tensorComplianceMetaData(
1789 op, a.dtype, args_dict, result_tensor, error_name
1790 )
1791
1792 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001793
evacha017f7d4252024-01-24 12:08:09 +00001794 def build_slice(
1795 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1796 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001797 assert len(inputs) == 3
1798 a, start_var, size_var = inputs
1799 start_const = args_dict["start"]
1800 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001801
1802 result_tensor = OutputShaper.sliceOp(
TatWai Chongf15bad82024-01-31 21:33:27 -08001803 self.ser, self.rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001804 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001805
1806 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001807 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001808 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001809 pCount, cCount = op["operands"]
1810 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001811 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1812 self, error_name, input_list, output_list
1813 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001814
Les Bell729b0352021-11-24 10:28:21 +00001815 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001816 self.ser,
1817 validator_fcns,
1818 error_name,
1819 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001820 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001821 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001822 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001823 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001824 start=start_const,
1825 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001826 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001827 input_list=input_list,
1828 output_list=output_list,
1829 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001830 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001831 ):
1832 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001833
TatWai Chongf15bad82024-01-31 21:33:27 -08001834 # TODO remove the slice attribute once shape dynamism support is mature.
Eric Kunzee5e26762020-10-13 16:11:07 -07001835 attr = ts.TosaSerializerAttribute()
TatWai Chongf15bad82024-01-31 21:33:27 -08001836 attr.SliceAttribute(start_const, size_const)
Eric Kunzee5e26762020-10-13 16:11:07 -07001837
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001838 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha017f7d4252024-01-24 12:08:09 +00001839
1840 compliance = self.tensorComplianceMetaData(
1841 op, a.dtype, args_dict, result_tensor, error_name
1842 )
1843
1844 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001845
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001846 def build_tile(
1847 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1848 ):
Tai Ly8690a082023-12-18 20:40:24 +00001849 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001850 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001851 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001852 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001853 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001854 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001855 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001856
1857 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001858 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001859 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001860 pCount, cCount = op["operands"]
1861 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001862 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1863 self, error_name, input_list, output_list
1864 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001865
Les Bell729b0352021-11-24 10:28:21 +00001866 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001867 self.ser,
1868 validator_fcns,
1869 error_name,
1870 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001871 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001872 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001873 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001874 output_dtype=result_tensor.dtype,
1875 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001876 input_list=input_list,
1877 output_list=output_list,
1878 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001879 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001880 ):
1881 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001882
Tai Ly8690a082023-12-18 20:40:24 +00001883 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001884
1885 compliance = self.tensorComplianceMetaData(
1886 op, a.dtype, args_dict, result_tensor, error_name
1887 )
1888
1889 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001890
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001891 def build_gather(
1892 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1893 ):
1894 assert len(inputs) == 2
1895 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001896
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001897 result_tensor = OutputShaper.gatherOp(
1898 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001899 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001900
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001901 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001902 input_list = [values.name, indices.name]
1903 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001904 pCount, cCount = op["operands"]
1905 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001906 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1907 self, error_name, input_list, output_list
1908 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001909
Les Bell729b0352021-11-24 10:28:21 +00001910 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001911 self.ser,
1912 validator_fcns,
1913 error_name,
1914 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001915 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001916 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001917 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001918 output_dtype=result_tensor.dtype,
1919 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001920 input_list=input_list,
1921 output_list=output_list,
1922 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001923 ):
1924 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001925
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001926 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001927
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001928 compliance = self.tensorComplianceMetaData(
1929 op, values.dtype, args_dict, result_tensor, error_name
1930 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001931
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001932 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001933
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001934 def build_scatter(
1935 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1936 ):
1937 assert len(inputs) == 3
1938 values_in, indices, input = inputs
1939 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001940 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001941 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001942
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001943 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001944 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001945 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001946 pCount, cCount = op["operands"]
1947 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001948 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1949 self, error_name, input_list, output_list
1950 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001951
Les Bell729b0352021-11-24 10:28:21 +00001952 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001953 self.ser,
1954 validator_fcns,
1955 error_name,
1956 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001957 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001958 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001959 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001960 output_dtype=result_tensor.dtype,
1961 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001962 input_list=input_list,
1963 output_list=output_list,
1964 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001965 ):
1966 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001967
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001968 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001969
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001970 compliance = self.tensorComplianceMetaData(
1971 op, values_in.dtype, args_dict, result_tensor, error_name
1972 )
1973
1974 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001975
Kevin Cheng550ccc52021-03-03 11:21:43 -08001976 def build_resize(
1977 self,
1978 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001979 inputs,
1980 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01001981 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001982 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001983 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001984 ):
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001985 assert len(inputs) == 1
1986 input = inputs[0]
1987 mode = args_dict["mode"]
1988 scale = args_dict["scale"]
1989 offset = args_dict["offset"]
1990 border = args_dict["border"]
1991 output_dtype = args_dict["output_dtype"]
1992
1993 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08001994 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001995 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001996 input,
1997 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001998 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001999 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002000 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002001 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002002 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002003 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002004 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002005
Matthew Haddon848efb42021-09-09 12:30:53 +01002006 # Invalidate Input/Output list for error if checks.
2007 input_list = [input.name]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002008 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002009 pCount, cCount = op["operands"]
2010 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002011 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2012 self, error_name, input_list, output_list
2013 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002014
Les Bell729b0352021-11-24 10:28:21 +00002015 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002016 self.ser,
2017 validator_fcns,
2018 error_name,
2019 op=op,
2020 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002021 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002022 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002023 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002024 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002025 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002026 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002027 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002028 input_list=input_list,
2029 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002030 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002031 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002032 ):
2033 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002034
Eric Kunzee5e26762020-10-13 16:11:07 -07002035 attr = ts.TosaSerializerAttribute()
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002036 attr.ResizeAttribute(scale, offset, border, mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002037 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002038
2039 compliance = self.tensorComplianceMetaData(
2040 op, input.dtype, args_dict, result_tensor, error_name
2041 )
2042
2043 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002044
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002045 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
2046 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
2047 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002048 self.ser.addOperator(
2049 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2050 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002051 return result_tens
2052
evacha0198477222024-01-26 12:25:32 +00002053 def build_const(
2054 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2055 ):
2056 assert len(inputs) == 1
2057 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002058 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002059
2060 compliance = self.tensorComplianceMetaData(
2061 op, val.dtype, args_dict, val, error_name
2062 )
2063
2064 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002065
2066 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002067 def build_cast(
2068 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2069 ):
2070 assert len(inputs) == 1
2071 val = inputs[0]
2072 out_dtype = args_dict["out_type"]
2073
2074 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002075 self.ser, self.rng, val, out_dtype, error_name
2076 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002077
2078 # Invalidate Input/Output list for error if checks.
2079 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002080 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002081 pCount, cCount = op["operands"]
2082 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002083 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2084 self, error_name, input_list, output_list
2085 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002086
Les Bell729b0352021-11-24 10:28:21 +00002087 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002088 self.ser,
2089 validator_fcns,
2090 error_name,
2091 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002092 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002093 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002094 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002095 output_dtype=result_tensor.dtype,
2096 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002097 input_list=input_list,
2098 output_list=output_list,
2099 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002100 ):
2101 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002102
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002103 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002104
2105 compliance = self.tensorComplianceMetaData(
2106 op, val.dtype, args_dict, result_tensor, error_name
2107 )
2108
2109 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002110
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002111 def build_rescale(
2112 self,
2113 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002114 inputs,
2115 args_dict,
2116 validator_fcns=None,
2117 error_name=None,
2118 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002119 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002120 assert len(inputs) == 1
2121 val = inputs[0]
2122 out_dtype = args_dict["output_dtype"]
2123 scale32 = args_dict["scale"]
2124 double_round = args_dict["double_round"]
2125 per_channel = args_dict["per_channel"]
2126
2127 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002128 self.ser, self.rng, val, out_dtype, error_name
2129 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002130
2131 if per_channel:
2132 nc = val.shape[-1]
2133 else:
2134 nc = 1
2135
2136 in_type_width = self.typeWidth(val.dtype)
2137 out_type_width = self.typeWidth(out_dtype)
2138
Tai Ly8690a082023-12-18 20:40:24 +00002139 input_unsigned = False
2140 output_unsigned = False
2141
Kevin Cheng3a478572021-01-22 17:21:02 -08002142 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002143 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002144 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002145 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002146 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002147 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002148 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002149 elif error_name in [
2150 ErrorIf.InputZeroPointNotZero,
2151 ErrorIf.U16InputZeroPointNotValid,
2152 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002153 input_zp = self.randInt(-128, 128)
2154 if input_zp == 0:
2155 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002156 in_type_width += 1
2157 elif val.dtype == DType.UINT16:
2158 # Must come after ErrorIf.U16InputZeroPointNotValid check
2159 input_zp = self.rng.choice([0, 32768])
2160 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002161 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002162 else:
2163 input_zp = 0
2164
Kevin Cheng3a478572021-01-22 17:21:02 -08002165 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002166 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002167 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002168 elif out_dtype == DType.UINT8:
2169 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002170 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002171 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002172 elif error_name in [
2173 ErrorIf.OutputZeroPointNotZero,
2174 ErrorIf.U16OutputZeroPointNotValid,
2175 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002176 output_zp = self.randInt(-128, 128)
2177 if output_zp == 0:
2178 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002179 out_type_width += 1
2180 elif out_dtype == DType.UINT16:
2181 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2182 output_zp = self.rng.choice([0, 32768])
2183 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002184 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002185 else:
2186 output_zp = 0
2187
2188 # Calculate scale based on:
2189 # scale = a *(2^output_width)/(2^input_width))
2190
2191 a = np.float32(self.rng.random(size=[nc]))
2192 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2193
2194 if scale32:
2195 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002196 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002197 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2198 else:
2199 # Cap the scaling at 2^15 - 1 for scale16
2200 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2201
Kevin Cheng550ccc52021-03-03 11:21:43 -08002202 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002203
2204 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2205 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002206 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2207 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002208
2209 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002210 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2211 scale_arr[i], scale32
2212 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002213 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2214 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002215
Kevin Cheng550ccc52021-03-03 11:21:43 -08002216 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002217 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002218 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002219 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002220 assert val.placeholderFilename
2221 values = np.load(
2222 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2223 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002224 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2225 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2226 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002227 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2228 # Check we can safely convert to the expected dtype
2229 assert (
2230 val_adj.all() >= np.iinfo(values.dtype).min
2231 and val_adj.all() <= np.iinfo(values.dtype).max
2232 )
2233
2234 # Force casting to output datatype
2235 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2236
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002237 if not np.all(np.array_equal(values, val_adj)):
2238 # Values changed so overwrite file with new values
2239 np.save(
2240 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2241 val_adj,
2242 False,
2243 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002244
Matthew Haddonc2025212021-10-08 21:21:05 +01002245 # Invalidate Input/Output list for error if checks.
2246 input_list = [val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002247 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002248 pCount, cCount = op["operands"]
2249 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002250 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2251 self, error_name, input_list, output_list
2252 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002253
2254 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002255 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002256 self.ser,
2257 validator_fcns,
2258 error_name,
2259 op=op,
2260 input_dtype=val.dtype,
2261 output_dtype=out_dtype,
2262 input_shape=val.shape,
2263 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002264 scale32=scale32,
2265 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002266 input_list=input_list,
2267 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002268 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002269 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002270 ):
2271 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002272
Eric Kunzee5e26762020-10-13 16:11:07 -07002273 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002274 attr.RescaleAttribute(
2275 input_zp,
2276 output_zp,
2277 multiplier_arr,
2278 shift_arr,
2279 scale32,
2280 double_round,
2281 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002282 input_unsigned,
2283 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002284 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002285
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002286 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002287
2288 compliance = self.tensorComplianceMetaData(
2289 op, val.dtype, args_dict, result_tensor, error_name
2290 )
2291
2292 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002293
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002294 def _get_condition_tensor(self, op, cond, error_name):
2295 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002296 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002297 else:
2298 cond_type = DType.BOOL
2299 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2300 choice = self.rng.choice([1, 2])
2301 if choice == 1:
2302 cond_shape = [2]
2303 else:
2304 cond_shape = [1, 2]
2305 else:
2306 # Must be of size 1 (rank 0)
2307 cond_shape = []
2308 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2309 return cond_tens
2310
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002311 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002312 self,
2313 op,
2314 inputs,
2315 args_dict,
2316 validator_fcns=None,
2317 error_name=None,
2318 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002319 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002320 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002321 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002322 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002323 assert len(inputs) == 2
2324 then_tens, else_tens = inputs
2325
2326 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002327
2328 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002329 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002330
2331 # Make then/else tensors
2332 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002333
Jeremy Johnson587cc842024-02-08 11:45:44 +00002334 dtype = DType.INT32
2335
Matthew Haddon630c17c2021-10-14 15:05:41 +01002336 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002337 if error_name in [
2338 ErrorIf.CondIfOutputListThenGraphMismatch,
2339 ErrorIf.CondIfOutputListElseGraphMismatch,
2340 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002341 incorrect_shape = deepcopy(then_tens.shape)
2342 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002343 incorrect_shape[i] += (
2344 self.rng.choice([-3, -2, 2, 3])
2345 if incorrect_shape[i] > 3
2346 else self.rng.choice([1, 2, 4])
2347 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002348 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2349
Jeremy Johnson18e26662021-07-22 16:15:29 +01002350 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2351 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002352
2353 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002354 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002355
2356 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002357 then_block = "THEN_BLOCK"
2358 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002359 attr = ts.TosaSerializerAttribute()
2360 attr.CondIfAttribute(then_block, else_block)
2361
2362 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002363 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002364
Jerry Ge9e94af82022-10-27 09:57:00 -07002365 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002366 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002367 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002368 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002369 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002370 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002371 self.ser.addOutputTensor(then_tens)
2372
Jerry Ge9e94af82022-10-27 09:57:00 -07002373 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002374 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002375 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002376 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002377 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002378 self.ser.addOutputTensor(else_tens)
2379
Les Bell729b0352021-11-24 10:28:21 +00002380 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002381 self.ser,
2382 validator_fcns,
2383 error_name,
2384 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002385 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002386 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002387 ):
2388 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002389
Jeremy Johnson587cc842024-02-08 11:45:44 +00002390 compliance = self.tensorComplianceMetaData(
2391 op, dtype, args_dict, result_tensor, error_name
2392 )
2393
2394 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002395
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002396 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002397 self,
2398 op,
2399 inputs,
2400 args_dict,
2401 validator_fcns=None,
2402 error_name=None,
2403 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002404 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002405 # For cond_if with a binary op in the then/else blocks, take a and b and
2406 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002407 assert len(inputs) == 2
2408 a, b = inputs
2409
2410 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002411
2412 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002413 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002414
Jeremy Johnson587cc842024-02-08 11:45:44 +00002415 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002416
2417 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002418 then_block = "THEN_BLOCK"
2419 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002420 attr = ts.TosaSerializerAttribute()
2421 attr.CondIfAttribute(then_block, else_block)
2422
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002423 if error_name in [
2424 ErrorIf.CondIfInputListThenGraphMismatch,
2425 ErrorIf.CondIfInputListElseGraphMismatch,
2426 ErrorIf.CondIfOutputListElseGraphMismatch,
2427 ErrorIf.CondIfOutputListThenGraphMismatch,
2428 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002429 incorrect_shape = a.shape.copy()
2430 for i in range(len(incorrect_shape)):
2431 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2432 incorrect_block_input = deepcopy(a)
2433 incorrect_block_input.shape = incorrect_shape
2434
Eric Kunzee5e26762020-10-13 16:11:07 -07002435 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002436 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002437 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002438 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002439
James Ward24dbc422022-10-19 12:20:31 +01002440 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002441 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002442 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002443 then_op, else_op = (
2444 self.TOSA_OP_LIST["logical_right_shift"],
2445 self.TOSA_OP_LIST["logical_left_shift"],
2446 )
Les Bell6040b4d2021-10-11 12:50:31 +01002447 else:
2448 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002449
Jeremy Johnson587cc842024-02-08 11:45:44 +00002450 # Determine the element-wise binary operation that compliance will need to
2451 # check the results of
2452 compliance_op = then_op if cond else else_op
2453
2454 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002455 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002456 if (
2457 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2458 and block == then_block
2459 ) or (
2460 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2461 and block == else_block
2462 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002463 self.ser.addInputTensor(incorrect_block_input)
2464 self.ser.addInputTensor(b)
2465 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002466 elif (
2467 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2468 and block == then_block
2469 ) or (
2470 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2471 and block == else_block
2472 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002473 self.ser.addInputTensor(a)
2474 self.ser.addInputTensor(b)
2475 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2476 else:
2477 self.ser.addInputTensor(a)
2478 self.ser.addInputTensor(b)
2479 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002480 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002481
Les Bell729b0352021-11-24 10:28:21 +00002482 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002483 self.ser,
2484 validator_fcns,
2485 error_name,
2486 op=op,
2487 a=a,
2488 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002489 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002490 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002491 ):
2492 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002493
Jeremy Johnson587cc842024-02-08 11:45:44 +00002494 compliance = self.tensorComplianceMetaData(
2495 compliance_op, a.dtype, args_dict, result_tensor, error_name
2496 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002497
Jeremy Johnson587cc842024-02-08 11:45:44 +00002498 return TosaTestGen.BuildInfo(result_tensor, compliance)
2499
2500 def build_while_loop(
2501 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2502 ):
2503 assert len(inputs) == 1
2504 a = inputs[0]
2505 iter_val = args_dict["iterations"]
2506
Kevin Cheng550ccc52021-03-03 11:21:43 -08002507 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002508
Kevin Cheng550ccc52021-03-03 11:21:43 -08002509 cond_block = "COND_BLOCK"
2510 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002511
2512 attr = ts.TosaSerializerAttribute()
2513 attr.WhileLoopAttribute(cond_block, body_block)
2514
2515 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002516 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002517 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002518 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002519
2520 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002521 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2522 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002523 if error_name == ErrorIf.InputListOutputListMismatch:
2524 incorrect_acc = deepcopy(acc)
2525 for i in range(len(incorrect_acc.shape)):
2526 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2527 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2528 else:
2529 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002530
2531 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002532 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002533 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002534 [iter.name, a.name, acc.name],
2535 [iter_out.name, a_out.name, acc_out.name],
2536 attr,
2537 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002538 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002539
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002540 if error_name in [
2541 ErrorIf.InputListCondGraphMismatch,
2542 ErrorIf.InputListBodyGraphInputMismatch,
2543 ErrorIf.InputListBodyGraphOutputMismatch,
2544 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002545 incorrect_iter = deepcopy(iter)
2546 for i in range(len(incorrect_iter.shape)):
2547 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2548 if len(incorrect_iter.shape) == 0:
2549 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2550
2551 incorrect_acc = deepcopy(acc)
2552 for i in range(len(incorrect_acc.shape)):
2553 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2554
Eric Kunzee5e26762020-10-13 16:11:07 -07002555 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002556 self.ser.addBasicBlock(cond_block)
2557
Matthew Haddon630c17c2021-10-14 15:05:41 +01002558 if error_name == ErrorIf.InputListCondGraphMismatch:
2559 self.ser.addInputTensor(incorrect_iter)
2560 self.ser.addInputTensor(a)
2561 self.ser.addInputTensor(incorrect_acc)
2562 else:
2563 self.ser.addInputTensor(iter)
2564 self.ser.addInputTensor(a)
2565 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002566 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002567
2568 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002569 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002570 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002571 cond_type = DType.BOOL
2572 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2573 choice = self.rng.choice([1, 2])
2574 if choice == 1:
2575 cond_shape = [3]
2576 else:
2577 cond_shape = [1, 2]
2578 else:
2579 cond_shape = []
2580 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002581
Kevin Cheng550ccc52021-03-03 11:21:43 -08002582 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002583
2584 # BODY block (input: a, acc, iter, output: a, acc, iter)
2585 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002586 self.ser.addBasicBlock(body_block)
2587
Matthew Haddon630c17c2021-10-14 15:05:41 +01002588 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2589 self.ser.addInputTensor(incorrect_iter)
2590 self.ser.addInputTensor(a)
2591 self.ser.addInputTensor(incorrect_acc)
2592 else:
2593 self.ser.addInputTensor(iter)
2594 self.ser.addInputTensor(a)
2595 self.ser.addInputTensor(acc)
2596
Kevin Cheng550ccc52021-03-03 11:21:43 -08002597 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002598
2599 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002600 iter_body_out = self.ser.addIntermediate(
2601 incorrect_iter.shape, incorrect_iter.dtype
2602 )
2603 acc_body_out = self.ser.addIntermediate(
2604 incorrect_acc.shape, incorrect_acc.dtype
2605 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002606 else:
2607 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2608 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2609
Eric Kunzee5e26762020-10-13 16:11:07 -07002610 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2611 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2612 self.ser.addOutputTensor(iter_body_out)
2613 self.ser.addOutputTensor(a)
2614 self.ser.addOutputTensor(acc_body_out)
2615
Les Bell729b0352021-11-24 10:28:21 +00002616 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002617 self.ser,
2618 validator_fcns,
2619 error_name,
2620 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002621 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002622 ):
2623 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002624
Jeremy Johnson587cc842024-02-08 11:45:44 +00002625 compliance = self.tensorComplianceMetaData(
2626 op, a.dtype, args_dict, acc_out, error_name
2627 )
2628
2629 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002630
Luke Hutton57287132023-02-06 14:54:18 +00002631 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002632 self,
2633 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002634 inputs,
2635 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002636 validator_fcns=None,
2637 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002638 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002639 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002640 assert len(inputs) == 2
2641 val1, val2 = inputs
2642 inverse = args_dict["inverse"]
2643
Luke Hutton57287132023-02-06 14:54:18 +00002644 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2645
2646 input_names = [val1.name, val2.name]
2647 pCount, cCount = op["operands"]
2648 num_operands = pCount + cCount
2649
2650 output_names = [res.name for res in results]
2651 output_shapes = [res.shape for res in results]
2652 output_dtypes = [res.dtype for res in results]
2653
2654 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2655 self, error_name, input_names, output_names
2656 )
2657
2658 if not TosaErrorValidator.evValidateErrorIfs(
2659 self.ser,
2660 validator_fcns,
2661 error_name,
2662 op=op,
2663 inverse=inverse,
2664 input1=val1,
2665 input2=val2,
2666 input_shape=val1.shape,
2667 input_dtype=val1.dtype,
2668 output_shape=output_shapes,
2669 output_dtype=output_dtypes,
2670 result_tensors=results,
2671 input_list=input_names,
2672 output_list=output_names,
2673 num_operands=num_operands,
2674 ):
2675 return None
2676
Tai Lyd3797f02023-11-15 23:06:19 +00002677 # TODO - Test local_bound, for now set local bound attribute to False
2678 local_bound = False
2679
Luke Hutton57287132023-02-06 14:54:18 +00002680 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002681 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002682
2683 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002684
2685 compliance = []
2686 for res in results:
2687 compliance.append(
2688 self.tensorComplianceMetaData(
2689 op, val1.dtype, args_dict, res, error_name
2690 )
2691 )
2692
2693 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002694
Tai Lyd3797f02023-11-15 23:06:19 +00002695 def build_rfft2d(
2696 self,
2697 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002698 inputs,
2699 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002700 validator_fcns=None,
2701 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002702 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002703 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002704 assert len(inputs) == 1
2705 val = inputs[0]
Luke Hutton261b7b62023-01-10 14:50:31 +00002706 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2707
2708 input_names = [val.name]
2709 pCount, cCount = op["operands"]
2710 num_operands = pCount + cCount
2711
2712 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002713 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002714 output_dtypes = [res.dtype for res in results]
2715
2716 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2717 self, error_name, input_names, output_names
2718 )
2719
2720 if not TosaErrorValidator.evValidateErrorIfs(
2721 self.ser,
2722 validator_fcns,
2723 error_name,
2724 op=op,
2725 input_shape=val.shape,
2726 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002727 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002728 output_dtype=output_dtypes,
2729 result_tensors=results,
2730 input_list=input_names,
2731 output_list=output_names,
2732 num_operands=num_operands,
2733 ):
2734 return None
2735
Tai Lyd3797f02023-11-15 23:06:19 +00002736 # TODO - Test local_bound, for now set local bound attribute to False
2737 local_bound = False
2738
2739 attr = ts.TosaSerializerAttribute()
2740 attr.RFFTAttribute(local_bound)
2741
2742 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002743
2744 compliance = []
2745 for res in results:
2746 compliance.append(
2747 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2748 )
2749
2750 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002751
Won Jeon74342e52024-01-09 00:34:40 +00002752 def build_shape_op(
2753 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2754 ):
2755 assert len(inputs) == 2
2756 a, b = inputs
2757
2758 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2759
2760 # Invalidate Input/Output list for error if checks.
2761 input_list = [a.name, b.name]
2762 output_list = [result_tensor.name]
2763 pCount, cCount = op["operands"]
2764 num_operands = pCount + cCount
2765 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2766 self, error_name, input_list, output_list
2767 )
2768
2769 if not TosaErrorValidator.evValidateErrorIfs(
2770 self.ser,
2771 validator_fcns,
2772 error_name,
2773 op=op,
2774 input1=a,
2775 input2=b,
2776 input_shape=a.shape,
2777 input_dtype=a.dtype,
2778 output_shape=result_tensor.shape,
2779 output_dtype=result_tensor.dtype,
2780 result_tensors=[result_tensor],
2781 input_list=input_list,
2782 output_list=output_list,
2783 num_operands=num_operands,
2784 ):
2785 return None
2786
2787 self.ser.addOperator(
2788 op["op"],
2789 input_list,
2790 output_list,
2791 )
2792 compliance = self.tensorComplianceMetaData(
2793 op, a.dtype, args_dict, result_tensor, error_name
2794 )
2795
2796 return TosaTestGen.BuildInfo(result_tensor, compliance)
2797
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002798 def create_filter_lists(
2799 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2800 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002801 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2802 default_test_rank_range = range(1, 5)
2803 if not shapeFilter:
2804 shapeFilter = [None]
2805
2806 # Calculate the filters based on what is requested and what the operator allows
2807 rmin, rmax = op["rank"]
2808 if rankFilter is not None:
2809 cleanRankFilter = []
2810 # Ensure rankFilter values are allowed by operator
2811 for rank in rankFilter:
2812 if rank >= rmin and rank <= rmax:
2813 cleanRankFilter.append(rank)
2814 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002815 # Ensure default behaviour is bounded by default range or by operator,
2816 # whichever is the smaller range of ranks.
2817 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002818 cleanRankFilter = (
2819 opRankRange
2820 if len(opRankRange) <= len(default_test_rank_range)
2821 else default_test_rank_range
2822 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002823 else:
2824 cleanRankFilter = range(rmin, rmax + 1)
2825
2826 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002827
Matthew Haddon1c00b712021-10-01 15:51:03 +01002828 if dtypeFilter is not None:
2829 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002830 # Create list of operator dtypes filtered by requested dtypes
2831 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002832 if dtype in dtypeFilter or (
2833 isinstance(dtype, list) and dtype[0] in dtypeFilter
2834 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002835 cleanDtypeFilter.append(dtype)
2836 else:
2837 cleanDtypeFilter = dtypes
2838
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002839 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002840 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002841 "shapeFilter": shapeFilter,
2842 "rankFilter": cleanRankFilter,
2843 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002844 }
2845 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002846 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002847 if validator is not None:
2848 validator_info = validator(check=False, op=op)
2849 else:
2850 return None
2851
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002852 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002853
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002854 # Set parameters as required
2855 if error_arguments["rank"] is not None:
2856 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002857 else:
2858 rankFilter = cleanRankFilter
2859
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002860 if error_arguments["dtype"] is not None:
2861 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002862 else:
2863 dtypeFilter = cleanDtypeFilter
2864
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002865 if error_arguments["shape"] is not None:
2866 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002867 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002868 shapeFilter = shapeFilter[
2869 :2
2870 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002871
2872 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002873 "shapeFilter": shapeFilter,
2874 "rankFilter": rankFilter,
2875 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002876 }
2877 return filterDict
2878
Kevin Cheng550ccc52021-03-03 11:21:43 -08002879 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002880 self,
2881 opName,
2882 shapeFilter=[None],
2883 rankFilter=None,
2884 dtypeFilter=None,
2885 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002886 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002887
2888 try:
2889 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002890 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002891 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002892
2893 # Initialize a new random number generator
2894 self.rng = np.random.default_rng(self.random_seed)
2895
Jeremy Johnson1271c442023-09-05 11:39:26 +01002896 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002897
Eric Kunzee5e26762020-10-13 16:11:07 -07002898 # Test list consists of a tuple of:
2899 # (opName, testNameStr, dtype, shapeList, argumentsList)
2900 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002901 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002902 error_if_validators = op["error_if_validators"]
2903 else:
2904 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002905
Matthew Haddon1c00b712021-10-01 15:51:03 +01002906 for validator in error_if_validators:
2907 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002908 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002909 else:
2910 error_name = None
2911
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002912 filterDict = self.create_filter_lists(
2913 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2914 )
2915 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002916 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002917 cleanRankFilter = filterDict["rankFilter"]
2918 cleanDtypeFilter = filterDict["dtypeFilter"]
2919 cleanShapeFilter = filterDict["shapeFilter"]
2920 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002921
2922 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002923 for t in cleanDtypeFilter:
2924 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002925 # Filter out by rank
2926 if shape is not None and len(shape) != r:
2927 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002928 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002929 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002930
Matthew Haddon74567092021-07-16 15:38:20 +01002931 shapeStr = self.shapeStr(shapeList[0])
2932 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002933
Matthew Haddon74567092021-07-16 15:38:20 +01002934 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2935 argList = []
2936 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002937 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002938 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002939 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002940
Matthew Haddon74567092021-07-16 15:38:20 +01002941 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002942 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002943 if argStr:
2944 testStr = "{}_{}_{}_{}".format(
2945 opName, shapeStr, typeStr, argStr
2946 )
2947 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002948 testStr = "{}_{}_{}".format(
2949 opName, shapeStr, typeStr
2950 )
2951 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002952 if argStr:
2953 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2954 opName, error_name, shapeStr, typeStr, argStr
2955 )
2956 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002957 testStr = "{}_ERRORIF_{}_{}_{}".format(
2958 opName, error_name, shapeStr, typeStr
2959 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002960
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002961 testList.append(
2962 (opName, testStr, t, error_name, shapeList, args)
2963 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002964
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002965 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002966 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2967 if "invalid_test_validators" in op:
2968 invalid_test_validators = op["invalid_test_validators"]
2969 clean_testList = []
2970 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002971 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002972 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002973 if validator_fcn(
2974 opName=test[0],
2975 input_dtype=test[2],
2976 shapeList=test[4],
2977 args=test[5],
2978 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002979 remove_test = True
2980 if not remove_test:
2981 clean_testList.append(test)
2982 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002983
2984 return testList
2985
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002986 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002987 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002988 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002989 try:
2990 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002991 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002992 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002993
Jeremy Johnson0c716862023-04-13 17:18:19 +01002994 if self.args.verbose:
2995 print(f"Creating {testStr}")
2996
Eric Kunzee5e26762020-10-13 16:11:07 -07002997 # Create a serializer
2998 self.createSerializer(opName, testStr)
2999
Jeremy Johnson1271c442023-09-05 11:39:26 +01003000 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003001 if "error_if_validators" in op:
3002 error_if_validators = op["error_if_validators"]
3003 else:
3004 error_if_validators = None
3005
Kevin Cheng550ccc52021-03-03 11:21:43 -08003006 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003007 num_operands = pCount + cCount
3008
3009 if isinstance(dtype_or_dtypeList, list):
3010 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003011 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003012 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003013 else:
3014 dtypeList = [dtype_or_dtypeList] * (num_operands)
3015
Won Jeon74342e52024-01-09 00:34:40 +00003016 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003017 assert (
3018 len(shapeList) == num_operands
3019 ), "shapeList length {} must match number of operands {}".format(
3020 len(shapeList), num_operands
3021 )
3022 assert (
3023 len(dtypeList) == num_operands
3024 ), "dtypeList length {} must match number of operands {}".format(
3025 len(dtypeList), num_operands
3026 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003027
3028 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003029 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003030 except KeyError:
3031 qgen = None
3032
3033 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003034
Matthew Haddon1c00b712021-10-01 15:51:03 +01003035 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003036 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003037 else:
3038 qinfo = None
3039
Jeremy Johnson1271c442023-09-05 11:39:26 +01003040 # Extra meta data for the desc.json
3041 tensMeta = {}
3042
Jeremy Johnson587cc842024-02-08 11:45:44 +00003043 # Check we are using the new interface with an argsDict dictionary
3044 assert isinstance(
3045 argsDict, dict
3046 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003047
Jeremy Johnson587cc842024-02-08 11:45:44 +00003048 # New interface with args info in dictionary
3049 assert "dg_type" in argsDict
3050 tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name)
3051 if tvgInfo.dataGenDict:
3052 tensMeta["data_gen"] = tvgInfo.dataGenDict
3053 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003054
Jeremy Johnson587cc842024-02-08 11:45:44 +00003055 result = build_fcn(
3056 self,
3057 op,
3058 tens,
3059 argsDict,
3060 validator_fcns=error_if_validators,
3061 error_name=error_name,
3062 qinfo=qinfo,
3063 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003064
Jeremy Johnson1271c442023-09-05 11:39:26 +01003065 if result:
Les Bell729b0352021-11-24 10:28:21 +00003066 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003067 if isinstance(result, TosaTestGen.BuildInfo):
3068 # Add the compliance meta data (if any)
3069 compliance = result.getComplianceInfo()
3070 if compliance:
3071 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003072 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00003073 else:
3074 # The test is not valid
3075 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01003076
Eric Kunzee5e26762020-10-13 16:11:07 -07003077 def createDynamicOpLists(self):
3078
Jeremy Johnson00423432022-09-12 17:27:37 +01003079 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3080 # Already created these lists (can occur when class is initialized more than once)
3081 return
3082
Eric Kunzee5e26762020-10-13 16:11:07 -07003083 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003084 if not self.args.level8k:
3085 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3086 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3087 else:
3088 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3089 KERNELS_2D = [[1, bigK], [bigK, 2]]
3090 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003091
Kevin Cheng1533b852021-09-01 12:51:58 -07003092 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003093 testName = "conv2d_{}x{}".format(k[0], k[1])
3094 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3095 self.TOSA_OP_LIST[testName]["filter"] = k
3096 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003097
Kevin Cheng550ccc52021-03-03 11:21:43 -08003098 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3099 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3100 "depthwise_conv2d_TEMPLATE"
3101 ].copy()
3102 self.TOSA_OP_LIST[testName]["filter"] = k
3103 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003104
Kevin Cheng550ccc52021-03-03 11:21:43 -08003105 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3106 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3107 "transpose_conv2d_TEMPLATE"
3108 ].copy()
3109 self.TOSA_OP_LIST[testName]["filter"] = k
3110 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003111
Kevin Cheng1533b852021-09-01 12:51:58 -07003112 for k in KERNELS_3D:
3113 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3114 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3115 self.TOSA_OP_LIST[testName]["filter"] = k
3116 self.TOSA_OP_LIST[testName]["template"] = False
3117
Eric Kunzee5e26762020-10-13 16:11:07 -07003118 # Delete any templates after having created any dynamic ops
3119 # This is a two-pass operation because it's bad practice to delete
3120 # keys from dictionaries while iterating
3121 keyList = []
3122 for k in self.TOSA_OP_LIST:
3123 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003124 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003125 keyList.append(k)
3126 continue
3127 except KeyError:
3128 pass
3129
3130 for k in keyList:
3131 del self.TOSA_OP_LIST[k]
3132
3133 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003134 """Fill in default fields for ops if they aren't already specified.
3135 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003136 for op in self.TOSA_OP_LIST:
3137
3138 # Required fields
3139 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003140 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003141 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003142 raise Exception(
3143 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3144 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003145
3146 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003147 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003148 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003149 raise Exception(
3150 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3151 op
3152 )
3153 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003154
3155 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003156 _ = self.TOSA_OP_LIST[op]["types"]
3157 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003158 raise Exception(
3159 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3160 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003161
3162 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003163 _ = self.TOSA_OP_LIST[op]["op"]
3164 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003165 raise Exception(
3166 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3167 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003168
3169 # Put in default rank range, if missing
3170 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003171 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003172 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003173 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003174
3175 # Tensor operator list
3176 # 'op': op name
3177 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003178 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3179 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003180 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3181 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003182 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003183
Kevin Cheng550ccc52021-03-03 11:21:43 -08003184 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003185 TYPE_INT_FP = [
3186 DType.INT8,
3187 DType.INT16,
3188 DType.INT32,
3189 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003190 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003191 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003192 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003193
Kevin Cheng550ccc52021-03-03 11:21:43 -08003194 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003195 TYPE_FI32 = [
3196 DType.FP32,
3197 DType.FP16,
3198 DType.BF16,
3199 DType.INT32,
3200 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003201 TYPE_FIB = [
3202 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003203 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003204 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003205 DType.INT8,
3206 DType.INT16,
3207 DType.INT32,
3208 DType.BOOL,
3209 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003210 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003211
Won Jeon2c34b462024-02-06 18:37:00 +00003212 TYPE_NARROW_INT_FP = [
3213 DType.INT8,
3214 DType.INT16,
3215 DType.FP16,
3216 DType.BF16,
3217 DType.FP32,
3218 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003219
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003220 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003221 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003222 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003223 [DType.INT8, DType.INT8, DType.INT32],
3224 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003225 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003226 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003227 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003228 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003229 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3230 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003231 ]
3232
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003233 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003234
3235 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003236 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003237 "argmax": {
3238 "op": Op.ARGMAX,
3239 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003240 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003241 "build_fcn": (
3242 build_argmax,
3243 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003244 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003245 TosaArgGen.agAxis,
3246 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003247 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003248 "error_if_validators": (
3249 TosaErrorValidator.evAxisSmallerZero,
3250 TosaErrorValidator.evAxisLargerRank,
3251 TosaErrorValidator.evArgmaxOutputRankMismatch,
3252 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3253 TosaErrorValidator.evWrongRank,
3254 TosaErrorValidator.evWrongInputType,
3255 TosaErrorValidator.evWrongOutputType,
3256 TosaErrorValidator.evWrongInputList,
3257 TosaErrorValidator.evWrongOutputList,
3258 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003259 "data_gen": {
3260 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3261 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003262 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003263 "avg_pool2d": {
3264 "op": Op.AVG_POOL2D,
3265 "operands": (1, 0),
3266 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003267 "build_fcn": (
3268 build_pool2d,
3269 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003270 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003271 TosaArgGen.agPooling,
3272 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003273 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003274 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003275 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003276 "error_if_validators": (
3277 TosaErrorValidator.evKernelSmallerOne,
3278 TosaErrorValidator.evStrideSmallerOne,
3279 TosaErrorValidator.evPadSmallerZero,
3280 TosaErrorValidator.evWrongRank,
3281 TosaErrorValidator.evWrongInputType,
3282 TosaErrorValidator.evWrongOutputType,
3283 TosaErrorValidator.evWrongInputList,
3284 TosaErrorValidator.evWrongOutputList,
3285 TosaErrorValidator.evInputZeroPointNotZero,
3286 TosaErrorValidator.evOutputZeroPointNotZero,
3287 TosaErrorValidator.evPadLargerEqualKernel,
3288 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003289 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003290 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003291 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003292 "data_gen": {
3293 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3294 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003295 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003296 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003297 "conv2d_TEMPLATE": {
3298 "op": Op.CONV2D,
3299 "operands": (1, 2),
3300 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003301 "build_fcn": (
3302 build_conv2d,
3303 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003304 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003305 TosaArgGen.agConv,
3306 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003307 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003308 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003309 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3310 "error_if_validators": (
3311 TosaErrorValidator.evWrongInputType,
3312 TosaErrorValidator.evWrongOutputType,
3313 TosaErrorValidator.evWrongInputList,
3314 TosaErrorValidator.evWrongOutputList,
3315 TosaErrorValidator.evInputZeroPointNotZero,
3316 TosaErrorValidator.evWeightZeroPointNotZero,
3317 TosaErrorValidator.evPadSmallerZero,
3318 TosaErrorValidator.evStrideSmallerOne,
3319 TosaErrorValidator.evDilationSmallerOne,
3320 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003321 TosaErrorValidator.evConvOutputShapeMismatch,
3322 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003323 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003324 "data_gen": {
3325 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3326 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003327 "template": True,
3328 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003329 # Templated operator. Filled in by createDynamicOpLists
3330 "conv3d_TEMPLATE": {
3331 "op": Op.CONV3D,
3332 "operands": (1, 2),
3333 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003334 "build_fcn": (
3335 build_conv3d,
3336 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003337 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003338 TosaArgGen.agConv,
3339 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003340 "qgen": TosaQuantGen.qgConv,
3341 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003342 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3343 "error_if_validators": (
3344 TosaErrorValidator.evWrongInputType,
3345 TosaErrorValidator.evWrongOutputType,
3346 TosaErrorValidator.evWrongInputList,
3347 TosaErrorValidator.evWrongOutputList,
3348 TosaErrorValidator.evInputZeroPointNotZero,
3349 TosaErrorValidator.evWeightZeroPointNotZero,
3350 TosaErrorValidator.evPadSmallerZero,
3351 TosaErrorValidator.evStrideSmallerOne,
3352 TosaErrorValidator.evDilationSmallerOne,
3353 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003354 TosaErrorValidator.evConvOutputShapeMismatch,
3355 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003356 ),
evacha0147ab1762024-01-29 13:23:23 +00003357 "data_gen": {
3358 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3359 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003360 "template": True,
3361 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003362 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003363 "depthwise_conv2d_TEMPLATE": {
3364 "op": Op.DEPTHWISE_CONV2D,
3365 "operands": (1, 2),
3366 "filter": [1, 1],
3367 "rank": (4, 4),
3368 "build_fcn": (
3369 build_depthwise_conv2d,
3370 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003371 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003372 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003373 ),
3374 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003375 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003376 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3377 "error_if_validators": (
3378 TosaErrorValidator.evWrongInputType,
3379 TosaErrorValidator.evWrongOutputType,
3380 TosaErrorValidator.evWrongInputList,
3381 TosaErrorValidator.evWrongOutputList,
3382 TosaErrorValidator.evInputZeroPointNotZero,
3383 TosaErrorValidator.evWeightZeroPointNotZero,
3384 TosaErrorValidator.evPadSmallerZero,
3385 TosaErrorValidator.evStrideSmallerOne,
3386 TosaErrorValidator.evDilationSmallerOne,
3387 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003388 TosaErrorValidator.evConvOutputShapeMismatch,
3389 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003390 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003391 "data_gen": {
3392 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3393 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003394 "template": True,
3395 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003396 "fully_connected": {
3397 "op": Op.FULLY_CONNECTED,
3398 "operands": (1, 2),
3399 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003400 "build_fcn": (
3401 build_fully_connected,
3402 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003403 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003404 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003405 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003406 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003407 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003408 "error_if_validators": (
3409 TosaErrorValidator.evInputZeroPointNotZero,
3410 TosaErrorValidator.evWeightZeroPointNotZero,
3411 TosaErrorValidator.evWrongRank,
3412 TosaErrorValidator.evWrongInputType,
3413 TosaErrorValidator.evWrongOutputType,
3414 TosaErrorValidator.evWrongInputList,
3415 TosaErrorValidator.evWrongOutputList,
3416 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003417 "data_gen": {
3418 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3419 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003420 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003421 "matmul": {
3422 "op": Op.MATMUL,
3423 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003424 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003425 "build_fcn": (
3426 build_matmul,
3427 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003428 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003429 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003430 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003431 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003432 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003433 "error_if_validators": (
3434 TosaErrorValidator.evInputZeroPointNotZero,
3435 TosaErrorValidator.evWrongRank,
3436 TosaErrorValidator.evWrongInputType,
3437 TosaErrorValidator.evWrongOutputType,
3438 TosaErrorValidator.evWrongInputList,
3439 TosaErrorValidator.evWrongOutputList,
3440 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003441 "data_gen": {
3442 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003443 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003444 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003445 "max_pool2d": {
3446 "op": Op.MAX_POOL2D,
3447 "operands": (1, 0),
3448 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003449 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003450 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003451 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003452 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003453 TosaArgGen.agPooling,
3454 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003455 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003456 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003457 "error_if_validators": (
3458 TosaErrorValidator.evKernelSmallerOne,
3459 TosaErrorValidator.evStrideSmallerOne,
3460 TosaErrorValidator.evPadSmallerZero,
3461 TosaErrorValidator.evWrongRank,
3462 TosaErrorValidator.evWrongInputType,
3463 TosaErrorValidator.evWrongOutputType,
3464 TosaErrorValidator.evWrongInputList,
3465 TosaErrorValidator.evWrongOutputList,
3466 TosaErrorValidator.evPadLargerEqualKernel,
3467 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003468 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003469 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003470 "data_gen": {
3471 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3472 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003474 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003475 "transpose_conv2d_TEMPLATE": {
3476 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003477 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003478 "rank": (4, 4),
3479 "build_fcn": (
3480 build_transpose_conv2d,
3481 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003482 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003483 TosaArgGen.agTransposeConv2D,
3484 ),
3485 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003486 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003487 "invalid_test_validators": (
3488 TosaInvalidValidator.ivHeightWidthInvalid,
3489 TosaInvalidValidator.ivNonPositiveOutputShape,
3490 ),
3491 "error_if_validators": (
3492 TosaErrorValidator.evWrongInputType,
3493 TosaErrorValidator.evWrongOutputType,
3494 TosaErrorValidator.evWrongInputList,
3495 TosaErrorValidator.evWrongOutputList,
3496 TosaErrorValidator.evInputZeroPointNotZero,
3497 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003498 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003499 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003500 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003501 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003502 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003503 "data_gen": {
3504 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3505 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003506 "template": True,
3507 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003508 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003509 "clamp": {
3510 "op": Op.CLAMP,
3511 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003512 "build_fcn": (
3513 build_clamp,
3514 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003515 TosaTensorValuesGen.tvgLazyGenDefault,
3516 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003517 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003518 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003519 "error_if_validators": (
3520 TosaErrorValidator.evMaxSmallerMin,
3521 TosaErrorValidator.evWrongInputType,
3522 TosaErrorValidator.evWrongOutputType,
3523 TosaErrorValidator.evWrongInputList,
3524 TosaErrorValidator.evWrongOutputList,
3525 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003526 "data_gen": {
3527 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3528 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003529 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003530 "sigmoid": {
3531 "op": Op.SIGMOID,
3532 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003533 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003534 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003535 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003536 TosaTensorValuesGen.tvgLazyGenDefault,
3537 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003538 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003539 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003540 "error_if_validators": (
3541 TosaErrorValidator.evWrongInputType,
3542 TosaErrorValidator.evWrongOutputType,
3543 TosaErrorValidator.evWrongInputList,
3544 TosaErrorValidator.evWrongOutputList,
3545 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003546 "data_gen": {
3547 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3548 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003549 },
3550 "tanh": {
3551 "op": Op.TANH,
3552 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003553 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003554 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003555 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003556 TosaTensorValuesGen.tvgLazyGenDefault,
3557 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003558 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003559 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003560 "error_if_validators": (
3561 TosaErrorValidator.evWrongInputType,
3562 TosaErrorValidator.evWrongOutputType,
3563 TosaErrorValidator.evWrongInputList,
3564 TosaErrorValidator.evWrongOutputList,
3565 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003566 "data_gen": {
3567 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3568 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003569 "compliance": {
3570 "abs_error_lower_bound": 0.5,
3571 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003572 },
Won Jeon78155c62023-06-10 00:20:04 +00003573 "erf": {
3574 "op": Op.ERF,
3575 "operands": (1, 0),
3576 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003577 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003578 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003579 TosaTensorValuesGen.tvgLazyGenDefault,
3580 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003581 ),
3582 "types": TYPE_FP,
3583 "error_if_validators": (
3584 TosaErrorValidator.evWrongInputType,
3585 TosaErrorValidator.evWrongOutputType,
3586 TosaErrorValidator.evWrongInputList,
3587 TosaErrorValidator.evWrongOutputList,
3588 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003589 "data_gen": {
3590 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3591 },
3592 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003593 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003594 # Elementwise Binary Operators
3595 "add": {
3596 "op": Op.ADD,
3597 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003598 "build_fcn": (
3599 build_binary_broadcast,
3600 TosaTensorGen.tgBroadcastFuzz,
3601 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003602 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003603 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003604 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003605 "error_if_validators": (
3606 TosaErrorValidator.evRankMismatch,
3607 TosaErrorValidator.evWrongInputType,
3608 TosaErrorValidator.evWrongOutputType,
3609 TosaErrorValidator.evWrongInputList,
3610 TosaErrorValidator.evWrongOutputList,
3611 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003612 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003613 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003614 "data_gen": {
3615 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3616 },
3617 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003618 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003619 "arithmetic_right_shift": {
3620 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3621 "operands": (2, 0),
3622 "build_fcn": (
3623 build_arithmetic_right_shift,
3624 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003625 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003626 TosaArgGen.agArithmeticRightShift,
3627 ),
3628 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003629 "error_if_validators": (
3630 TosaErrorValidator.evRankMismatch,
3631 TosaErrorValidator.evWrongInputType,
3632 TosaErrorValidator.evWrongOutputType,
3633 TosaErrorValidator.evWrongInputList,
3634 TosaErrorValidator.evWrongOutputList,
3635 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003636 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003637 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003638 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003639 "bitwise_and": {
3640 "op": Op.BITWISE_AND,
3641 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003642 "build_fcn": (
3643 build_binary_broadcast,
3644 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003645 TosaTensorValuesGen.tvgLazyGenDefault,
3646 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003647 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003648 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003649 "error_if_validators": (
3650 TosaErrorValidator.evRankMismatch,
3651 TosaErrorValidator.evWrongInputType,
3652 TosaErrorValidator.evWrongOutputType,
3653 TosaErrorValidator.evWrongInputList,
3654 TosaErrorValidator.evWrongOutputList,
3655 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003656 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003657 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003658 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003659 "bitwise_or": {
3660 "op": Op.BITWISE_OR,
3661 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003662 "build_fcn": (
3663 build_binary_broadcast,
3664 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003665 TosaTensorValuesGen.tvgLazyGenDefault,
3666 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003667 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003668 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003669 "error_if_validators": (
3670 TosaErrorValidator.evRankMismatch,
3671 TosaErrorValidator.evWrongInputType,
3672 TosaErrorValidator.evWrongOutputType,
3673 TosaErrorValidator.evWrongInputList,
3674 TosaErrorValidator.evWrongOutputList,
3675 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003676 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003677 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003678 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003679 "bitwise_xor": {
3680 "op": Op.BITWISE_XOR,
3681 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003682 "build_fcn": (
3683 build_binary_broadcast,
3684 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003685 TosaTensorValuesGen.tvgLazyGenDefault,
3686 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003687 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003688 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003689 "error_if_validators": (
3690 TosaErrorValidator.evRankMismatch,
3691 TosaErrorValidator.evWrongInputType,
3692 TosaErrorValidator.evWrongOutputType,
3693 TosaErrorValidator.evWrongInputList,
3694 TosaErrorValidator.evWrongOutputList,
3695 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003696 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003697 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003698 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003699 "intdiv": {
3700 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003701 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003702 "build_fcn": (
3703 build_binary_broadcast,
3704 TosaTensorGen.tgBroadcastFuzz,
3705 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003706 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003707 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003708 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003709 "error_if_validators": (
3710 TosaErrorValidator.evRankMismatch,
3711 TosaErrorValidator.evWrongInputType,
3712 TosaErrorValidator.evWrongOutputType,
3713 TosaErrorValidator.evWrongInputList,
3714 TosaErrorValidator.evWrongOutputList,
3715 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003716 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003717 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003718 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003719 "logical_and": {
3720 "op": Op.LOGICAL_AND,
3721 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003722 "build_fcn": (
3723 build_binary_broadcast,
3724 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003725 TosaTensorValuesGen.tvgLazyGenDefault,
3726 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003727 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003728 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003729 "error_if_validators": (
3730 TosaErrorValidator.evRankMismatch,
3731 TosaErrorValidator.evWrongInputType,
3732 TosaErrorValidator.evWrongOutputType,
3733 TosaErrorValidator.evWrongInputList,
3734 TosaErrorValidator.evWrongOutputList,
3735 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003736 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003737 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003738 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003739 "logical_left_shift": {
3740 "op": Op.LOGICAL_LEFT_SHIFT,
3741 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003742 "build_fcn": (
3743 build_binary_broadcast,
3744 TosaTensorGen.tgBroadcastFuzz,
3745 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003746 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003747 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003748 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003749 "error_if_validators": (
3750 TosaErrorValidator.evRankMismatch,
3751 TosaErrorValidator.evWrongInputType,
3752 TosaErrorValidator.evWrongOutputType,
3753 TosaErrorValidator.evWrongInputList,
3754 TosaErrorValidator.evWrongOutputList,
3755 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003756 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003757 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003758 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003759 "logical_right_shift": {
3760 "op": Op.LOGICAL_RIGHT_SHIFT,
3761 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003762 "build_fcn": (
3763 build_binary_broadcast,
3764 TosaTensorGen.tgBroadcastFuzz,
3765 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003766 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003767 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003768 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003769 "error_if_validators": (
3770 TosaErrorValidator.evRankMismatch,
3771 TosaErrorValidator.evWrongInputType,
3772 TosaErrorValidator.evWrongOutputType,
3773 TosaErrorValidator.evWrongInputList,
3774 TosaErrorValidator.evWrongOutputList,
3775 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003776 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003777 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003778 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003779 "logical_or": {
3780 "op": Op.LOGICAL_OR,
3781 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003782 "build_fcn": (
3783 build_binary_broadcast,
3784 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003785 TosaTensorValuesGen.tvgLazyGenDefault,
3786 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003787 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003788 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003789 "error_if_validators": (
3790 TosaErrorValidator.evRankMismatch,
3791 TosaErrorValidator.evWrongInputType,
3792 TosaErrorValidator.evWrongOutputType,
3793 TosaErrorValidator.evWrongInputList,
3794 TosaErrorValidator.evWrongOutputList,
3795 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003796 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003797 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003798 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003799 "logical_xor": {
3800 "op": Op.LOGICAL_XOR,
3801 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003802 "build_fcn": (
3803 build_binary_broadcast,
3804 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003805 TosaTensorValuesGen.tvgLazyGenDefault,
3806 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003807 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003808 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003809 "error_if_validators": (
3810 TosaErrorValidator.evRankMismatch,
3811 TosaErrorValidator.evWrongInputType,
3812 TosaErrorValidator.evWrongOutputType,
3813 TosaErrorValidator.evWrongInputList,
3814 TosaErrorValidator.evWrongOutputList,
3815 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003816 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003817 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003818 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003819 "maximum": {
3820 "op": Op.MAXIMUM,
3821 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003822 "build_fcn": (
3823 build_binary_broadcast,
3824 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003825 TosaTensorValuesGen.tvgLazyGenDefault,
3826 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003827 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003828 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003829 "error_if_validators": (
3830 TosaErrorValidator.evRankMismatch,
3831 TosaErrorValidator.evWrongInputType,
3832 TosaErrorValidator.evWrongOutputType,
3833 TosaErrorValidator.evWrongInputList,
3834 TosaErrorValidator.evWrongOutputList,
3835 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003836 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003837 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003838 "data_gen": {
3839 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3840 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003841 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003842 "minimum": {
3843 "op": Op.MINIMUM,
3844 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003845 "build_fcn": (
3846 build_binary_broadcast,
3847 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003848 TosaTensorValuesGen.tvgLazyGenDefault,
3849 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003850 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003851 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003852 "error_if_validators": (
3853 TosaErrorValidator.evRankMismatch,
3854 TosaErrorValidator.evWrongInputType,
3855 TosaErrorValidator.evWrongOutputType,
3856 TosaErrorValidator.evWrongInputList,
3857 TosaErrorValidator.evWrongOutputList,
3858 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003859 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003860 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003861 "data_gen": {
3862 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3863 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003864 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003865 "mul": {
3866 "op": Op.MUL,
3867 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003868 "build_fcn": (
3869 build_mul,
3870 TosaTensorGen.tgBroadcastFuzz,
3871 TosaTensorValuesGen.tvgMul,
3872 TosaArgGen.agMul,
3873 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003874 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003875 "error_if_validators": (
3876 TosaErrorValidator.evWrongInputType,
3877 TosaErrorValidator.evWrongOutputType,
3878 TosaErrorValidator.evWrongInputList,
3879 TosaErrorValidator.evWrongOutputList,
3880 TosaErrorValidator.evRankMismatch,
3881 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003882 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003883 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003884 "data_gen": {
3885 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3886 },
3887 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003888 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003889 "pow": {
3890 "op": Op.POW,
3891 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003892 "build_fcn": (
3893 build_binary_broadcast,
3894 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003895 TosaTensorValuesGen.tvgPow,
3896 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003897 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003898 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003899 "error_if_validators": (
3900 TosaErrorValidator.evRankMismatch,
3901 TosaErrorValidator.evWrongInputType,
3902 TosaErrorValidator.evWrongOutputType,
3903 TosaErrorValidator.evWrongInputList,
3904 TosaErrorValidator.evWrongOutputList,
3905 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003906 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003907 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003908 "data_gen": {
3909 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3910 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003911 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003912 "sub": {
3913 "op": Op.SUB,
3914 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003915 "build_fcn": (
3916 build_binary_broadcast,
3917 TosaTensorGen.tgBroadcastFuzz,
3918 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003919 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003920 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003921 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003922 "error_if_validators": (
3923 TosaErrorValidator.evRankMismatch,
3924 TosaErrorValidator.evWrongInputType,
3925 TosaErrorValidator.evWrongOutputType,
3926 TosaErrorValidator.evWrongInputList,
3927 TosaErrorValidator.evWrongOutputList,
3928 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003929 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003930 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003931 "data_gen": {
3932 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3933 },
3934 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003935 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003936 "table": {
3937 "op": Op.TABLE,
3938 # Use the automatic generation functions to create the input array
3939 # but create the table tensor in the build function, as it may be
3940 # a different type from the input
3941 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003942 "build_fcn": (
3943 build_table,
3944 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003945 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003946 TosaArgGen.agTable,
3947 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003948 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003949 "error_if_validators": (
3950 TosaErrorValidator.evWrongInputType,
3951 TosaErrorValidator.evWrongOutputType,
3952 TosaErrorValidator.evWrongInputList,
3953 TosaErrorValidator.evWrongOutputList,
3954 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003955 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003956 # Elementwise Unary operators
3957 "abs": {
3958 "op": Op.ABS,
3959 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003960 "build_fcn": (
3961 build_unary,
3962 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003963 TosaTensorValuesGen.tvgLazyGenDefault,
3964 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003965 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003966 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003967 "error_if_validators": (
3968 TosaErrorValidator.evWrongInputType,
3969 TosaErrorValidator.evWrongOutputType,
3970 TosaErrorValidator.evWrongInputList,
3971 TosaErrorValidator.evWrongOutputList,
3972 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003973 "data_gen": {
3974 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3975 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003976 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003977 "bitwise_not": {
3978 "op": Op.BITWISE_NOT,
3979 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003980 "build_fcn": (
3981 build_unary,
3982 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003983 TosaTensorValuesGen.tvgLazyGenDefault,
3984 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003985 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003986 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003987 "error_if_validators": (
3988 TosaErrorValidator.evWrongInputType,
3989 TosaErrorValidator.evWrongOutputType,
3990 TosaErrorValidator.evWrongInputList,
3991 TosaErrorValidator.evWrongOutputList,
3992 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003993 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003994 "ceil": {
3995 "op": Op.CEIL,
3996 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003997 "build_fcn": (
3998 build_unary,
3999 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004000 TosaTensorValuesGen.tvgLazyGenDefault,
4001 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004002 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004003 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004004 "error_if_validators": (
4005 TosaErrorValidator.evWrongInputType,
4006 TosaErrorValidator.evWrongOutputType,
4007 TosaErrorValidator.evWrongInputList,
4008 TosaErrorValidator.evWrongOutputList,
4009 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004010 "data_gen": {
4011 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4012 },
4013 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004014 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004015 "clz": {
4016 "op": Op.CLZ,
4017 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004018 "build_fcn": (
4019 build_unary,
4020 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004021 TosaTensorValuesGen.tvgLazyGenDefault,
4022 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004023 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004024 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004025 "error_if_validators": (
4026 TosaErrorValidator.evWrongInputType,
4027 TosaErrorValidator.evWrongOutputType,
4028 TosaErrorValidator.evWrongInputList,
4029 TosaErrorValidator.evWrongOutputList,
4030 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004031 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004032 "exp": {
4033 "op": Op.EXP,
4034 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004035 "build_fcn": (
4036 build_unary,
4037 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004038 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004039 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004040 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004041 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004042 "error_if_validators": (
4043 TosaErrorValidator.evWrongInputType,
4044 TosaErrorValidator.evWrongOutputType,
4045 TosaErrorValidator.evWrongInputList,
4046 TosaErrorValidator.evWrongOutputList,
4047 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004048 "data_gen": {
4049 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4050 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004051 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004052 "floor": {
4053 "op": Op.FLOOR,
4054 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004055 "build_fcn": (
4056 build_unary,
4057 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004058 TosaTensorValuesGen.tvgLazyGenDefault,
4059 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004060 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004061 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004062 "error_if_validators": (
4063 TosaErrorValidator.evWrongInputType,
4064 TosaErrorValidator.evWrongOutputType,
4065 TosaErrorValidator.evWrongInputList,
4066 TosaErrorValidator.evWrongOutputList,
4067 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004068 "data_gen": {
4069 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4070 },
4071 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004072 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004073 "log": {
4074 "op": Op.LOG,
4075 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004076 "build_fcn": (
4077 build_unary,
4078 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004079 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004080 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004081 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004082 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004083 "error_if_validators": (
4084 TosaErrorValidator.evWrongInputType,
4085 TosaErrorValidator.evWrongOutputType,
4086 TosaErrorValidator.evWrongInputList,
4087 TosaErrorValidator.evWrongOutputList,
4088 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004089 "data_gen": {
4090 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4091 },
4092 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004093 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004094 "logical_not": {
4095 "op": Op.LOGICAL_NOT,
4096 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004097 "build_fcn": (
4098 build_unary,
4099 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004100 TosaTensorValuesGen.tvgLazyGenDefault,
4101 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004102 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004103 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004104 "error_if_validators": (
4105 TosaErrorValidator.evWrongInputType,
4106 TosaErrorValidator.evWrongOutputType,
4107 TosaErrorValidator.evWrongInputList,
4108 TosaErrorValidator.evWrongOutputList,
4109 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004110 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004111 "negate": {
4112 "op": Op.NEGATE,
4113 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004114 "build_fcn": (
4115 build_unary,
4116 TosaTensorGen.tgBasic,
4117 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004118 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004119 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004120 "qgen": TosaQuantGen.qgUnary,
4121 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004122 "error_if_validators": (
4123 TosaErrorValidator.evInputZeroPointNotZero,
4124 TosaErrorValidator.evOutputZeroPointNotZero,
4125 TosaErrorValidator.evWrongInputType,
4126 TosaErrorValidator.evWrongOutputType,
4127 TosaErrorValidator.evWrongInputList,
4128 TosaErrorValidator.evWrongOutputList,
4129 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004130 "data_gen": {
4131 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4132 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004133 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004134 "reciprocal": {
4135 "op": Op.RECIPROCAL,
4136 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004137 "build_fcn": (
4138 build_unary,
4139 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004140 TosaTensorValuesGen.tvgLazyGenDefault,
4141 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004142 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004143 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004144 "error_if_validators": (
4145 TosaErrorValidator.evWrongInputType,
4146 TosaErrorValidator.evWrongOutputType,
4147 TosaErrorValidator.evWrongInputList,
4148 TosaErrorValidator.evWrongOutputList,
4149 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004150 "data_gen": {
4151 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4152 },
4153 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004154 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004155 "rsqrt": {
4156 "op": Op.RSQRT,
4157 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004158 "build_fcn": (
4159 build_unary,
4160 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004161 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004162 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004163 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004164 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004165 "error_if_validators": (
4166 TosaErrorValidator.evWrongInputType,
4167 TosaErrorValidator.evWrongOutputType,
4168 TosaErrorValidator.evWrongInputList,
4169 TosaErrorValidator.evWrongOutputList,
4170 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004171 "data_gen": {
4172 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4173 },
4174 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004175 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004176 # Elementwise Ternary operators
4177 "select": {
4178 "op": Op.SELECT,
4179 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004180 "build_fcn": (
4181 build_select,
4182 TosaTensorGen.tgBroadcastFuzz,
4183 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004184 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004185 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004186 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004187 "error_if_validators": (
4188 TosaErrorValidator.evRankMismatch,
4189 TosaErrorValidator.evWrongInputType,
4190 TosaErrorValidator.evWrongOutputType,
4191 TosaErrorValidator.evWrongInputList,
4192 TosaErrorValidator.evWrongOutputList,
4193 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004194 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004195 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004196 "data_gen": {
4197 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4198 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004199 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004200 # Comparison operators
4201 "equal": {
4202 "op": Op.EQUAL,
4203 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004204 "build_fcn": (
4205 build_comparison,
4206 TosaTensorGen.tgBroadcastFuzz,
4207 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004208 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004209 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004210 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004211 "error_if_validators": (
4212 TosaErrorValidator.evRankMismatch,
4213 TosaErrorValidator.evWrongInputType,
4214 TosaErrorValidator.evWrongOutputType,
4215 TosaErrorValidator.evWrongInputList,
4216 TosaErrorValidator.evWrongOutputList,
4217 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004218 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004219 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004220 "data_gen": {
4221 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4222 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004223 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004224 "greater_equal": {
4225 "op": Op.GREATER_EQUAL,
4226 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004227 "build_fcn": (
4228 build_comparison,
4229 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004230 TosaTensorValuesGen.tvgLazyGenDefault,
4231 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004232 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004233 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004234 "error_if_validators": (
4235 TosaErrorValidator.evRankMismatch,
4236 TosaErrorValidator.evWrongInputType,
4237 TosaErrorValidator.evWrongOutputType,
4238 TosaErrorValidator.evWrongInputList,
4239 TosaErrorValidator.evWrongOutputList,
4240 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004241 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004242 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004243 "data_gen": {
4244 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4245 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004246 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004247 "greater": {
4248 "op": Op.GREATER,
4249 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004250 "build_fcn": (
4251 build_comparison,
4252 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004253 TosaTensorValuesGen.tvgLazyGenDefault,
4254 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004255 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004256 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004257 "error_if_validators": (
4258 TosaErrorValidator.evRankMismatch,
4259 TosaErrorValidator.evWrongInputType,
4260 TosaErrorValidator.evWrongOutputType,
4261 TosaErrorValidator.evWrongInputList,
4262 TosaErrorValidator.evWrongOutputList,
4263 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004264 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004265 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004266 "data_gen": {
4267 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4268 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004269 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004270 # Reduction operators
4271 "reduce_all": {
4272 "op": Op.REDUCE_ALL,
4273 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004274 "build_fcn": (
4275 build_reduce,
4276 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004277 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004278 TosaArgGen.agAxis,
4279 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004280 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004281 "error_if_validators": (
4282 TosaErrorValidator.evAxisLargerRank,
4283 TosaErrorValidator.evAxisSmallerZero,
4284 TosaErrorValidator.evShapeOfAxisNotOne,
4285 TosaErrorValidator.evWrongInputType,
4286 TosaErrorValidator.evWrongOutputType,
4287 TosaErrorValidator.evWrongRank,
4288 TosaErrorValidator.evWrongInputList,
4289 TosaErrorValidator.evWrongOutputList,
4290 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004291 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004292 "reduce_any": {
4293 "op": Op.REDUCE_ANY,
4294 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004295 "build_fcn": (
4296 build_reduce,
4297 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004298 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004299 TosaArgGen.agAxis,
4300 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004301 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004302 "error_if_validators": (
4303 TosaErrorValidator.evAxisLargerRank,
4304 TosaErrorValidator.evAxisSmallerZero,
4305 TosaErrorValidator.evShapeOfAxisNotOne,
4306 TosaErrorValidator.evWrongInputType,
4307 TosaErrorValidator.evWrongOutputType,
4308 TosaErrorValidator.evWrongRank,
4309 TosaErrorValidator.evWrongInputList,
4310 TosaErrorValidator.evWrongOutputList,
4311 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004312 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004313 "reduce_max": {
4314 "op": Op.REDUCE_MAX,
4315 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004316 "build_fcn": (
4317 build_reduce,
4318 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004319 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004320 TosaArgGen.agAxis,
4321 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004322 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004323 "error_if_validators": (
4324 TosaErrorValidator.evAxisLargerRank,
4325 TosaErrorValidator.evAxisSmallerZero,
4326 TosaErrorValidator.evShapeOfAxisNotOne,
4327 TosaErrorValidator.evWrongInputType,
4328 TosaErrorValidator.evWrongOutputType,
4329 TosaErrorValidator.evWrongRank,
4330 TosaErrorValidator.evWrongInputList,
4331 TosaErrorValidator.evWrongOutputList,
4332 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004333 "data_gen": {
4334 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4335 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004336 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004337 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004338 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004339 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004340 "build_fcn": (
4341 build_reduce,
4342 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004343 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004344 TosaArgGen.agAxis,
4345 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004346 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004347 "error_if_validators": (
4348 TosaErrorValidator.evAxisLargerRank,
4349 TosaErrorValidator.evAxisSmallerZero,
4350 TosaErrorValidator.evShapeOfAxisNotOne,
4351 TosaErrorValidator.evWrongInputType,
4352 TosaErrorValidator.evWrongOutputType,
4353 TosaErrorValidator.evWrongRank,
4354 TosaErrorValidator.evWrongInputList,
4355 TosaErrorValidator.evWrongOutputList,
4356 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004357 "data_gen": {
4358 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4359 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004360 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004361 "reduce_product": {
4362 "op": Op.REDUCE_PRODUCT,
4363 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004364 "build_fcn": (
4365 build_reduce,
4366 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004367 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004368 TosaArgGen.agAxis,
4369 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004370 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004371 "error_if_validators": (
4372 TosaErrorValidator.evAxisLargerRank,
4373 TosaErrorValidator.evAxisSmallerZero,
4374 TosaErrorValidator.evShapeOfAxisNotOne,
4375 TosaErrorValidator.evWrongInputType,
4376 TosaErrorValidator.evWrongOutputType,
4377 TosaErrorValidator.evWrongRank,
4378 TosaErrorValidator.evWrongInputList,
4379 TosaErrorValidator.evWrongOutputList,
4380 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004381 "data_gen": {
4382 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4383 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004384 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004385 "reduce_sum": {
4386 "op": Op.REDUCE_SUM,
4387 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004388 "build_fcn": (
4389 build_reduce,
4390 TosaTensorGen.tgBasic,
4391 TosaTensorValuesGen.tvgReduceSum,
4392 TosaArgGen.agAxis,
4393 ),
James Ward24dbc422022-10-19 12:20:31 +01004394 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004395 "error_if_validators": (
4396 TosaErrorValidator.evAxisLargerRank,
4397 TosaErrorValidator.evAxisSmallerZero,
4398 TosaErrorValidator.evShapeOfAxisNotOne,
4399 TosaErrorValidator.evWrongInputType,
4400 TosaErrorValidator.evWrongOutputType,
4401 TosaErrorValidator.evWrongRank,
4402 TosaErrorValidator.evWrongInputList,
4403 TosaErrorValidator.evWrongOutputList,
4404 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004405 "data_gen": {
4406 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4407 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004408 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004409 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004410 "concat": {
4411 "op": Op.CONCAT,
4412 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004413 "build_fcn": (
4414 build_concat,
4415 TosaTensorGen.tgConcat,
4416 TosaTensorValuesGen.tvgConcat,
4417 TosaArgGen.agAxis,
4418 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004419 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004420 "error_if_validators": (
4421 TosaErrorValidator.evAxisLargerRank,
4422 TosaErrorValidator.evAxisSmallerZero,
4423 TosaErrorValidator.evConcatInputRankMismatch,
4424 TosaErrorValidator.evConcatShapeSumMismatch,
4425 TosaErrorValidator.evConcatInputDimMismatch,
4426 TosaErrorValidator.evWrongInputType,
4427 TosaErrorValidator.evWrongOutputType,
4428 TosaErrorValidator.evWrongOutputList,
4429 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004430 "data_gen": {
4431 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4432 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004433 },
4434 "pad": {
4435 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004436 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004437 "build_fcn": (
4438 build_pad,
4439 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004440 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004441 TosaArgGen.agPad,
4442 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004443 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004444 "error_if_validators": (
4445 TosaErrorValidator.evWrongInputType,
4446 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004447 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004448 TosaErrorValidator.evWrongOutputType,
4449 TosaErrorValidator.evWrongInputList,
4450 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004451 TosaErrorValidator.evRankMismatch,
4452 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004453 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004454 "data_gen": {
4455 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4456 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004457 },
Won Jeona21b2e82023-08-10 10:33:01 +00004458 "dim": {
4459 "op": Op.DIM,
4460 "operands": (1, 0),
4461 "build_fcn": (
4462 build_dim,
4463 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004464 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004465 TosaArgGen.agAxis,
4466 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004467 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004468 "error_if_validators": (
4469 TosaErrorValidator.evAxisLargerRank,
4470 TosaErrorValidator.evAxisSmallerZero,
4471 TosaErrorValidator.evWrongInputType,
4472 TosaErrorValidator.evWrongInputList,
4473 TosaErrorValidator.evWrongOutputList,
4474 TosaErrorValidator.evWrongRank,
4475 ),
4476 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004477 "reshape": {
4478 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004479 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004480 "build_fcn": (
4481 build_reshape,
4482 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004483 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004484 TosaArgGen.agReshape,
4485 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004486 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004487 "error_if_validators": (
4488 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4489 TosaErrorValidator.evWrongInputType,
4490 TosaErrorValidator.evWrongOutputType,
4491 TosaErrorValidator.evWrongInputList,
4492 TosaErrorValidator.evWrongOutputList,
4493 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004494 "data_gen": {
4495 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4496 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004497 },
4498 "reverse": {
4499 "op": Op.REVERSE,
4500 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004501 "build_fcn": (
4502 build_reverse,
4503 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004504 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004505 TosaArgGen.agAxis,
4506 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004507 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004508 "error_if_validators": (
4509 TosaErrorValidator.evAxisSmallerZero,
4510 TosaErrorValidator.evAxisLargerRank,
4511 TosaErrorValidator.evWrongInputType,
4512 TosaErrorValidator.evWrongOutputType,
4513 TosaErrorValidator.evWrongInputList,
4514 TosaErrorValidator.evWrongOutputList,
4515 ),
evacha0198477222024-01-26 12:25:32 +00004516 "data_gen": {
4517 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4518 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004519 },
4520 "slice": {
4521 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004522 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004523 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004524 "build_fcn": (
4525 build_slice,
4526 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004527 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004528 TosaArgGen.agSlice,
4529 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004530 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004531 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004532 # TODO Turn off these error categories for now as the reference
4533 # model cannot allocate memory space for empty tensor. We probably
4534 # can report an accurate error messege at the right place during
4535 # exeuction.
4536 # TosaErrorValidator.evStartSmallerZero,
4537 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004538 TosaErrorValidator.evStartSizeOutsideBounds,
4539 TosaErrorValidator.evSizeOutputShapeMismatch,
4540 TosaErrorValidator.evInputSizeStartLengthMismatch,
4541 TosaErrorValidator.evWrongRank,
4542 TosaErrorValidator.evWrongInputType,
4543 TosaErrorValidator.evWrongOutputType,
4544 TosaErrorValidator.evWrongInputList,
4545 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004546 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004547 ),
evacha017f7d4252024-01-24 12:08:09 +00004548 "data_gen": {
4549 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4550 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004551 },
4552 "tile": {
4553 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004554 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004555 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004556 "build_fcn": (
4557 build_tile,
4558 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004559 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004560 TosaArgGen.agTile,
4561 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004562 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004563 "error_if_validators": (
4564 TosaErrorValidator.evWrongInputType,
4565 TosaErrorValidator.evWrongOutputType,
4566 TosaErrorValidator.evWrongInputList,
4567 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004568 TosaErrorValidator.evRankMismatch,
4569 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004570 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004571 "data_gen": {
4572 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4573 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004574 },
4575 "transpose": {
4576 "op": Op.TRANSPOSE,
4577 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004578 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004579 "build_fcn": (
4580 build_transpose,
4581 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004582 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004583 TosaArgGen.agTranspose,
4584 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004585 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004586 "error_if_validators": (
4587 TosaErrorValidator.evIndexOutsideBounds,
4588 TosaErrorValidator.evIndexUsedTwice,
4589 TosaErrorValidator.evWrongInputType,
4590 TosaErrorValidator.evWrongOutputType,
4591 TosaErrorValidator.evWrongInputList,
4592 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004593 TosaErrorValidator.evWrongRank,
4594 TosaErrorValidator.evRankMismatch,
4595 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004596 ),
evacha0198477222024-01-26 12:25:32 +00004597 "data_gen": {
4598 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4599 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004600 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004601 # Data nodes
4602 "const": {
4603 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004604 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004605 "build_fcn": (
4606 build_const,
4607 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004608 TosaTensorValuesGen.tvgLazyGenDefault,
4609 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004610 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004611 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha0198477222024-01-26 12:25:32 +00004612 "data_gen": {
4613 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4614 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004615 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004616 "identity": {
4617 "op": Op.IDENTITY,
4618 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004619 "build_fcn": (
4620 build_unary,
4621 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004622 TosaTensorValuesGen.tvgLazyGenDefault,
4623 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004624 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004625 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004626 "data_gen": {
4627 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4628 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004629 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004630 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004631 "gather": {
4632 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004633 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004634 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004635 "build_fcn": (
4636 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004637 TosaTensorGen.tgGather,
4638 TosaTensorValuesGen.tvgGather,
4639 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004640 ),
James Ward24dbc422022-10-19 12:20:31 +01004641 "types": (
4642 DType.INT8,
4643 DType.INT16,
4644 DType.INT32,
4645 DType.FP16,
4646 DType.BF16,
4647 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004648 DType.FP8E4M3,
4649 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004650 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004651 "error_if_validators": (
4652 TosaErrorValidator.evWrongInputType,
4653 TosaErrorValidator.evWrongOutputType,
4654 TosaErrorValidator.evWrongInputList,
4655 TosaErrorValidator.evWrongOutputList,
4656 TosaErrorValidator.evWrongRank,
4657 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004658 "data_gen": {
4659 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4660 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004661 },
4662 "scatter": {
4663 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004664 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004665 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004666 "build_fcn": (
4667 build_scatter,
4668 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004669 TosaTensorValuesGen.tvgScatter,
4670 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004671 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004672 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004673 "error_if_validators": (
4674 TosaErrorValidator.evWrongInputType,
4675 TosaErrorValidator.evWrongOutputType,
4676 TosaErrorValidator.evWrongInputList,
4677 TosaErrorValidator.evWrongOutputList,
4678 TosaErrorValidator.evWrongRank,
4679 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004680 "data_gen": {
4681 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4682 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004683 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004684 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004685 "resize": {
4686 "op": Op.RESIZE,
4687 "operands": (1, 0),
4688 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004689 "build_fcn": (
4690 build_resize,
4691 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004692 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004693 TosaArgGen.agResize,
4694 ),
James Ward24dbc422022-10-19 12:20:31 +01004695 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004696 "invalid_test_validators": (
4697 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004698 ),
4699 "error_if_validators": (
4700 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004701 TosaErrorValidator.evScaleSmallerEqualZero,
4702 TosaErrorValidator.evScaleNLargerMax,
4703 TosaErrorValidator.evScaleDLargerMax,
4704 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004705 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004706 TosaErrorValidator.evBorderSmallerMin,
4707 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004708 TosaErrorValidator.evWrongInputType,
4709 TosaErrorValidator.evWrongOutputType,
4710 TosaErrorValidator.evWrongRank,
4711 TosaErrorValidator.evWrongInputList,
4712 TosaErrorValidator.evWrongOutputList,
4713 TosaErrorValidator.evBatchMismatch,
4714 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004715 TosaErrorValidator.evResizeOutputShapeMismatch,
4716 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004717 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004718 "data_gen": {
4719 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4720 },
4721 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004722 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004723 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004724 "cast": {
4725 "op": Op.CAST,
4726 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004727 "build_fcn": (
4728 build_cast,
4729 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004730 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004731 TosaArgGen.agCast,
4732 ),
James Ward8b390432022-08-12 20:48:56 +01004733 "types": (
4734 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004735 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004736 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004737 DType.INT8,
4738 DType.INT16,
4739 DType.INT32,
4740 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004741 DType.FP8E4M3,
4742 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004743 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004744 "error_if_validators": (
4745 TosaErrorValidator.evWrongInputType,
4746 TosaErrorValidator.evWrongOutputType,
4747 TosaErrorValidator.evWrongInputList,
4748 TosaErrorValidator.evWrongOutputList,
4749 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004750 "data_gen": {
4751 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4752 },
4753 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004754 },
4755 "rescale": {
4756 "op": Op.RESCALE,
4757 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004758 "build_fcn": (
4759 build_rescale,
4760 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00004761 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004762 TosaArgGen.agRescale,
4763 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004764 "types": [
4765 DType.UINT8,
4766 DType.INT8,
4767 DType.INT16,
4768 DType.INT32,
4769 DType.INT48,
4770 DType.UINT16,
4771 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004772 "error_if_validators": (
4773 TosaErrorValidator.evInputZeroPointNotZero,
4774 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004775 TosaErrorValidator.evU16InputZeroPointNotValid,
4776 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004777 TosaErrorValidator.evScaleTrue,
4778 TosaErrorValidator.evScaleNotTrue,
4779 TosaErrorValidator.evWrongInputType,
4780 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004781 TosaErrorValidator.evWrongInputList,
4782 TosaErrorValidator.evWrongOutputList,
4783 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004784 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004785 # Custom
4786 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004787 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004788 # Two varients of cond_if, one that generates one of two constant tensors (no
4789 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4790 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004791 "cond_if_const": {
4792 "op": Op.COND_IF,
4793 "operands": (0, 2),
4794 "build_fcn": (
4795 build_cond_if_const,
4796 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004797 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004798 TosaArgGen.agCondIf,
4799 ),
4800 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004801 "error_if_validators": (
4802 TosaErrorValidator.evOutputListThenGraphMismatch,
4803 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004804 TosaErrorValidator.evCondIfCondNotMatchingBool,
4805 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004806 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004807 },
4808 "cond_if_binary": {
4809 "op": Op.COND_IF,
4810 "operands": (2, 0),
4811 "build_fcn": (
4812 build_cond_if_binary,
4813 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004814 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004815 TosaArgGen.agCondIf,
4816 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004817 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004818 "error_if_validators": (
4819 TosaErrorValidator.evInputListThenGraphMismatch,
4820 TosaErrorValidator.evInputListElseGraphMismatch,
4821 TosaErrorValidator.evOutputListThenGraphMismatch,
4822 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004823 TosaErrorValidator.evCondIfCondNotMatchingBool,
4824 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004825 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004826 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004827 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004828 "while_loop": {
4829 "op": Op.WHILE_LOOP,
4830 "operands": (0, 1),
4831 "build_fcn": (
4832 build_while_loop,
4833 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004834 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004835 TosaArgGen.agWhileLoop,
4836 ),
4837 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004838 "error_if_validators": (
4839 TosaErrorValidator.evInputListOutputListMismatch,
4840 TosaErrorValidator.evInputListCondGraphMismatch,
4841 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4842 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4843 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004844 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004845 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004846 },
Luke Hutton57287132023-02-06 14:54:18 +00004847 "fft2d": {
4848 "op": Op.FFT2D,
4849 "operands": (2, 0),
4850 "rank": (3, 3),
4851 "build_fcn": (
4852 build_fft2d,
4853 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004854 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004855 TosaArgGen.agFFT2d,
4856 ),
4857 "types": [DType.FP32],
4858 "error_if_validators": (
4859 TosaErrorValidator.evWrongInputType,
4860 TosaErrorValidator.evWrongOutputType,
4861 TosaErrorValidator.evWrongInputList,
4862 TosaErrorValidator.evWrongOutputList,
4863 TosaErrorValidator.evWrongRank,
4864 TosaErrorValidator.evBatchMismatch,
4865 TosaErrorValidator.evKernelNotPowerOfTwo,
4866 TosaErrorValidator.evFFTInputShapeMismatch,
4867 TosaErrorValidator.evFFTOutputShapeMismatch,
4868 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004869 "data_gen": {
4870 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4871 },
Luke Hutton57287132023-02-06 14:54:18 +00004872 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004873 "rfft2d": {
4874 "op": Op.RFFT2D,
4875 "operands": (1, 0),
4876 "rank": (3, 3),
4877 "build_fcn": (
4878 build_rfft2d,
4879 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004880 TosaTensorValuesGen.tvgLazyGenDefault,
4881 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004882 ),
4883 "types": [DType.FP32],
4884 "error_if_validators": (
4885 TosaErrorValidator.evWrongInputType,
4886 TosaErrorValidator.evWrongOutputType,
4887 TosaErrorValidator.evWrongInputList,
4888 TosaErrorValidator.evWrongOutputList,
4889 TosaErrorValidator.evWrongRank,
4890 TosaErrorValidator.evBatchMismatch,
4891 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004892 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004893 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004894 "data_gen": {
4895 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4896 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004897 },
Won Jeon74342e52024-01-09 00:34:40 +00004898 # Shape
4899 "add_shape": {
4900 "op": Op.ADD_SHAPE,
4901 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004902 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004903 "build_fcn": (
4904 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004905 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004906 TosaTensorValuesGen.tvgAddSub,
4907 TosaArgGen.agNone,
4908 ),
4909 "types": [DType.SHAPE],
4910 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4911 },
4912 "sub_shape": {
4913 "op": Op.SUB_SHAPE,
4914 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004915 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004916 "build_fcn": (
4917 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004918 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004919 TosaTensorValuesGen.tvgAddSub,
4920 TosaArgGen.agNone,
4921 ),
4922 "types": [DType.SHAPE],
4923 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4924 },
4925 "mul_shape": {
4926 "op": Op.MUL_SHAPE,
4927 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004928 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004929 "build_fcn": (
4930 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004931 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004932 TosaTensorValuesGen.tvgMul,
4933 TosaArgGen.agNone,
4934 ),
4935 "types": [DType.SHAPE],
4936 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4937 },
4938 "div_shape": {
4939 "op": Op.DIV_SHAPE,
4940 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004941 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004942 "build_fcn": (
4943 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004944 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004945 TosaTensorValuesGen.tvgIntDiv,
4946 TosaArgGen.agNone,
4947 ),
4948 "types": [DType.SHAPE],
4949 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4950 },
4951 "concat_shape": {
4952 "op": Op.CONCAT_SHAPE,
4953 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004954 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004955 "build_fcn": (
4956 build_concat,
4957 TosaTensorGen.tgConcat,
4958 TosaTensorValuesGen.tvgConcat,
4959 TosaArgGen.agNone,
4960 ),
4961 "types": [DType.SHAPE],
4962 "error_if_validators": (),
4963 },
4964 "const_shape": {
4965 "op": Op.CONST_SHAPE,
4966 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004967 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004968 "build_fcn": (
4969 build_const,
4970 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004971 TosaTensorValuesGen.tvgLazyGenDefault,
4972 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00004973 ),
4974 "types": [DType.SHAPE],
4975 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004976 }
4977
Kevin Cheng550ccc52021-03-03 11:21:43 -08004978
Eric Kunzee5e26762020-10-13 16:11:07 -07004979class OutputShaper:
4980 # Methods in this class compute the expected output shape and datatype
4981 # for common classes of operations
4982 def __init__(self):
4983 pass
4984
4985 # These methods return arguments that can be used for
4986 # creating a new output tensor
4987 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004988 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4989 if error_name != ErrorIf.RankMismatch:
4990 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004991 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004992
4993 shape = []
4994 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004995 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004996 shape.append(b.shape[i])
4997 else:
4998 shape.append(a.shape[i])
4999
Jerry Ge135c9552023-05-23 20:59:32 +00005000 fuzz_idx = rng.integers(0, len(a.shape))
5001 if error_name == ErrorIf.DimensionMismatch:
5002 shape[fuzz_idx] += 1
5003
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005004 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005005 all_dtypes = [
5006 DType.INT8,
5007 DType.INT16,
5008 DType.INT32,
5009 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005010 DType.FP16,
5011 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005012 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005013 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005014 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5015 outputDType = rng.choice(wrong_dtypes)
5016 else:
5017 outputDType = a.dtype
5018
5019 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005020
5021 @staticmethod
5022 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005023 assert len(a.shape) == len(b.shape)
5024 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005025
5026 shape = []
5027 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005028 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005029 shape.append(a.shape[i])
5030
Kevin Cheng550ccc52021-03-03 11:21:43 -08005031 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005032
5033 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005034 def unaryOp(ser, rng, a, error_name=None):
5035 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005036 all_dtypes = [
5037 DType.INT8,
5038 DType.INT16,
5039 DType.INT32,
5040 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005041 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005042 DType.FP16,
5043 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005044 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005045 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5046 outputDType = rng.choice(wrong_dtypes)
5047 else:
5048 outputDType = a.dtype
5049
5050 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005051
5052 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005053 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005054 if error_name != ErrorIf.RankMismatch:
5055 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005056 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005057
5058 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005059 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005060 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005061 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5062 else:
5063 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005064
Jerry Ge135c9552023-05-23 20:59:32 +00005065 fuzz_idx = rng.integers(0, len(a.shape))
5066 if error_name == ErrorIf.DimensionMismatch:
5067 shape[fuzz_idx] += 1
5068
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005069 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005070 all_dtypes = [
5071 DType.INT8,
5072 DType.INT16,
5073 DType.INT32,
5074 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005075 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005076 DType.FP16,
5077 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005078 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005079 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5080 outputDType = rng.choice(wrong_dtypes)
5081 else:
5082 outputDType = a.dtype
5083
5084 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005085
5086 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005087 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005088 if error_name != ErrorIf.RankMismatch:
5089 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005090 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005091
5092 # Do broadcast
5093 shape = []
5094 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005095 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005096 shape.append(b.shape[i])
5097 else:
5098 shape.append(a.shape[i])
5099
Jerry Ge135c9552023-05-23 20:59:32 +00005100 fuzz_idx = rng.integers(0, len(a.shape))
5101 if error_name == ErrorIf.DimensionMismatch:
5102 shape[fuzz_idx] += 1
5103
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005104 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005105 wrong_dtypes = [
5106 DType.INT8,
5107 DType.INT16,
5108 DType.INT32,
5109 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005110 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005111 DType.FP16,
5112 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005113 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005114 outputDType = rng.choice(wrong_dtypes)
5115 else:
5116 outputDType = DType.BOOL
5117
5118 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005119
5120 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005121 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005122 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005123 if error_name not in [
5124 ErrorIf.AxisSmallerZero,
5125 ErrorIf.AxisLargerRank,
5126 ErrorIf.ShapeOfAxisNotOne,
5127 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005128 shape[axis] = 1
5129 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5130 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005131
Matthew Haddond6ce7252021-09-29 15:35:44 +01005132 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005133 all_dtypes = [
5134 DType.INT8,
5135 DType.INT16,
5136 DType.INT32,
5137 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005138 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005139 DType.FP16,
5140 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005141 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005142 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5143 outputDType = rng.choice(wrong_dtypes)
5144 else:
5145 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005146
Matthew Haddond6ce7252021-09-29 15:35:44 +01005147 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005148
5149 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005150 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005151 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005152
5153 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5154 del shape[axis]
5155
5156 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5157 remove = rng.choice([True, False])
5158 if remove and len(shape) > 1:
5159 del shape[0]
5160 else:
5161 shape.append(1)
5162 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5163 for i in range(len(shape)):
5164 shape[i] = shape[i] + rng.integers(1, 10)
5165
5166 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005167 all_dtypes = [
5168 DType.INT8,
5169 DType.INT16,
5170 DType.INT32,
5171 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005172 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005173 DType.FP16,
5174 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005175 DType.FP8E4M3,
5176 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005177 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005178 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5179 outputDType = rng.choice(wrong_dtypes)
5180 else:
5181 outputDType = DType.INT32
5182
5183 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005184
5185 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005186 def conv2dOp(
5187 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5188 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005189
5190 # IFM: NHWC
5191 # Filter: OHWI
5192 # OFM: NHWC
5193
Kevin Cheng550ccc52021-03-03 11:21:43 -08005194 h = (
5195 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005196 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005197 + padding[0]
5198 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005199 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005200 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005201
Kevin Cheng550ccc52021-03-03 11:21:43 -08005202 w = (
5203 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005204 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005205 + padding[2]
5206 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005207 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005208 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005209
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005210 if error_name == ErrorIf.ConvOutputShapeMismatch:
5211 choices = [1, 2, 3]
5212 change = rng.choice(choices)
5213 # increment in multiples of stride to not hit non-integer error case
5214 if change in [1, 3]:
5215 h = h + (rng.choice(choices) * strides[0])
5216 if change in [2, 3]:
5217 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005218
Eric Kunzee5e26762020-10-13 16:11:07 -07005219 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5220
James Ward8b390432022-08-12 20:48:56 +01005221 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005222 # Pick some potentially correct output dtype if input type is incorrect
5223 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005224 else:
James Ward8b390432022-08-12 20:48:56 +01005225 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005226
5227 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005228 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005229 excludes = [DType.FP16, DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00005230 if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
5231 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005232 else:
5233 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005234 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005235 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005236
Kevin Cheng550ccc52021-03-03 11:21:43 -08005237 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005238
5239 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005240 def conv3dOp(
5241 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5242 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005243
5244 # IFM: NDHWC
5245 # Filter: ODHWI
5246 # OFM: NDHWC
5247
5248 d = (
5249 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005250 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005251 + padding[0]
5252 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005253 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005254 ) // strides[0] + 1
5255
5256 h = (
5257 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005258 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005259 + padding[2]
5260 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005261 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005262 ) // strides[1] + 1
5263
5264 w = (
5265 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005266 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005267 + padding[4]
5268 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005269 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005270 ) // strides[2] + 1
5271
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005272 if error_name == ErrorIf.ConvOutputShapeMismatch:
5273 choices = [1, 2, 3, 4]
5274 change = rng.choice(choices)
5275 # increment in multiples of stride to not hit non-integer error case
5276 if change in [1, 4]:
5277 d = d + (rng.choice(choices) * strides[0])
5278 if change in [2, 4]:
5279 h = h + (rng.choice(choices) * strides[1])
5280 if change in [3, 4]:
5281 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005282
Kevin Cheng1533b852021-09-01 12:51:58 -07005283 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5284
James Ward8b390432022-08-12 20:48:56 +01005285 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005286 # Pick some potentially correct output dtype if input type is incorrect
5287 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005288 else:
James Ward8b390432022-08-12 20:48:56 +01005289 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005290
5291 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005292 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005293 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005294 else:
5295 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005296 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005297 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005298
5299 return ser.addOutput(ofm_shape, out_dtype)
5300
5301 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005302 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005303 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005304 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005305 # IFM: NHWC
5306 # Filter: HWCM
5307 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005308
Kevin Cheng550ccc52021-03-03 11:21:43 -08005309 h = (
5310 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005311 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005312 + padding[0]
5313 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005314 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005315 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005316
Kevin Cheng550ccc52021-03-03 11:21:43 -08005317 w = (
5318 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005319 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005320 + padding[2]
5321 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005322 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005323 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005324
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005325 if error_name == ErrorIf.ConvOutputShapeMismatch:
5326 choices = [1, 2, 3]
5327 change = rng.choice(choices)
5328 # increment in multiples of stride to not hit non-integer error case
5329 if change in [1, 3]:
5330 h = h + (rng.choice(choices) * strides[0])
5331 if change in [2, 3]:
5332 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005333
Eric Kunzee5e26762020-10-13 16:11:07 -07005334 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5335
James Ward8b390432022-08-12 20:48:56 +01005336 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005337 # Pick some potentially correct output dtype if input type is incorrect
5338 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005339 else:
James Ward8b390432022-08-12 20:48:56 +01005340 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005341
5342 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005343 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005344 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005345 else:
5346 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005347 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005348 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005349
Kevin Cheng550ccc52021-03-03 11:21:43 -08005350 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005351
5352 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005353 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005354 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005355 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005356 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005357 h = 1
5358 w = 1
5359 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005360 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5361 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005362
5363 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005364 choices = [1, 2, 3]
5365 change = rng.choice(choices)
5366 # increment in multiples of stride to not hit non-integer error case
5367 if change in [1, 3]:
5368 h = h + (rng.choice(choices) * stride[0])
5369 if change in [2, 3]:
5370 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005371 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005372
5373 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005374 all_dtypes = [
5375 DType.INT8,
5376 DType.INT16,
5377 DType.INT32,
5378 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005379 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005380 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005381 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005382 DType.FP8E4M3,
5383 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005384 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005385 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5386 outputDType = rng.choice(wrong_dtypes)
5387 else:
5388 outputDType = ifm.dtype
5389
5390 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005391
5392 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005393 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005394 # input: N, IC
5395 # filter: OC, IC
5396 # output: N, OC
5397
5398 output_shape = [input.shape[0], filter.shape[0]]
5399
James Ward8b390432022-08-12 20:48:56 +01005400 # Validated in arg_gen (also invalidated for ErrorIf)
5401 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005402
Kevin Cheng550ccc52021-03-03 11:21:43 -08005403 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005404
5405 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005406 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005407 # a: N, H, C
5408 # b: N, C, W
5409 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005410
Kevin Cheng2d60f002021-06-09 14:18:32 -07005411 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005412
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005413 if error_name == ErrorIf.WrongOutputType:
5414 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005415 incorrect_types = (
5416 DType.INT4,
5417 DType.INT8,
5418 DType.INT16,
5419 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005420 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005421 DType.FP16,
5422 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005423 DType.FP8E4M3,
5424 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005425 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005426 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005427 incorrect_types = (
5428 DType.INT4,
5429 DType.INT8,
5430 DType.INT16,
5431 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005432 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005433 DType.FP16,
5434 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005435 DType.FP8E4M3,
5436 DType.FP8E5M2,
5437 )
5438 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5439 incorrect_types = (
5440 DType.INT4,
5441 DType.INT8,
5442 DType.INT16,
5443 DType.INT32,
5444 DType.INT48,
5445 DType.FP32,
5446 DType.BF16,
5447 DType.FP8E4M3,
5448 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005449 )
James Ward24dbc422022-10-19 12:20:31 +01005450 elif (
5451 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5452 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005453 incorrect_types = (
5454 DType.INT4,
5455 DType.INT8,
5456 DType.INT16,
5457 DType.INT32,
5458 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005459 DType.FP8E4M3,
5460 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005461 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005462 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005463 elif error_name == ErrorIf.WrongInputType:
5464 # Pick some potentially correct output dtype if input type is incorrect
5465 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005466 else:
James Ward8b390432022-08-12 20:48:56 +01005467 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005468
Kevin Cheng550ccc52021-03-03 11:21:43 -08005469 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005470
5471 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005472 def concatOp(ser, rng, axis, inputs, error_name=None):
5473 input1 = inputs[0]
5474 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005475
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005476 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005477 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005478 if not (
5479 # unable to concat tensors of different ranks
5480 error_name == ErrorIf.ConcatInputRankMismatch
5481 # unable to concat tensors along an invalid axis
5482 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005483 ):
5484 for tensor in remaining_inputs:
5485 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005486
Matthew Haddon01c359d2021-10-15 16:30:48 +01005487 if error_name == ErrorIf.ConcatShapeSumMismatch:
5488 output_shape[axis] += rng.integers(5, 10)
5489
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005490 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005491 all_dtypes = {
5492 DType.INT8,
5493 DType.INT16,
5494 DType.INT32,
5495 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005496 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005497 DType.FP16,
5498 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005499 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005500 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5501 outputDType = rng.choice(wrong_dtypes)
5502 else:
5503 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005504
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005505 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005506
5507 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005508 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005509
5510 output_shape = a.shape.copy()
5511
5512 for i in range(len(output_shape)):
5513 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5514
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005515 if error_name == ErrorIf.PadOutputShapeMismatch:
5516 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005517 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005518 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005519 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005520
Matthew Haddone807aae2021-10-11 18:12:58 +01005521 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005522 all_dtypes = [
5523 DType.INT8,
5524 DType.INT16,
5525 DType.INT32,
5526 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005527 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005528 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005529 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005530 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005531 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5532 outputDType = rng.choice(wrong_dtypes)
5533 else:
5534 outputDType = a.dtype
5535
5536 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005537
5538 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005539 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005540 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005541
5542 if error_name == ErrorIf.WrongOutputType:
5543 all_dtypes = [
5544 DType.INT8,
5545 DType.INT16,
5546 DType.INT32,
5547 DType.INT48,
5548 DType.FP32,
5549 DType.FP16,
5550 DType.BF16,
5551 ]
5552 wrong_dtypes = list(set(all_dtypes))
5553 outputDType = rng.choice(wrong_dtypes)
5554 else:
5555 outputDType = DType.SHAPE
5556
5557 return ser.addOutput(output_shape, outputDType)
5558
5559 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005560 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005561 output_shape = shape.copy()
5562
Matthew Haddone807aae2021-10-11 18:12:58 +01005563 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5564 for i in range(len(output_shape)):
5565 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5566
5567 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005568 all_dtypes = [
5569 DType.INT8,
5570 DType.INT16,
5571 DType.INT32,
5572 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005573 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005574 DType.FP16,
5575 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005576 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005577 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5578 outputDType = rng.choice(wrong_dtypes)
5579 else:
5580 outputDType = a.dtype
5581
5582 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005583
5584 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005585 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005586
Matthew Haddone807aae2021-10-11 18:12:58 +01005587 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005588 all_dtypes = [
5589 DType.INT8,
5590 DType.INT16,
5591 DType.INT32,
5592 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005593 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005594 DType.FP16,
5595 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005596 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005597 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005598 outputDType = rng.choice(wrong_dtypes)
5599 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005600 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005601
Luke Huttona4e48ca2023-02-22 11:53:48 +00005602 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005603 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005604 for index in range(len(output_shape)):
5605 if output_shape[index] <= 2:
5606 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5607 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005608 output_shape[index] = output_shape[index] + rng.choice(
5609 [-2, -1, 1, 2]
5610 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005611 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5612 output_shape = input.shape.copy()
5613 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005614 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005615
5616 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005617
5618 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005619 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005620
5621 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005622 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005623
5624 for i in range(len(output_shape)):
5625 output_shape[i] = a.shape[i] * multiples[i]
5626
Luke Huttona4e48ca2023-02-22 11:53:48 +00005627 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005628 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005629
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005630 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005631 all_dtypes = [
5632 DType.INT8,
5633 DType.INT16,
5634 DType.INT32,
5635 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005636 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005637 DType.FP16,
5638 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005639 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005640 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5641 outputDType = rng.choice(wrong_dtypes)
5642 else:
5643 outputDType = a.dtype
5644
5645 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005646
5647 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005648 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005649 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005650
Kevin Cheng550ccc52021-03-03 11:21:43 -08005651 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005652
Luke Huttona4e48ca2023-02-22 11:53:48 +00005653 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005654 for i in range(len(output_shape)):
5655 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005656
Luke Huttona4e48ca2023-02-22 11:53:48 +00005657 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5658 for i in range(len(output_shape)):
5659 output_shape[i] += rng.integers(1, 10)
5660 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005661 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005662
Matthew Haddone807aae2021-10-11 18:12:58 +01005663 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005664 all_dtypes = [
5665 DType.INT8,
5666 DType.INT16,
5667 DType.INT32,
5668 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005669 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005670 DType.FP16,
5671 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005672 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005673 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5674 outputDType = rng.choice(wrong_dtypes)
5675 else:
5676 outputDType = a.dtype
5677
5678 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005679
5680 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005681 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005682 if error_name != ErrorIf.WrongRank:
5683 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005684 assert len(indices.shape) == 2
5685 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005686
Kevin Cheng77d0f762020-11-24 10:26:32 -08005687 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5688
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005689 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005690 all_dtypes = [
5691 DType.INT8,
5692 DType.INT16,
5693 DType.INT32,
5694 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005695 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005696 DType.FP16,
5697 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005698 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005699 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5700 outputDType = rng.choice(wrong_dtypes)
5701 else:
5702 outputDType = values.dtype
5703
5704 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005705
5706 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005707 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005708 if error_name != ErrorIf.WrongRank:
5709 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005710 assert len(indices.shape) == 2
5711 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005712 assert values_in.shape[0] == indices.shape[0] # N
5713 assert input.shape[1] == indices.shape[1] # W
5714 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005715
5716 output_shape = values_in.shape
5717
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005718 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005719 all_dtypes = [
5720 DType.INT8,
5721 DType.INT16,
5722 DType.INT32,
5723 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005724 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005725 DType.FP16,
5726 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005727 DType.FP8E4M3,
5728 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005729 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005730 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5731 outputDType = rng.choice(wrong_dtypes)
5732 else:
5733 outputDType = values_in.dtype
5734
5735 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005736
5737 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005738 def tableOp(ser, rng, input, error_name=None):
5739 # Same shape as the input, dtype dependent on input dtype
5740 if error_name != ErrorIf.WrongInputType:
5741 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005742 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005743 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005744 wrong_dtypes = [
5745 DType.INT8,
5746 DType.INT16,
5747 DType.INT32,
5748 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005749 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005750 DType.FP16,
5751 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005752 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005753 wrong_dtypes.remove(output_dtype)
5754 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005755 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005756
5757 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005758 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005759 serializer,
5760 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005761 input,
5762 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005763 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005764 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005765 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005766 input_dtype,
5767 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005768 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005769 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005770 # Calculate OH, OW
5771 scale_y_n = scale[0]
5772 scale_y_d = scale[1]
5773 scale_x_n = scale[2]
5774 scale_x_d = scale[3]
5775 if error_name == ErrorIf.ScaleSmallerEqualZero:
5776 scale_y_n = max(scale_y_n, 1)
5777 scale_y_d = max(scale_y_d, 1)
5778 scale_x_n = max(scale_x_n, 1)
5779 scale_x_d = max(scale_x_d, 1)
5780
5781 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5782 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5783
5784 if error_name is not None:
5785 # Make sure the output tensor is valid, which can occur when
5786 # scale, offset or border have been changed for ERROR_IFs
5787 oh = max(oh, 1)
5788 ow = max(ow, 1)
5789 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005790 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5791 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005792
5793 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5794 choices = [1, 2, 3]
5795 change = rng.choice(choices)
5796 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5797 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005798 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005799 oh -= scale_y_d
5800 assert oh > 0 # Should have been caught in agResize
5801 else:
5802 oh += scale_y_d
5803 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005804 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005805 ow -= scale_x_d
5806 assert ow > 0 # Should have been caught in agResize
5807 else:
5808 ow += scale_x_d
5809
Matthew Haddon848efb42021-09-09 12:30:53 +01005810 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005811 output_dims = [
5812 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005813 oh,
5814 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005815 input.shape[0],
5816 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005817 elif error_name == ErrorIf.BatchMismatch:
5818 output_dims = [
5819 input.shape[0] + rng.integers(1, 10),
5820 oh,
5821 ow,
5822 input.shape[3],
5823 ]
5824 elif error_name == ErrorIf.ChannelMismatch:
5825 output_dims = [
5826 input.shape[0],
5827 oh,
5828 ow,
5829 input.shape[3] + rng.integers(1, 10),
5830 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005831 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005832 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005833
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005834 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005835
5836 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005837 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005838 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005839
5840 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005841 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005842 if error_name == ErrorIf.ConvOutputShapeMismatch:
5843 choices = [1, 2, 3]
5844 change = rng.choice(choices)
5845 if change in [1, 3]:
5846 output_shape[1] = output_shape[1] + rng.choice(choices)
5847 if change in [2, 3]:
5848 output_shape[2] = output_shape[2] + rng.choice(choices)
5849
James Ward8b390432022-08-12 20:48:56 +01005850 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005851 # Pick some potentially correct output dtype if input type is incorrect
5852 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005853 else:
James Ward8b390432022-08-12 20:48:56 +01005854 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005855
5856 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005857 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005858 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005859 else:
5860 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005861 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005862 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005863
Kevin Cheng550ccc52021-03-03 11:21:43 -08005864 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005865
5866 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005867 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5868 outputs = []
5869
5870 assert ifm1.dtype == ifm2.dtype
5871 input_dtype = ifm1.dtype
5872
5873 if error_name != ErrorIf.FFTInputShapeMismatch:
5874 assert ifm1.shape == ifm2.shape
5875
5876 input_shape = ifm1.shape
5877 if error_name != ErrorIf.WrongRank:
5878 assert len(input_shape) == 3
5879
5880 output_shape = input_shape.copy()
5881 output_dtype = input_dtype
5882
5883 if error_name == ErrorIf.WrongOutputType:
5884 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005885 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005886 output_dtype = rng.choice(wrong_dtypes)
5887 elif error_name == ErrorIf.BatchMismatch:
5888 output_shape[0] += rng.integers(1, 10)
5889 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5890 modify_dim = rng.choice([1, 2])
5891 output_shape[modify_dim] += rng.integers(1, 10)
5892
5893 outputs.append(serializer.addOutput(output_shape, output_dtype))
5894 outputs.append(serializer.addOutput(output_shape, output_dtype))
5895 return outputs
5896
5897 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005898 def rfft2dOp(serializer, rng, value, error_name=None):
5899 outputs = []
5900
5901 input_shape = value.shape
5902 if error_name != ErrorIf.WrongRank:
5903 assert len(input_shape) == 3
5904
5905 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5906
5907 output_dtype = value.dtype
5908 if error_name == ErrorIf.WrongOutputType:
5909 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005910 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005911 output_dtype = rng.choice(wrong_dtypes)
5912 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005913 output_shape[0] += rng.integers(1, 10)
5914 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5915 modify_dim = rng.choice([1, 2])
5916 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005917
5918 outputs.append(serializer.addOutput(output_shape, output_dtype))
5919 outputs.append(serializer.addOutput(output_shape, output_dtype))
5920 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005921
5922 @staticmethod
5923 def addShapeOp(ser, rng, a, b, error_name=None):
5924 if error_name != ErrorIf.RankMismatch:
5925 assert len(a.shape) == len(b.shape)
5926 assert a.dtype == b.dtype
5927
5928 shape = []
5929 for i in range(len(a.shape)):
5930 shape.append(a.shape[i])
5931
5932 fuzz_idx = rng.integers(0, len(a.shape))
5933 if error_name == ErrorIf.DimensionMismatch:
5934 shape[fuzz_idx] += 1
5935
5936 if error_name == ErrorIf.WrongOutputType:
5937 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5938 outputDType = rng.choice(wrong_dtypes)
5939 else:
5940 outputDType = DType.SHAPE
5941 return ser.addOutput(shape, outputDType)