blob: 415858c8688c51b7094876853c4c23f350af4770 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
Won Jeon2c34b462024-02-06 18:37:00 +000079 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010080 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Won Jeon2c34b462024-02-06 18:37:00 +0000155 if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Jerry Ge20ab3df2024-01-26 16:56:55 +0000198 elif dtype == DType.INT16:
199 return np.int16(self.rng.integers(low=low, high=high, size=shape))
200 elif dtype == DType.UINT16:
201 return np.uint16(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000202 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100203 return np.int64(self.rng.integers(low=low, high=high, size=shape))
Won Jeon2c34b462024-02-06 18:37:00 +0000204 elif dtype in (
205 DType.FP16,
206 DType.BF16,
207 DType.FP32,
208 DType.FP8E4M3,
209 DType.FP8E5M2,
210 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100211 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
212
213 if dtype == DType.FP16:
214 return np.float16(f_tensor)
215 else:
216 f32_tensor = np.float32(f_tensor)
217 if dtype == DType.BF16:
218 # Floor the last 16 bits of each f32 value
219 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
Won Jeon2c34b462024-02-06 18:37:00 +0000220 elif dtype == DType.FP8E4M3:
221 return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
222 elif dtype == DType.FP8E5M2:
223 return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
Jeremy Johnson1271c442023-09-05 11:39:26 +0100224 else:
225 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700226 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100227 # All other integer types
228 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
Kevin Cheng989cb052021-04-28 16:29:44 -0700230 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700231 placeholders = []
232
Kevin Cheng989cb052021-04-28 16:29:44 -0700233 assert len(shape_list) == len(dtype_list)
234
Jeremy Johnson1271c442023-09-05 11:39:26 +0100235 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700236 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100237 if not self.args.lazy_data_gen:
238 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700239 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700240
241 return placeholders
242
Kevin Cheng989cb052021-04-28 16:29:44 -0700243 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700244 consts = []
245
Kevin Cheng989cb052021-04-28 16:29:44 -0700246 assert len(shape_list) == len(dtype_list)
247
Jeremy Johnson1271c442023-09-05 11:39:26 +0100248 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700249 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100250 if not self.args.lazy_data_gen:
251 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700252 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700253
254 return consts
255
256 def makeShape(self, rank):
257 if self.targetted_shape:
258 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800259 return np.int32(
260 self.rng.integers(
261 low=self.args.tensor_shape_range[0],
262 high=self.args.tensor_shape_range[1],
263 size=rank,
264 )
265 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700266
267 def setTargetShape(self, shape):
268 self.targetted_shape = shape
269
270 def randInt(self, low=0, high=256):
271 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
272
273 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100274 low, high = self.getDTypeRange(dtype)
275
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100276 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100277 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100278 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100279 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100280 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100281 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
282 return gtu.vect_f32_to_bf16(rand_f32)
Won Jeon2c34b462024-02-06 18:37:00 +0000283 elif dtype == DType.FP8E4M3:
284 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
285 return gtu.vect_f32_to_fp8e4m3(rand_f32)
286 elif dtype == DType.FP8E5M2:
287 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
288 return gtu.vect_f32_to_fp8e5m2(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700289 elif dtype == DType.BOOL:
290 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000291 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 # Special size
293 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700294
295 return np.int32(self.rng.integers(low, high, size=1))[0]
296
297 def shapeStr(self, shape):
298
299 sStr = []
300 # Convert to strings
301 for i in shape:
302 sStr.append(str(i))
303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100306 def typeStr(self, dtype):
307 if isinstance(dtype, list) or isinstance(dtype, tuple):
308 assert len(dtype) >= 2
309 strs = [self.typeStr(t) for t in dtype]
310 # Limit types to the first 2 as the 3rd is the accumulator
311 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700312 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100313 if dtype in gtu.DTYPE_ATTRIBUTES:
314 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700315 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100316 raise Exception(
317 "Unknown dtype, cannot convert to string: {}".format(dtype)
318 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700319
Luke Hutton57287132023-02-06 14:54:18 +0000320 def constrictBatchSize(self, shape):
321 # Limit the batch size unless an explicit target shape set
322 if self.args.max_batch_size and not self.args.target_shapes:
323 shape[0] = min(shape[0], self.args.max_batch_size)
324 return shape
325
James Ward30124a82023-02-02 14:56:33 +0000326 def makeDimension(self):
327 return self.randInt(
328 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
329 )
330
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100331 def tensorComplianceMetaData(
332 self, op, inputType, argsDict, outputTensor, errorName
333 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000334 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
335 UNSUPPORTED_NON_FP32_INPUT_OPS = (
336 Op.MATMUL,
337 Op.CONV2D,
338 Op.FULLY_CONNECTED,
339 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000340 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000341 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000342 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100343 if (
344 errorName
345 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000346 or (
347 not gtu.dtypeIsSupportedByCompliance(inputType)
348 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
349 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100350 ):
351 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100352 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100353
Jeremy Johnson1271c442023-09-05 11:39:26 +0100354 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100355 compliance_tens = {
356 "mode": None,
357 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
358 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
359 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100360 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
361 mode = gtu.ComplianceMode.DOT_PRODUCT
362 compliance_tens["dot_product_info"] = {
363 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100364 "ks": int(argsDict["ksb"])
365 if "ksb" in argsDict
366 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100367 }
368 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
369 mode = gtu.ComplianceMode.FP_SPECIAL
370 elif "compliance" in op and "ulp" in op["compliance"]:
371 mode = gtu.ComplianceMode.ULP
372 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000373 elif "compliance" in op and "relative" in op["compliance"]:
374 mode = gtu.ComplianceMode.RELATIVE
375 compliance_tens["relative_info"] = {
376 "max": argsDict["max_abs_value"],
377 "scale": op["compliance"]["relative"],
378 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100379 elif op["op"] == Op.REDUCE_PRODUCT:
380 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000381 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000382 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000383 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000384 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
385 compliance_tens["abs_error_info"] = {
386 "lower_bound": op["compliance"]["abs_error_lower_bound"]
387 }
Jerry Ge51bd4f52024-02-20 11:21:19 -0800388 elif op["op"] in (Op.SIN, Op.COS):
389 mode = gtu.ComplianceMode.ABS_ERROR
390 if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
391 compliance_tens["abs_error_info"] = {
392 "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
393 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100394 else:
395 mode = gtu.ComplianceMode.EXACT
396 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
397
398 return compliance_tens
399
400 # Build Op functions
401 # Create the output tensor (calling OutputShaper as needed)
402 # Do final tweaks to attributes (if necessary for errorIf)
403 # Add Op into graph
404 # Return resulting tensor information or BuildInfo
405
406 class BuildInfo:
407 """Enhanced build information containing result tensor and associated compliance dict."""
408
409 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000410 if isinstance(resultTensor, list):
411 assert complianceDict is None or isinstance(complianceDict, list)
412 self.resultTensorList = resultTensor
413 self.complianceDictList = complianceDict
414 else:
415 self.resultTensorList = [resultTensor]
416 if complianceDict is None:
417 self.complianceDictList = None
418 else:
419 self.complianceDictList = [complianceDict]
420
421 def getComplianceInfo(self):
422 if self.complianceDictList is None:
423 return None
424 else:
425 tens_dict = {}
426 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
427 if comp is not None:
428 tens_dict[tens.name] = comp
429
430 if tens_dict:
431 # Have some compliance data, so return the info
432 compliance = {
433 "version": "0.1",
434 "tensors": tens_dict,
435 }
436 else:
437 compliance = None
438 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700439
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000440 def build_unary(
441 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
442 ):
443 assert len(inputs) == 1
444 a = inputs[0]
445 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100446
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000447 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100448
449 # Ensure new output type has correct qinfo
450 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000451 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000452 qinfo = [
453 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000454 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000455 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100456
457 # Invalidate Input/Output list for error if checks.
458 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000459 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100460 pCount, cCount = op["operands"]
461 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000462 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
463 self, error_name, input_list, output_list
464 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100465
Les Bell729b0352021-11-24 10:28:21 +0000466 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100467 self.ser,
468 validator_fcns,
469 error_name,
470 op=op,
471 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000472 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000473 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000474 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100475 input_list=input_list,
476 output_list=output_list,
477 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000478 ):
479 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100480
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000481 attr = None
482 if op["op"] == Op.NEGATE:
483 attr = ts.TosaSerializerAttribute()
484 attr.NegateAttribute(qinfo[0], qinfo[1])
485
486 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000487
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000488 compliance = self.tensorComplianceMetaData(
489 op, a.dtype, args_dict, result_tensor, error_name
490 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000491 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700492
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000493 def build_binary_broadcast(
494 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
495 ):
496 assert len(inputs) == 2
497 a, b = inputs
498 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000499 self.ser, self.rng, a, b, error_name
500 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100501
502 # Invalidate Input/Output list for error if checks.
503 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000504 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100505 pCount, cCount = op["operands"]
506 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000507 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
508 self, error_name, input_list, output_list
509 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100510
Les Bell729b0352021-11-24 10:28:21 +0000511 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100512 self.ser,
513 validator_fcns,
514 error_name,
515 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000516 input1=a,
517 input2=b,
518 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000519 output_dtype=result_tensor.dtype,
520 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100521 input_list=input_list,
522 output_list=output_list,
523 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000524 ):
525 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100526
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000527 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000528
Jeremy Johnson9a758382023-11-07 16:27:35 +0000529 compliance = self.tensorComplianceMetaData(
530 op, a.dtype, args_dict, result_tensor, error_name
531 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000532
533 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700534
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100535 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700536 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000537 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700538 return result_tens
539
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000540 def build_arithmetic_right_shift(
Jeremy Johnson587cc842024-02-08 11:45:44 +0000541 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000542 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000543 assert len(inputs) == 2
544 a, b = inputs
545 round = args_dict["round"]
546 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000547 self.ser, self.rng, a, b, error_name
548 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100549
550 # Invalidate Input/Output list for error if checks.
551 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000552 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100553 pCount, cCount = op["operands"]
554 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000555 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
556 self, error_name, input_list, output_list
557 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100558
Les Bell729b0352021-11-24 10:28:21 +0000559 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100560 self.ser,
561 validator_fcns,
562 error_name,
563 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000564 input1=a,
565 input2=b,
566 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000567 output_dtype=result_tensor.dtype,
568 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100569 input_list=input_list,
570 output_list=output_list,
571 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000572 ):
573 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800574
575 attr = ts.TosaSerializerAttribute()
576 attr.ArithmeticRightShiftAttribute(round)
577
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000578 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000579
580 compliance = self.tensorComplianceMetaData(
581 op, a.dtype, args_dict, result_tensor, error_name
582 )
583
584 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800585
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100586 def build_mul(
587 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
588 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000589 # Note that mul is binary operator but it has a shift value tensor
590 assert len(inputs) == 3
591 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100592
593 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000594 self.ser, self.rng, a, b, error_name
595 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700596
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100597 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100598 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100599 result_tensor.setDtype(DType.INT32)
600
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100601 if error_name == ErrorIf.WrongOutputType:
602 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
603 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100604 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100605
606 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000607 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100608 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100609 pCount, cCount = op["operands"]
610 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000611 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
612 self, error_name, input_list, output_list
613 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100614
Les Bell729b0352021-11-24 10:28:21 +0000615 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100616 self.ser,
617 validator_fcns,
618 error_name,
619 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000620 input1=a,
621 input2=b,
622 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100623 output_dtype=result_tensor.dtype,
624 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100625 input_list=input_list,
626 output_list=output_list,
627 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000628 ):
629 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700630
Jeremy Johnson0a042992024-02-28 13:20:05 +0000631 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100632
633 compliance = self.tensorComplianceMetaData(
634 op, a.dtype, args_dict, result_tensor, error_name
635 )
636
637 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700638
Jeremy Johnson587cc842024-02-08 11:45:44 +0000639 def build_table(
640 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
641 ):
642 assert len(inputs) == 1
643 a = inputs[0]
644 table = args_dict["table"]
645 result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700646
Kevin Chengfe392ce2021-10-18 21:51:55 +0000647 attr = ts.TosaSerializerAttribute()
648 attr.TableAttribute(table)
649
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100650 # Invalidate Input/Output list for error if checks.
651 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000652 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100653 pCount, cCount = op["operands"]
654 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000655 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
656 self, error_name, input_list, output_list
657 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100658
Les Bell729b0352021-11-24 10:28:21 +0000659 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100660 self.ser,
661 validator_fcns,
662 error_name,
663 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000664 input_shape=a.shape,
665 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000666 output_dtype=result_tensor.dtype,
667 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100668 input_list=input_list,
669 output_list=output_list,
670 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000671 ):
672 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100673
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000674 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700675
Jeremy Johnson587cc842024-02-08 11:45:44 +0000676 compliance = self.tensorComplianceMetaData(
677 op, a.dtype, args_dict, result_tensor, error_name
678 )
679
680 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700681
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000682 def build_select(
683 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
684 ):
685 assert len(inputs) == 3
686 cond, a, b = inputs
687
688 result_tensor = OutputShaper.selectOp(
689 self.ser, self.rng, cond, a, b, error_name
690 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100691
692 # Invalidate Input/Output list for error if checks.
693 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000694 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100695 pCount, cCount = op["operands"]
696 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000697 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
698 self, error_name, input_list, output_list
699 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100700
Les Bell729b0352021-11-24 10:28:21 +0000701 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100702 self.ser,
703 validator_fcns,
704 error_name,
705 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000706 input1=cond,
707 input2=a,
708 input3=b,
709 input_shape=a.shape,
710 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000711 output_dtype=result_tensor.dtype,
712 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100713 input_list=input_list,
714 output_list=output_list,
715 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000716 ):
717 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100718
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000719 self.ser.addOperator(
720 op["op"],
721 input_list,
722 output_list,
723 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000724 compliance = self.tensorComplianceMetaData(
725 op, a.dtype, args_dict, result_tensor, error_name
726 )
727
728 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700729
Jeremy Johnsona0150012023-11-15 15:52:06 +0000730 def build_comparison(
731 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
732 ):
733 assert len(inputs) == 2
734 a, b = inputs
735
736 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000737 self.ser, self.rng, a, b, error_name
738 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100739
740 # Invalidate Input/Output list for error if checks.
741 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000742 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100743 pCount, cCount = op["operands"]
744 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000745 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
746 self, error_name, input_list, output_list
747 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100748
Les Bell729b0352021-11-24 10:28:21 +0000749 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100750 self.ser,
751 validator_fcns,
752 error_name,
753 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000754 input1=a,
755 input2=b,
756 input_shape=a.shape,
757 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000758 output_shape=result_tensor.shape,
759 output_dtype=result_tensor.dtype,
760 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100761 input_list=input_list,
762 output_list=output_list,
763 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000764 ):
765 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100766
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000767 self.ser.addOperator(
768 op["op"],
769 input_list,
770 output_list,
771 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000772
773 compliance = self.tensorComplianceMetaData(
774 op, a.dtype, args_dict, result_tensor, error_name
775 )
776 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700777
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000778 def build_argmax(
779 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
780 ):
781 assert len(inputs) == 1
782 a = inputs[0]
783 axis = args_dict["axis"]
784 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100785
786 # Invalidate Input/Output list for error if checks.
787 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000788 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100789 pCount, cCount = op["operands"]
790 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000791 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
792 self, error_name, input_list, output_list
793 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100794
Les Bell729b0352021-11-24 10:28:21 +0000795 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100796 self.ser,
797 validator_fcns,
798 error_name,
799 op=op,
800 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000801 input_shape=a.shape,
802 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000803 output_shape=result_tensor.shape,
804 output_dtype=result_tensor.dtype,
805 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100806 input_list=input_list,
807 output_list=output_list,
808 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000809 ):
810 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700811
812 attr = ts.TosaSerializerAttribute()
813 attr.AxisAttribute(axis)
814
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000815 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000816
817 compliance = self.tensorComplianceMetaData(
818 op, inputs[0].dtype, args_dict, result_tensor, error_name
819 )
820 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700821
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000822 def build_pool2d(
823 self,
824 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100825 inputs,
826 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000827 validator_fcns=None,
828 error_name=None,
829 qinfo=None,
830 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100831 assert len(inputs) == 1
832 input = inputs[0]
833 # max_pool has no accum_dtype
834 accum_dtype = (
835 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
836 )
837 stride = args_dict["stride"]
838 pad = args_dict["pad"]
839 kernel = args_dict["kernel"]
840
Jeremy Johnson0601f802023-11-08 16:28:09 +0000841 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000842 self.ser, self.rng, input, kernel, stride, pad, error_name
843 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100844
845 # Ensure new output type has correct qinfo
846 if error_name == ErrorIf.WrongInputType:
847 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000848 qinfo = [
849 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000850 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000851 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100852
853 # Invalidate Input/Output list for error if checks.
854 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000855 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100856 pCount, cCount = op["operands"]
857 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000858 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
859 self, error_name, input_list, output_list
860 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100861
Les Bell729b0352021-11-24 10:28:21 +0000862 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100863 self.ser,
864 validator_fcns,
865 error_name,
866 op=op,
867 input_shape=input.shape,
868 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000869 output_shape=result_tensor.shape,
870 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000871 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100872 kernel=kernel,
873 stride=stride,
874 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000875 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000876 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100877 input_list=input_list,
878 output_list=output_list,
879 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000880 ):
881 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700882
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000883 if qinfo is None:
884 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700885
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000886 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100887 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000888
889 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700890
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100891 compliance = self.tensorComplianceMetaData(
892 op, inputs[0].dtype, args_dict, result_tensor, error_name
893 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100894
895 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100896
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000897 def build_conv2d(
898 self,
899 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100900 inputs,
901 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000902 validator_fcns=None,
903 error_name=None,
904 qinfo=None,
905 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100906 assert len(inputs) == 3
907 ifm, filter, bias = inputs
908 accum_dtype = args_dict["acc_type"]
909 strides = args_dict["stride"]
910 padding = args_dict["pad"]
911 dilations = args_dict["dilation"]
912
Kevin Cheng550ccc52021-03-03 11:21:43 -0800913 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100914 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100915 self.ser,
916 self.rng,
917 ifm,
918 filter,
919 accum_dtype,
920 strides,
921 padding,
922 dilations,
923 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000924 )
925
926 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000927 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
928 DType.INT8,
929 DType.UINT8,
930 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000931 qinfo = [
932 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100933 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000934 ]
Les Bell0e027d42021-11-09 14:42:14 +0000935
936 # Invalidate Input/Output list for error_if checks.
937 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100938 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000939 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000940 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
941 self, error_name, input_list, output_list
942 )
Les Bell0e027d42021-11-09 14:42:14 +0000943
Les Bell729b0352021-11-24 10:28:21 +0000944 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000945 self.ser,
946 validator_fcns,
947 error_name,
948 op=op,
949 input_dtype=ifm.dtype,
950 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100951 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000952 qinfo=qinfo,
953 input_list=input_list,
954 num_operands=num_operands,
955 output_list=output_list,
956 pad=padding,
957 stride=strides,
958 dilation=dilations,
959 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100960 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100961 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000962 ):
963 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700964
Tai Lyd3797f02023-11-15 23:06:19 +0000965 # TODO - Test local_bound, for now set local bound attribute to False
966 local_bound = False
967
Eric Kunzee5e26762020-10-13 16:11:07 -0700968 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000969 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700970
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000971 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100972
973 compliance = self.tensorComplianceMetaData(
974 op, ifm.dtype, args_dict, result_tensor, error_name
975 )
976
977 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700978
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000979 def build_conv3d(
980 self,
981 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100982 inputs,
983 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000984 validator_fcns=None,
985 error_name=None,
986 qinfo=None,
987 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100988 assert len(inputs) == 3
989 ifm, filter, bias = inputs
990 accum_dtype = args_dict["acc_type"]
991 strides = args_dict["stride"]
992 padding = args_dict["pad"]
993 dilations = args_dict["dilation"]
994
Kevin Cheng1533b852021-09-01 12:51:58 -0700995 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000996 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100997 self.ser,
998 self.rng,
999 ifm,
1000 filter,
1001 accum_dtype,
1002 strides,
1003 padding,
1004 dilations,
1005 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001006 )
1007
1008 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001009 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1010 DType.INT8,
1011 DType.UINT8,
1012 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001013 qinfo = [
1014 TosaQuantGen.getZeroPoint(self, ifm.dtype),
evacha0147ab1762024-01-29 13:23:23 +00001015 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001016 ]
Les Bell0e027d42021-11-09 14:42:14 +00001017
1018 # Invalidate Input/Output list for error_if checks.
1019 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +00001020 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001021 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001022 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1023 self, error_name, input_list, output_list
1024 )
Les Bell0e027d42021-11-09 14:42:14 +00001025
Les Bell729b0352021-11-24 10:28:21 +00001026 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001027 self.ser,
1028 validator_fcns,
1029 error_name,
1030 op=op,
1031 input_dtype=ifm.dtype,
1032 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +00001033 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001034 qinfo=qinfo,
1035 input_list=input_list,
1036 num_operands=num_operands,
1037 output_list=output_list,
1038 pad=padding,
1039 stride=strides,
1040 dilation=dilations,
1041 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001042 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +00001043 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001044 ):
1045 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07001046
Tai Lyd3797f02023-11-15 23:06:19 +00001047 # TODO - Test local_bound, for now set local bound attribute to False
1048 local_bound = False
1049
Kevin Cheng1533b852021-09-01 12:51:58 -07001050 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001051 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -07001052
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001053 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001054
1055 compliance = self.tensorComplianceMetaData(
1056 op, ifm.dtype, args_dict, result_tensor, error_name
1057 )
1058
1059 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001060
Kevin Cheng550ccc52021-03-03 11:21:43 -08001061 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001062 self,
1063 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001064 inputs,
1065 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001066 validator_fcns=None,
1067 error_name=None,
1068 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001069 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001070 assert len(inputs) == 3
1071 ifm, filter, bias = inputs
1072 accum_dtype = args_dict["acc_type"]
1073 strides = args_dict["stride"]
1074 out_pad = args_dict["pad"]
1075 output_shape = args_dict["out_shape"]
1076
TatWai Chong24594f52022-06-08 00:48:04 -07001077 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001078 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001079 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001080 )
Les Bell0e027d42021-11-09 14:42:14 +00001081
1082 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001083 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1084 DType.INT8,
1085 DType.UINT8,
1086 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001087 qinfo = [
1088 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001089 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001090 ]
Les Bell0e027d42021-11-09 14:42:14 +00001091
1092 # Invalidate Input/Output list for error_if checks.
1093 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001094 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001095 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001096 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1097 self, error_name, input_list, output_list
1098 )
Les Bell0e027d42021-11-09 14:42:14 +00001099
Les Bell729b0352021-11-24 10:28:21 +00001100 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001101 self.ser,
1102 validator_fcns,
1103 error_name,
1104 op=op,
1105 input_dtype=ifm.dtype,
1106 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001107 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001108 qinfo=qinfo,
1109 input_list=input_list,
1110 num_operands=num_operands,
1111 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001112 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001113 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001114 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001115 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001116 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001117 ):
1118 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001119
Tai Lyd3797f02023-11-15 23:06:19 +00001120 # TODO - Test local_bound, for now set local bound attribute to False
1121 local_bound = False
1122
Eric Kunzee5e26762020-10-13 16:11:07 -07001123 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001124 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001125 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001126 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001127
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001128 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001129
1130 compliance = self.tensorComplianceMetaData(
1131 op, ifm.dtype, args_dict, result_tensor, error_name
1132 )
1133
1134 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001135
Kevin Cheng550ccc52021-03-03 11:21:43 -08001136 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001137 self,
1138 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001139 inputs,
1140 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001141 validator_fcns=None,
1142 error_name=None,
1143 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001144 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001145 assert len(inputs) == 3
1146 ifm, filter, bias = inputs
1147 accum_dtype = args_dict["acc_type"]
1148 strides = args_dict["stride"]
1149 padding = args_dict["pad"]
1150 dilations = args_dict["dilation"]
1151
Jeremy Johnson4f931302024-01-04 17:05:24 +00001152 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001153 self.ser,
1154 self.rng,
1155 ifm,
1156 filter,
1157 accum_dtype,
1158 strides,
1159 padding,
1160 dilations,
1161 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001162 )
1163
1164 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001165 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1166 DType.INT8,
1167 DType.UINT8,
1168 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001169 qinfo = [
1170 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001171 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001172 ]
Les Bell0e027d42021-11-09 14:42:14 +00001173
1174 # Invalidate Input/Output list for error_if checks.
1175 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001176 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001177 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001178 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1179 self, error_name, input_list, output_list
1180 )
Les Bell0e027d42021-11-09 14:42:14 +00001181
Les Bell729b0352021-11-24 10:28:21 +00001182 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001183 self.ser,
1184 validator_fcns,
1185 error_name,
1186 op=op,
1187 input_dtype=ifm.dtype,
1188 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001189 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001190 qinfo=qinfo,
1191 input_list=input_list,
1192 num_operands=num_operands,
1193 output_list=output_list,
1194 pad=padding,
1195 stride=strides,
1196 dilation=dilations,
1197 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001198 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001199 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001200 ):
1201 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001202
Tai Lyd3797f02023-11-15 23:06:19 +00001203 # TODO - Test local_bound, for now set local bound attribute to False
1204 local_bound = False
1205
Eric Kunzee5e26762020-10-13 16:11:07 -07001206 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001207 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001208
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001209 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001210
1211 compliance = self.tensorComplianceMetaData(
1212 op, ifm.dtype, args_dict, result_tensor, error_name
1213 )
1214
1215 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001216
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001217 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001218 self,
1219 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001220 inputs,
1221 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001222 validator_fcns=None,
1223 error_name=None,
1224 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001225 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001226 assert len(inputs) == 3
1227 ifm, filter, bias = inputs
1228 accum_dtype = args_dict["acc_type"]
1229
1230 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001231 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001232 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001233
1234 # Invalidate Input/Output list for error if checks.
1235 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001236 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001237 pCount, cCount = op["operands"]
1238 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001239 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1240 self, error_name, input_list, output_list
1241 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001242
Les Bell729b0352021-11-24 10:28:21 +00001243 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001244 self.ser,
1245 validator_fcns,
1246 error_name,
1247 op=op,
1248 input_shape=ifm.shape,
1249 input_dtype=ifm.dtype,
1250 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001251 output_shape=result_tensor.shape,
1252 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001253 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001254 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001255 input_list=input_list,
1256 output_list=output_list,
1257 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001258 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001259 ):
1260 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001261
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001262 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001263 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001264
1265 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001266
1267 compliance = self.tensorComplianceMetaData(
1268 op, ifm.dtype, args_dict, result_tensor, error_name
1269 )
1270
1271 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001272
James Ward8b390432022-08-12 20:48:56 +01001273 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001274 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001275 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001276 assert len(inputs) == 2
1277 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001278 accum_dtype = args_dict["acc_type"]
1279 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001280 self.ser, self.rng, a, b, accum_dtype, error_name
1281 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001282
1283 # Invalidate Input/Output list for error if checks.
1284 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001285 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001286 pCount, cCount = op["operands"]
1287 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001288 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1289 self, error_name, input_list, output_list
1290 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001291
Les Bell729b0352021-11-24 10:28:21 +00001292 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001293 self.ser,
1294 validator_fcns,
1295 error_name,
1296 op=op,
1297 input_shape=a.shape,
1298 input_dtype=a.dtype,
1299 input2_shape=b.shape,
1300 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001301 output_shape=result_tensor.shape,
1302 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001303 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001304 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001305 input_list=input_list,
1306 output_list=output_list,
1307 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001308 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001309 ):
1310 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001311
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001312 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001313 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001314
1315 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001316
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001317 compliance = self.tensorComplianceMetaData(
1318 op, a.dtype, args_dict, result_tensor, error_name
1319 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001320
1321 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001322
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001323 def build_reduce(
1324 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1325 ):
1326 assert len(inputs) == 1
1327 a = inputs[0]
1328 axis = args_dict["axis"]
1329 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001330
1331 # Invalidate Input/Output list for error if checks.
1332 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001333 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001334 pCount, cCount = op["operands"]
1335 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001336 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1337 self, error_name, input_list, output_list
1338 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001339
Les Bell729b0352021-11-24 10:28:21 +00001340 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001341 self.ser,
1342 validator_fcns,
1343 error_name,
1344 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001345 axis=axis,
1346 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001347 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001348 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001349 output_dtype=result_tensor.dtype,
1350 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001351 input_list=input_list,
1352 output_list=output_list,
1353 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001354 ):
1355 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001356
1357 attr = ts.TosaSerializerAttribute()
1358 attr.AxisAttribute(axis)
1359
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001360 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001361
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001362 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1363 # Number of products - needed for compliance
1364 args_dict["n"] = a.shape[axis]
1365
1366 compliance = self.tensorComplianceMetaData(
1367 op, a.dtype, args_dict, result_tensor, error_name
1368 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001369
1370 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001371
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001372 def build_clamp(
1373 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1374 ):
1375 assert len(inputs) == 1
1376 a = inputs[0]
1377
1378 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001379
Jeremy Johnson18e26662021-07-22 16:15:29 +01001380 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001381
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001382 if error_name == ErrorIf.MaxSmallerMin:
1383 # Make sure the numbers are different to invoke this error
1384 while v[0] == v[1]:
1385 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1386 max_val = min(v)
1387 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001388 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001389 max_val = max(v)
1390 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001391
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001392 # Invalidate Input/Output list for error if checks.
1393 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001394 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001395 pCount, cCount = op["operands"]
1396 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001397 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1398 self, error_name, input_list, output_list
1399 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001400
Les Bell729b0352021-11-24 10:28:21 +00001401 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001402 self.ser,
1403 validator_fcns,
1404 error_name,
1405 op=op,
1406 max_val=max_val,
1407 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001408 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001409 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001410 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001411 output_dtype=result_tensor.dtype,
1412 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001413 input_list=input_list,
1414 output_list=output_list,
1415 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001416 ):
1417 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001418
1419 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001420 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1421 if a.dtype == DType.FP16:
1422 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1423 min_val = min_val.astype(np.float32)
1424 max_val = max_val.astype(np.float32)
1425
1426 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001427 elif a.dtype in (DType.INT8, DType.INT16):
James Ward34071252022-12-07 15:48:47 +00001428 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Won Jeon2c34b462024-02-06 18:37:00 +00001429 else:
1430 # to avoid internal error for incorrect input types
1431 attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001433 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001434
1435 compliance = self.tensorComplianceMetaData(
1436 op, a.dtype, args_dict, result_tensor, error_name
1437 )
1438
1439 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001440
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001441 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1442 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001443 attr = ts.TosaSerializerAttribute()
1444
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001445 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001446
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001447 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001448 return result_tens
1449
1450 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001451 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1452 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001453
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001454 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001455 return result_tens
1456
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001457 def build_activation(
1458 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1459 ):
1460 assert len(inputs) == 1
1461 a = inputs[0]
1462
1463 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001464
1465 # Invalidate Input/Output list for error if checks.
1466 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001467 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001468 pCount, cCount = op["operands"]
1469 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001470 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1471 self, error_name, input_list, output_list
1472 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001473
Les Bell729b0352021-11-24 10:28:21 +00001474 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001475 self.ser,
1476 validator_fcns,
1477 error_name,
1478 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001479 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001480 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001481 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001482 output_dtype=result_tensor.dtype,
1483 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001484 input_list=input_list,
1485 output_list=output_list,
1486 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001487 ):
1488 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001489
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001490 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001491
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001492 compliance = self.tensorComplianceMetaData(
1493 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001494 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001495
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001496 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001497
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001498 def build_concat(
1499 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1500 ):
Won Jeon74342e52024-01-09 00:34:40 +00001501 if op["op"] == Op.CONCAT_SHAPE:
1502 axis = 0
1503 else:
1504 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001505 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001506 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001507
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001508 result_tensor = OutputShaper.concatOp(
1509 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001510 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001511
Matthew Haddon818ab902021-07-27 09:12:49 +01001512 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001513 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001514 input_tensor_names.append(tensor.name)
1515
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001516 # Invalidate Input/Output list for error if checks.
1517 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001518 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001519 pCount, cCount = op["operands"]
1520 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001521 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1522 self, error_name, input_list, output_list
1523 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001524
Les Bell729b0352021-11-24 10:28:21 +00001525 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001526 self.ser,
1527 validator_fcns,
1528 error_name,
1529 op=op,
1530 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001531 input_shape=inputs[0].shape,
1532 output_shape=result_tensor.shape,
1533 input_dtype=inputs[0].dtype,
1534 output_dtype=result_tensor.dtype,
1535 inputs=inputs,
1536 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001537 input_list=input_list,
1538 output_list=output_list,
1539 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001540 ):
1541 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001542
Won Jeon74342e52024-01-09 00:34:40 +00001543 if op["op"] == Op.CONCAT:
1544 attr = ts.TosaSerializerAttribute()
1545 attr.AxisAttribute(axis)
1546 else:
1547 assert op["op"] == Op.CONCAT_SHAPE
1548 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001549 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001550
1551 compliance = self.tensorComplianceMetaData(
1552 op, inputs[0].dtype, args_dict, result_tensor, error_name
1553 )
1554
1555 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001556
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001557 def build_pad(
1558 self,
1559 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001560 inputs,
1561 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001562 validator_fcns=None,
1563 error_name=None,
1564 qinfo=None,
1565 ):
Tai Lye095da72024-01-25 22:00:18 +00001566 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001567 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001568 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001569 padding = args_dict["pad"]
1570 pad_const_int = args_dict["pad_const_int"]
1571 pad_const_float = args_dict["pad_const_fp"]
1572
1573 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001574
Tai Lye095da72024-01-25 22:00:18 +00001575 # write empty padding into PadAttribute to ensure inputs[1] is used
Kevin Chengfe392ce2021-10-18 21:51:55 +00001576 attr = ts.TosaSerializerAttribute()
Tai Lye095da72024-01-25 22:00:18 +00001577 attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001578
Matthew Haddone807aae2021-10-11 18:12:58 +01001579 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001580 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001581 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001582 pCount, cCount = op["operands"]
1583 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001584 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1585 self, error_name, input_list, output_list
1586 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001587
Les Bell729b0352021-11-24 10:28:21 +00001588 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001589 self.ser,
1590 validator_fcns,
1591 error_name,
1592 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001593 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001594 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001595 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001596 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001597 pad=padding,
1598 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001599 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001600 input_list=input_list,
1601 output_list=output_list,
1602 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001603 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001604 ):
1605 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001606
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001607 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001608
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001609 compliance = self.tensorComplianceMetaData(
1610 op, a.dtype, args_dict, result_tensor, error_name
1611 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001612
1613 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001614
Won Jeona21b2e82023-08-10 10:33:01 +00001615 def build_dim(
1616 self,
1617 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001618 inputs,
1619 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001620 validator_fcns=None,
1621 error_name=None,
1622 qinfo=None,
1623 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001624 assert len(inputs) == 1
1625 a = inputs[0]
1626 axis = args_dict["axis"]
1627 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001628
1629 # Invalidate Input/Output list for error if checks.
1630 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001631 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001632 pCount, cCount = op["operands"]
1633 num_operands = pCount + cCount
1634 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1635 self, error_name, input_list, output_list
1636 )
1637
1638 if not TosaErrorValidator.evValidateErrorIfs(
1639 self.ser,
1640 validator_fcns,
1641 error_name,
1642 op=op,
1643 axis=axis,
1644 input_shape=a.shape,
1645 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001646 output_shape=result_tensor.shape,
1647 output_dtype=result_tensor.dtype,
1648 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001649 input_list=input_list,
1650 output_list=output_list,
1651 num_operands=num_operands,
1652 ):
1653 return None
1654
1655 attr = ts.TosaSerializerAttribute()
1656 attr.AxisAttribute(axis)
1657
1658 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001659 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001660
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001661 def build_reshape(
1662 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1663 ):
Tai Ly8690a082023-12-18 20:40:24 +00001664 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001665 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001666 shape = inputs[1]
1667 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001668 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001669 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001670 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001671
1672 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001673 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001674 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001675 pCount, cCount = op["operands"]
1676 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001677 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1678 self, error_name, input_list, output_list
1679 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001680
Les Bell729b0352021-11-24 10:28:21 +00001681 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001682 self.ser,
1683 validator_fcns,
1684 error_name,
1685 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001686 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001687 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001688 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001689 output_dtype=result_tensor.dtype,
1690 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001691 input_list=input_list,
1692 output_list=output_list,
1693 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001694 ):
1695 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001696
Tai Ly8690a082023-12-18 20:40:24 +00001697 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001698
1699 compliance = self.tensorComplianceMetaData(
1700 op, a.dtype, args_dict, result_tensor, error_name
1701 )
1702
1703 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001704
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001705 def build_reverse(
1706 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1707 ):
1708 assert len(inputs) == 1
1709 a = inputs[0]
1710 axis = args_dict["axis"]
1711 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001712
1713 # Invalidate Input/Output list for error if checks.
1714 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001715 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001716 pCount, cCount = op["operands"]
1717 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001718 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1719 self, error_name, input_list, output_list
1720 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001721
Les Bell729b0352021-11-24 10:28:21 +00001722 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001723 self.ser,
1724 validator_fcns,
1725 error_name,
1726 op=op,
1727 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001728 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001729 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001730 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001731 output_dtype=result_tensor.dtype,
1732 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001733 input_list=input_list,
1734 output_list=output_list,
1735 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001736 ):
1737 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001738
1739 attr = ts.TosaSerializerAttribute()
1740 attr.AxisAttribute(axis)
1741
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001742 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001743 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001744
evacha0198477222024-01-26 12:25:32 +00001745 def build_transpose(
1746 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1747 ):
1748 assert len(inputs) == 1
1749 a = inputs[0]
1750 perms = args_dict["perms"]
1751
1752 result_tensor = OutputShaper.transposeOp(
1753 self.ser, self.rng, a, perms, error_name
1754 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001755
Kevin Chengfe392ce2021-10-18 21:51:55 +00001756 attr = ts.TosaSerializerAttribute()
1757 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001758
Matthew Haddone807aae2021-10-11 18:12:58 +01001759 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001760 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001761 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001762 pCount, cCount = op["operands"]
1763 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001764 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1765 self, error_name, input_list, output_list
1766 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001767
Les Bell729b0352021-11-24 10:28:21 +00001768 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001769 self.ser,
1770 validator_fcns,
1771 error_name,
1772 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001773 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001774 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001775 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001776 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001777 output_dtype=result_tensor.dtype,
1778 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001779 input_list=input_list,
1780 output_list=output_list,
1781 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001782 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001783 ):
1784 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001785
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001786 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001787
1788 compliance = self.tensorComplianceMetaData(
1789 op, a.dtype, args_dict, result_tensor, error_name
1790 )
1791
1792 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001793
evacha017f7d4252024-01-24 12:08:09 +00001794 def build_slice(
1795 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1796 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001797 assert len(inputs) == 3
1798 a, start_var, size_var = inputs
1799 start_const = args_dict["start"]
1800 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001801
1802 result_tensor = OutputShaper.sliceOp(
TatWai Chongf15bad82024-01-31 21:33:27 -08001803 self.ser, self.rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001804 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001805
1806 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001807 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001808 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001809 pCount, cCount = op["operands"]
1810 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001811 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1812 self, error_name, input_list, output_list
1813 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001814
Les Bell729b0352021-11-24 10:28:21 +00001815 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001816 self.ser,
1817 validator_fcns,
1818 error_name,
1819 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001820 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001821 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001822 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001823 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001824 start=start_const,
1825 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001826 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001827 input_list=input_list,
1828 output_list=output_list,
1829 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001830 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001831 ):
1832 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001833
Tai Ly8ead6c42024-02-14 22:35:44 +00001834 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001835
1836 compliance = self.tensorComplianceMetaData(
1837 op, a.dtype, args_dict, result_tensor, error_name
1838 )
1839
1840 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001841
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001842 def build_tile(
1843 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1844 ):
Tai Ly8690a082023-12-18 20:40:24 +00001845 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001846 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001847 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001848 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001849 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001850 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001851 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001852
1853 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001854 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001855 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001856 pCount, cCount = op["operands"]
1857 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001858 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1859 self, error_name, input_list, output_list
1860 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001861
Les Bell729b0352021-11-24 10:28:21 +00001862 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001863 self.ser,
1864 validator_fcns,
1865 error_name,
1866 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001867 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001868 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001869 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001870 output_dtype=result_tensor.dtype,
1871 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001872 input_list=input_list,
1873 output_list=output_list,
1874 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001875 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001876 ):
1877 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001878
Tai Ly8690a082023-12-18 20:40:24 +00001879 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001880
1881 compliance = self.tensorComplianceMetaData(
1882 op, a.dtype, args_dict, result_tensor, error_name
1883 )
1884
1885 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001886
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001887 def build_gather(
1888 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1889 ):
1890 assert len(inputs) == 2
1891 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001892
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001893 result_tensor = OutputShaper.gatherOp(
1894 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001895 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001896
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001897 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001898 input_list = [values.name, indices.name]
1899 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001900 pCount, cCount = op["operands"]
1901 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001902 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1903 self, error_name, input_list, output_list
1904 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001905
Les Bell729b0352021-11-24 10:28:21 +00001906 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001907 self.ser,
1908 validator_fcns,
1909 error_name,
1910 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001911 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001912 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001913 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001914 output_dtype=result_tensor.dtype,
1915 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001916 input_list=input_list,
1917 output_list=output_list,
1918 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001919 ):
1920 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001921
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001922 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001923
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001924 compliance = self.tensorComplianceMetaData(
1925 op, values.dtype, args_dict, result_tensor, error_name
1926 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001927
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001928 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001929
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001930 def build_scatter(
1931 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1932 ):
1933 assert len(inputs) == 3
1934 values_in, indices, input = inputs
1935 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001936 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001937 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001938
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001939 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001940 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001941 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001942 pCount, cCount = op["operands"]
1943 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001944 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1945 self, error_name, input_list, output_list
1946 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001947
Les Bell729b0352021-11-24 10:28:21 +00001948 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001949 self.ser,
1950 validator_fcns,
1951 error_name,
1952 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001953 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001954 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001955 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001956 output_dtype=result_tensor.dtype,
1957 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001958 input_list=input_list,
1959 output_list=output_list,
1960 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001961 ):
1962 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001963
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001964 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001965
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001966 compliance = self.tensorComplianceMetaData(
1967 op, values_in.dtype, args_dict, result_tensor, error_name
1968 )
1969
1970 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001971
Kevin Cheng550ccc52021-03-03 11:21:43 -08001972 def build_resize(
1973 self,
1974 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001975 inputs,
1976 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01001977 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001978 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001979 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001980 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001981 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001982 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001983 scale_input = inputs[1]
1984 offset_input = inputs[2]
1985 border_input = inputs[3]
1986
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001987 mode = args_dict["mode"]
1988 scale = args_dict["scale"]
1989 offset = args_dict["offset"]
1990 border = args_dict["border"]
1991 output_dtype = args_dict["output_dtype"]
1992
1993 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08001994 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001995 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001996 input,
1997 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001998 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001999 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002000 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002001 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002002 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002003 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002004 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002005
Matthew Haddon848efb42021-09-09 12:30:53 +01002006 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002007 input_list = [
2008 input.name,
2009 scale_input.name,
2010 offset_input.name,
2011 border_input.name,
2012 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002013 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002014 pCount, cCount = op["operands"]
2015 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002016 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2017 self, error_name, input_list, output_list
2018 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002019
Les Bell729b0352021-11-24 10:28:21 +00002020 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002021 self.ser,
2022 validator_fcns,
2023 error_name,
2024 op=op,
2025 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002026 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002027 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002028 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002029 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002030 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002031 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002032 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002033 input_list=input_list,
2034 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002035 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002036 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002037 ):
2038 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002039
Eric Kunzee5e26762020-10-13 16:11:07 -07002040 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002041 # write empty scale/offset/border into ResizeAttribute
2042 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002043 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002044
2045 compliance = self.tensorComplianceMetaData(
2046 op, input.dtype, args_dict, result_tensor, error_name
2047 )
2048
2049 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002050
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002051 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
2052 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
2053 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002054 self.ser.addOperator(
2055 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2056 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002057 return result_tens
2058
evacha0198477222024-01-26 12:25:32 +00002059 def build_const(
2060 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2061 ):
2062 assert len(inputs) == 1
2063 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002064 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002065
2066 compliance = self.tensorComplianceMetaData(
2067 op, val.dtype, args_dict, val, error_name
2068 )
2069
2070 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002071
2072 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002073 def build_cast(
2074 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2075 ):
2076 assert len(inputs) == 1
2077 val = inputs[0]
2078 out_dtype = args_dict["out_type"]
2079
2080 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002081 self.ser, self.rng, val, out_dtype, error_name
2082 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002083
2084 # Invalidate Input/Output list for error if checks.
2085 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002086 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002087 pCount, cCount = op["operands"]
2088 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002089 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2090 self, error_name, input_list, output_list
2091 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002092
Les Bell729b0352021-11-24 10:28:21 +00002093 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002094 self.ser,
2095 validator_fcns,
2096 error_name,
2097 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002098 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002099 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002100 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002101 output_dtype=result_tensor.dtype,
2102 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002103 input_list=input_list,
2104 output_list=output_list,
2105 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002106 ):
2107 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002108
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002109 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002110
2111 compliance = self.tensorComplianceMetaData(
2112 op, val.dtype, args_dict, result_tensor, error_name
2113 )
2114
2115 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002116
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002117 def build_rescale(
2118 self,
2119 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002120 inputs,
2121 args_dict,
2122 validator_fcns=None,
2123 error_name=None,
2124 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002125 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002126 assert len(inputs) == 3
Jeremy Johnson587cc842024-02-08 11:45:44 +00002127 val = inputs[0]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002128 multiplier_val = inputs[1]
2129 shift_val = inputs[2]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002130 out_dtype = args_dict["output_dtype"]
2131 scale32 = args_dict["scale"]
2132 double_round = args_dict["double_round"]
2133 per_channel = args_dict["per_channel"]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002134 shift_arr = args_dict["shift"]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002135
2136 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002137 self.ser, self.rng, val, out_dtype, error_name
2138 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002139
2140 if per_channel:
2141 nc = val.shape[-1]
2142 else:
2143 nc = 1
2144
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002145 in_type_width = gtu.dtypeWidth(val.dtype)
2146 out_type_width = gtu.dtypeWidth(out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002147
Tai Ly8690a082023-12-18 20:40:24 +00002148 input_unsigned = False
2149 output_unsigned = False
2150
Kevin Cheng3a478572021-01-22 17:21:02 -08002151 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002152 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002153 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002154 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002155 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002156 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002157 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002158 elif error_name in [
2159 ErrorIf.InputZeroPointNotZero,
2160 ErrorIf.U16InputZeroPointNotValid,
2161 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002162 input_zp = self.randInt(-128, 128)
2163 if input_zp == 0:
2164 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002165 in_type_width += 1
2166 elif val.dtype == DType.UINT16:
2167 # Must come after ErrorIf.U16InputZeroPointNotValid check
2168 input_zp = self.rng.choice([0, 32768])
2169 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002170 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002171 else:
2172 input_zp = 0
2173
Kevin Cheng3a478572021-01-22 17:21:02 -08002174 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002175 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002176 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002177 elif out_dtype == DType.UINT8:
2178 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002179 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002180 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002181 elif error_name in [
2182 ErrorIf.OutputZeroPointNotZero,
2183 ErrorIf.U16OutputZeroPointNotValid,
2184 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002185 output_zp = self.randInt(-128, 128)
2186 if output_zp == 0:
2187 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002188 out_type_width += 1
2189 elif out_dtype == DType.UINT16:
2190 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2191 output_zp = self.rng.choice([0, 32768])
2192 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002193 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002194 else:
2195 output_zp = 0
2196
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002197 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2198 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002199
2200 for i in range(nc):
Eric Kunze750d27d2022-06-30 21:37:09 +00002201 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2202 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002203
Kevin Cheng550ccc52021-03-03 11:21:43 -08002204 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002205 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002206 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002207 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002208 assert val.placeholderFilename
2209 values = np.load(
2210 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2211 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002212 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2213 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2214 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002215 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2216 # Check we can safely convert to the expected dtype
2217 assert (
2218 val_adj.all() >= np.iinfo(values.dtype).min
2219 and val_adj.all() <= np.iinfo(values.dtype).max
2220 )
2221
2222 # Force casting to output datatype
2223 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2224
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002225 if not np.all(np.array_equal(values, val_adj)):
2226 # Values changed so overwrite file with new values
2227 np.save(
2228 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2229 val_adj,
2230 False,
2231 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002232
Matthew Haddonc2025212021-10-08 21:21:05 +01002233 # Invalidate Input/Output list for error if checks.
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002234 input_list = [val.name, multiplier_val.name, shift_val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002235 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002236 pCount, cCount = op["operands"]
2237 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002238 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2239 self, error_name, input_list, output_list
2240 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002241
2242 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002243 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002244 self.ser,
2245 validator_fcns,
2246 error_name,
2247 op=op,
2248 input_dtype=val.dtype,
2249 output_dtype=out_dtype,
2250 input_shape=val.shape,
2251 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002252 scale32=scale32,
2253 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002254 input_list=input_list,
2255 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002256 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002257 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002258 ):
2259 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002260
Eric Kunzee5e26762020-10-13 16:11:07 -07002261 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002262 attr.RescaleAttribute(
2263 input_zp,
2264 output_zp,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002265 [],
2266 [],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002267 scale32,
2268 double_round,
2269 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002270 input_unsigned,
2271 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002272 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002273
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002274 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002275
2276 compliance = self.tensorComplianceMetaData(
2277 op, val.dtype, args_dict, result_tensor, error_name
2278 )
2279
2280 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002281
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002282 def _get_condition_tensor(self, op, cond, error_name):
2283 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002284 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002285 else:
2286 cond_type = DType.BOOL
2287 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2288 choice = self.rng.choice([1, 2])
2289 if choice == 1:
2290 cond_shape = [2]
2291 else:
2292 cond_shape = [1, 2]
2293 else:
2294 # Must be of size 1 (rank 0)
2295 cond_shape = []
2296 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2297 return cond_tens
2298
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002299 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002300 self,
2301 op,
2302 inputs,
2303 args_dict,
2304 validator_fcns=None,
2305 error_name=None,
2306 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002307 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002308 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002309 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002310 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002311 assert len(inputs) == 2
2312 then_tens, else_tens = inputs
2313
2314 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002315
2316 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002317 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002318
2319 # Make then/else tensors
2320 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002321
Jeremy Johnson587cc842024-02-08 11:45:44 +00002322 dtype = DType.INT32
2323
Matthew Haddon630c17c2021-10-14 15:05:41 +01002324 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002325 if error_name in [
2326 ErrorIf.CondIfOutputListThenGraphMismatch,
2327 ErrorIf.CondIfOutputListElseGraphMismatch,
2328 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002329 incorrect_shape = deepcopy(then_tens.shape)
2330 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002331 incorrect_shape[i] += (
2332 self.rng.choice([-3, -2, 2, 3])
2333 if incorrect_shape[i] > 3
2334 else self.rng.choice([1, 2, 4])
2335 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002336 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2337
Jeremy Johnson18e26662021-07-22 16:15:29 +01002338 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2339 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002340
2341 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002342 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002343
2344 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002345 then_block = "THEN_BLOCK"
2346 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002347 attr = ts.TosaSerializerAttribute()
2348 attr.CondIfAttribute(then_block, else_block)
2349
2350 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002351 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002352
Jerry Ge9e94af82022-10-27 09:57:00 -07002353 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002354 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002355 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002356 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002357 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002358 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002359 self.ser.addOutputTensor(then_tens)
2360
Jerry Ge9e94af82022-10-27 09:57:00 -07002361 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002362 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002363 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002364 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002365 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002366 self.ser.addOutputTensor(else_tens)
2367
Les Bell729b0352021-11-24 10:28:21 +00002368 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002369 self.ser,
2370 validator_fcns,
2371 error_name,
2372 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002373 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002374 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002375 ):
2376 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002377
Jeremy Johnson587cc842024-02-08 11:45:44 +00002378 compliance = self.tensorComplianceMetaData(
2379 op, dtype, args_dict, result_tensor, error_name
2380 )
2381
2382 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002383
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002384 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002385 self,
2386 op,
2387 inputs,
2388 args_dict,
2389 validator_fcns=None,
2390 error_name=None,
2391 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002392 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002393 # For cond_if with a binary op in the then/else blocks, take a and b and
2394 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002395 assert len(inputs) == 2
2396 a, b = inputs
2397
2398 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002399
2400 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002401 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002402
Jeremy Johnson587cc842024-02-08 11:45:44 +00002403 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002404
2405 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002406 then_block = "THEN_BLOCK"
2407 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002408 attr = ts.TosaSerializerAttribute()
2409 attr.CondIfAttribute(then_block, else_block)
2410
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002411 if error_name in [
2412 ErrorIf.CondIfInputListThenGraphMismatch,
2413 ErrorIf.CondIfInputListElseGraphMismatch,
2414 ErrorIf.CondIfOutputListElseGraphMismatch,
2415 ErrorIf.CondIfOutputListThenGraphMismatch,
2416 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002417 incorrect_shape = a.shape.copy()
2418 for i in range(len(incorrect_shape)):
2419 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2420 incorrect_block_input = deepcopy(a)
2421 incorrect_block_input.shape = incorrect_shape
2422
Eric Kunzee5e26762020-10-13 16:11:07 -07002423 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002424 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002425 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002426 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002427
James Ward24dbc422022-10-19 12:20:31 +01002428 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002429 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002430 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002431 then_op, else_op = (
2432 self.TOSA_OP_LIST["logical_right_shift"],
2433 self.TOSA_OP_LIST["logical_left_shift"],
2434 )
Les Bell6040b4d2021-10-11 12:50:31 +01002435 else:
2436 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002437
Jeremy Johnson587cc842024-02-08 11:45:44 +00002438 # Determine the element-wise binary operation that compliance will need to
2439 # check the results of
2440 compliance_op = then_op if cond else else_op
2441
2442 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002443 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002444 if (
2445 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2446 and block == then_block
2447 ) or (
2448 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2449 and block == else_block
2450 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002451 self.ser.addInputTensor(incorrect_block_input)
2452 self.ser.addInputTensor(b)
2453 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002454 elif (
2455 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2456 and block == then_block
2457 ) or (
2458 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2459 and block == else_block
2460 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002461 self.ser.addInputTensor(a)
2462 self.ser.addInputTensor(b)
2463 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2464 else:
2465 self.ser.addInputTensor(a)
2466 self.ser.addInputTensor(b)
2467 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002468 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002469
Les Bell729b0352021-11-24 10:28:21 +00002470 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002471 self.ser,
2472 validator_fcns,
2473 error_name,
2474 op=op,
2475 a=a,
2476 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002477 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002478 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002479 ):
2480 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002481
Jeremy Johnson587cc842024-02-08 11:45:44 +00002482 compliance = self.tensorComplianceMetaData(
2483 compliance_op, a.dtype, args_dict, result_tensor, error_name
2484 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002485
Jeremy Johnson587cc842024-02-08 11:45:44 +00002486 return TosaTestGen.BuildInfo(result_tensor, compliance)
2487
2488 def build_while_loop(
2489 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2490 ):
2491 assert len(inputs) == 1
2492 a = inputs[0]
2493 iter_val = args_dict["iterations"]
2494
Kevin Cheng550ccc52021-03-03 11:21:43 -08002495 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002496
Kevin Cheng550ccc52021-03-03 11:21:43 -08002497 cond_block = "COND_BLOCK"
2498 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002499
2500 attr = ts.TosaSerializerAttribute()
2501 attr.WhileLoopAttribute(cond_block, body_block)
2502
2503 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002504 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002505 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002506 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002507
2508 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002509 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2510 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002511 if error_name == ErrorIf.InputListOutputListMismatch:
2512 incorrect_acc = deepcopy(acc)
2513 for i in range(len(incorrect_acc.shape)):
2514 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2515 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2516 else:
2517 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002518
2519 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002520 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002521 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002522 [iter.name, a.name, acc.name],
2523 [iter_out.name, a_out.name, acc_out.name],
2524 attr,
2525 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002526 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002527
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002528 if error_name in [
2529 ErrorIf.InputListCondGraphMismatch,
2530 ErrorIf.InputListBodyGraphInputMismatch,
2531 ErrorIf.InputListBodyGraphOutputMismatch,
2532 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002533 incorrect_iter = deepcopy(iter)
2534 for i in range(len(incorrect_iter.shape)):
2535 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2536 if len(incorrect_iter.shape) == 0:
2537 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2538
2539 incorrect_acc = deepcopy(acc)
2540 for i in range(len(incorrect_acc.shape)):
2541 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2542
Eric Kunzee5e26762020-10-13 16:11:07 -07002543 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002544 self.ser.addBasicBlock(cond_block)
2545
Matthew Haddon630c17c2021-10-14 15:05:41 +01002546 if error_name == ErrorIf.InputListCondGraphMismatch:
2547 self.ser.addInputTensor(incorrect_iter)
2548 self.ser.addInputTensor(a)
2549 self.ser.addInputTensor(incorrect_acc)
2550 else:
2551 self.ser.addInputTensor(iter)
2552 self.ser.addInputTensor(a)
2553 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002554 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002555
2556 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002557 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002558 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002559 cond_type = DType.BOOL
2560 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2561 choice = self.rng.choice([1, 2])
2562 if choice == 1:
2563 cond_shape = [3]
2564 else:
2565 cond_shape = [1, 2]
2566 else:
2567 cond_shape = []
2568 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002569
Kevin Cheng550ccc52021-03-03 11:21:43 -08002570 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002571
2572 # BODY block (input: a, acc, iter, output: a, acc, iter)
2573 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002574 self.ser.addBasicBlock(body_block)
2575
Matthew Haddon630c17c2021-10-14 15:05:41 +01002576 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2577 self.ser.addInputTensor(incorrect_iter)
2578 self.ser.addInputTensor(a)
2579 self.ser.addInputTensor(incorrect_acc)
2580 else:
2581 self.ser.addInputTensor(iter)
2582 self.ser.addInputTensor(a)
2583 self.ser.addInputTensor(acc)
2584
Kevin Cheng550ccc52021-03-03 11:21:43 -08002585 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002586
2587 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002588 iter_body_out = self.ser.addIntermediate(
2589 incorrect_iter.shape, incorrect_iter.dtype
2590 )
2591 acc_body_out = self.ser.addIntermediate(
2592 incorrect_acc.shape, incorrect_acc.dtype
2593 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002594 else:
2595 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2596 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2597
Eric Kunzee5e26762020-10-13 16:11:07 -07002598 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2599 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2600 self.ser.addOutputTensor(iter_body_out)
2601 self.ser.addOutputTensor(a)
2602 self.ser.addOutputTensor(acc_body_out)
2603
Les Bell729b0352021-11-24 10:28:21 +00002604 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002605 self.ser,
2606 validator_fcns,
2607 error_name,
2608 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002609 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002610 ):
2611 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002612
Jeremy Johnson587cc842024-02-08 11:45:44 +00002613 compliance = self.tensorComplianceMetaData(
2614 op, a.dtype, args_dict, acc_out, error_name
2615 )
2616
2617 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002618
Luke Hutton57287132023-02-06 14:54:18 +00002619 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002620 self,
2621 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002622 inputs,
2623 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002624 validator_fcns=None,
2625 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002626 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002627 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002628 assert len(inputs) == 2
2629 val1, val2 = inputs
2630 inverse = args_dict["inverse"]
2631
Luke Hutton57287132023-02-06 14:54:18 +00002632 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2633
2634 input_names = [val1.name, val2.name]
2635 pCount, cCount = op["operands"]
2636 num_operands = pCount + cCount
2637
2638 output_names = [res.name for res in results]
2639 output_shapes = [res.shape for res in results]
2640 output_dtypes = [res.dtype for res in results]
2641
2642 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2643 self, error_name, input_names, output_names
2644 )
2645
2646 if not TosaErrorValidator.evValidateErrorIfs(
2647 self.ser,
2648 validator_fcns,
2649 error_name,
2650 op=op,
2651 inverse=inverse,
2652 input1=val1,
2653 input2=val2,
2654 input_shape=val1.shape,
2655 input_dtype=val1.dtype,
2656 output_shape=output_shapes,
2657 output_dtype=output_dtypes,
2658 result_tensors=results,
2659 input_list=input_names,
2660 output_list=output_names,
2661 num_operands=num_operands,
2662 ):
2663 return None
2664
Tai Lyd3797f02023-11-15 23:06:19 +00002665 # TODO - Test local_bound, for now set local bound attribute to False
2666 local_bound = False
2667
Luke Hutton57287132023-02-06 14:54:18 +00002668 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002669 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002670
2671 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002672
2673 compliance = []
2674 for res in results:
2675 compliance.append(
2676 self.tensorComplianceMetaData(
2677 op, val1.dtype, args_dict, res, error_name
2678 )
2679 )
2680
2681 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002682
Tai Lyd3797f02023-11-15 23:06:19 +00002683 def build_rfft2d(
2684 self,
2685 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002686 inputs,
2687 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002688 validator_fcns=None,
2689 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002690 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002691 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002692 assert len(inputs) == 1
2693 val = inputs[0]
Luke Hutton261b7b62023-01-10 14:50:31 +00002694 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2695
2696 input_names = [val.name]
2697 pCount, cCount = op["operands"]
2698 num_operands = pCount + cCount
2699
2700 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002701 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002702 output_dtypes = [res.dtype for res in results]
2703
2704 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2705 self, error_name, input_names, output_names
2706 )
2707
2708 if not TosaErrorValidator.evValidateErrorIfs(
2709 self.ser,
2710 validator_fcns,
2711 error_name,
2712 op=op,
2713 input_shape=val.shape,
2714 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002715 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002716 output_dtype=output_dtypes,
2717 result_tensors=results,
2718 input_list=input_names,
2719 output_list=output_names,
2720 num_operands=num_operands,
2721 ):
2722 return None
2723
Tai Lyd3797f02023-11-15 23:06:19 +00002724 # TODO - Test local_bound, for now set local bound attribute to False
2725 local_bound = False
2726
2727 attr = ts.TosaSerializerAttribute()
2728 attr.RFFTAttribute(local_bound)
2729
2730 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002731
2732 compliance = []
2733 for res in results:
2734 compliance.append(
2735 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2736 )
2737
2738 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002739
Won Jeon74342e52024-01-09 00:34:40 +00002740 def build_shape_op(
2741 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2742 ):
2743 assert len(inputs) == 2
2744 a, b = inputs
2745
2746 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2747
2748 # Invalidate Input/Output list for error if checks.
2749 input_list = [a.name, b.name]
2750 output_list = [result_tensor.name]
2751 pCount, cCount = op["operands"]
2752 num_operands = pCount + cCount
2753 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2754 self, error_name, input_list, output_list
2755 )
2756
2757 if not TosaErrorValidator.evValidateErrorIfs(
2758 self.ser,
2759 validator_fcns,
2760 error_name,
2761 op=op,
2762 input1=a,
2763 input2=b,
2764 input_shape=a.shape,
2765 input_dtype=a.dtype,
2766 output_shape=result_tensor.shape,
2767 output_dtype=result_tensor.dtype,
2768 result_tensors=[result_tensor],
2769 input_list=input_list,
2770 output_list=output_list,
2771 num_operands=num_operands,
2772 ):
2773 return None
2774
2775 self.ser.addOperator(
2776 op["op"],
2777 input_list,
2778 output_list,
2779 )
2780 compliance = self.tensorComplianceMetaData(
2781 op, a.dtype, args_dict, result_tensor, error_name
2782 )
2783
2784 return TosaTestGen.BuildInfo(result_tensor, compliance)
2785
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002786 def create_filter_lists(
2787 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2788 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002789 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2790 default_test_rank_range = range(1, 5)
2791 if not shapeFilter:
2792 shapeFilter = [None]
2793
2794 # Calculate the filters based on what is requested and what the operator allows
2795 rmin, rmax = op["rank"]
2796 if rankFilter is not None:
2797 cleanRankFilter = []
2798 # Ensure rankFilter values are allowed by operator
2799 for rank in rankFilter:
2800 if rank >= rmin and rank <= rmax:
2801 cleanRankFilter.append(rank)
2802 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002803 # Ensure default behaviour is bounded by default range or by operator,
2804 # whichever is the smaller range of ranks.
2805 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002806 cleanRankFilter = (
2807 opRankRange
2808 if len(opRankRange) <= len(default_test_rank_range)
2809 else default_test_rank_range
2810 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002811 else:
2812 cleanRankFilter = range(rmin, rmax + 1)
2813
2814 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002815
Matthew Haddon1c00b712021-10-01 15:51:03 +01002816 if dtypeFilter is not None:
2817 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002818 # Create list of operator dtypes filtered by requested dtypes
2819 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002820 if dtype in dtypeFilter or (
2821 isinstance(dtype, list) and dtype[0] in dtypeFilter
2822 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002823 cleanDtypeFilter.append(dtype)
2824 else:
2825 cleanDtypeFilter = dtypes
2826
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002827 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002828 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002829 "shapeFilter": shapeFilter,
2830 "rankFilter": cleanRankFilter,
2831 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002832 }
2833 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002834 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002835 if validator is not None:
2836 validator_info = validator(check=False, op=op)
2837 else:
2838 return None
2839
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002840 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002841
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002842 # Set parameters as required
2843 if error_arguments["rank"] is not None:
2844 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002845 else:
2846 rankFilter = cleanRankFilter
2847
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002848 if error_arguments["dtype"] is not None:
2849 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002850 else:
2851 dtypeFilter = cleanDtypeFilter
2852
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002853 if error_arguments["shape"] is not None:
2854 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002855 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002856 shapeFilter = shapeFilter[
2857 :2
2858 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002859
2860 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002861 "shapeFilter": shapeFilter,
2862 "rankFilter": rankFilter,
2863 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002864 }
2865 return filterDict
2866
Kevin Cheng550ccc52021-03-03 11:21:43 -08002867 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002868 self,
2869 opName,
2870 shapeFilter=[None],
2871 rankFilter=None,
2872 dtypeFilter=None,
2873 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002874 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002875
2876 try:
2877 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002878 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002879 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002880
2881 # Initialize a new random number generator
2882 self.rng = np.random.default_rng(self.random_seed)
2883
Jeremy Johnson1271c442023-09-05 11:39:26 +01002884 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002885
Eric Kunzee5e26762020-10-13 16:11:07 -07002886 # Test list consists of a tuple of:
2887 # (opName, testNameStr, dtype, shapeList, argumentsList)
2888 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002889 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002890 error_if_validators = op["error_if_validators"]
2891 else:
2892 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002893
Matthew Haddon1c00b712021-10-01 15:51:03 +01002894 for validator in error_if_validators:
2895 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002896 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002897 else:
2898 error_name = None
2899
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002900 filterDict = self.create_filter_lists(
2901 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2902 )
2903 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002904 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002905 cleanRankFilter = filterDict["rankFilter"]
2906 cleanDtypeFilter = filterDict["dtypeFilter"]
2907 cleanShapeFilter = filterDict["shapeFilter"]
2908 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002909
2910 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002911 for t in cleanDtypeFilter:
2912 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002913 # Filter out by rank
2914 if shape is not None and len(shape) != r:
2915 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002916 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002917 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002918
Matthew Haddon74567092021-07-16 15:38:20 +01002919 shapeStr = self.shapeStr(shapeList[0])
2920 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002921
Matthew Haddon74567092021-07-16 15:38:20 +01002922 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2923 argList = []
2924 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002925 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002926 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002927 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002928
Matthew Haddon74567092021-07-16 15:38:20 +01002929 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002930 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002931 if argStr:
2932 testStr = "{}_{}_{}_{}".format(
2933 opName, shapeStr, typeStr, argStr
2934 )
2935 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002936 testStr = "{}_{}_{}".format(
2937 opName, shapeStr, typeStr
2938 )
2939 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002940 if argStr:
2941 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2942 opName, error_name, shapeStr, typeStr, argStr
2943 )
2944 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002945 testStr = "{}_ERRORIF_{}_{}_{}".format(
2946 opName, error_name, shapeStr, typeStr
2947 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002948
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002949 testList.append(
2950 (opName, testStr, t, error_name, shapeList, args)
2951 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002952
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002953 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002954 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2955 if "invalid_test_validators" in op:
2956 invalid_test_validators = op["invalid_test_validators"]
2957 clean_testList = []
2958 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002959 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002960 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002961 if validator_fcn(
2962 opName=test[0],
2963 input_dtype=test[2],
2964 shapeList=test[4],
2965 args=test[5],
2966 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002967 remove_test = True
2968 if not remove_test:
2969 clean_testList.append(test)
2970 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002971
2972 return testList
2973
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002974 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002975 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002976 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002977 try:
2978 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002979 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002980 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002981
Jeremy Johnson0c716862023-04-13 17:18:19 +01002982 if self.args.verbose:
2983 print(f"Creating {testStr}")
2984
Eric Kunzee5e26762020-10-13 16:11:07 -07002985 # Create a serializer
2986 self.createSerializer(opName, testStr)
2987
Jeremy Johnson1271c442023-09-05 11:39:26 +01002988 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002989 if "error_if_validators" in op:
2990 error_if_validators = op["error_if_validators"]
2991 else:
2992 error_if_validators = None
2993
Kevin Cheng550ccc52021-03-03 11:21:43 -08002994 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002995 num_operands = pCount + cCount
2996
2997 if isinstance(dtype_or_dtypeList, list):
2998 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00002999 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003000 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003001 else:
3002 dtypeList = [dtype_or_dtypeList] * (num_operands)
3003
Won Jeon74342e52024-01-09 00:34:40 +00003004 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003005 assert (
3006 len(shapeList) == num_operands
3007 ), "shapeList length {} must match number of operands {}".format(
3008 len(shapeList), num_operands
3009 )
3010 assert (
3011 len(dtypeList) == num_operands
3012 ), "dtypeList length {} must match number of operands {}".format(
3013 len(dtypeList), num_operands
3014 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003015
3016 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003017 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003018 except KeyError:
3019 qgen = None
3020
3021 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003022
Matthew Haddon1c00b712021-10-01 15:51:03 +01003023 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003024 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003025 else:
3026 qinfo = None
3027
Jeremy Johnson1271c442023-09-05 11:39:26 +01003028 # Extra meta data for the desc.json
3029 tensMeta = {}
3030
Jeremy Johnson587cc842024-02-08 11:45:44 +00003031 # Check we are using the new interface with an argsDict dictionary
3032 assert isinstance(
3033 argsDict, dict
3034 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003035
Jeremy Johnson587cc842024-02-08 11:45:44 +00003036 # New interface with args info in dictionary
3037 assert "dg_type" in argsDict
3038 tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name)
3039 if tvgInfo.dataGenDict:
3040 tensMeta["data_gen"] = tvgInfo.dataGenDict
3041 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003042
Jeremy Johnson587cc842024-02-08 11:45:44 +00003043 result = build_fcn(
3044 self,
3045 op,
3046 tens,
3047 argsDict,
3048 validator_fcns=error_if_validators,
3049 error_name=error_name,
3050 qinfo=qinfo,
3051 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003052
Jeremy Johnson1271c442023-09-05 11:39:26 +01003053 if result:
Les Bell729b0352021-11-24 10:28:21 +00003054 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003055 if isinstance(result, TosaTestGen.BuildInfo):
3056 # Add the compliance meta data (if any)
3057 compliance = result.getComplianceInfo()
3058 if compliance:
3059 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003060 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00003061 else:
3062 # The test is not valid
3063 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01003064
Eric Kunzee5e26762020-10-13 16:11:07 -07003065 def createDynamicOpLists(self):
3066
Jeremy Johnson00423432022-09-12 17:27:37 +01003067 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3068 # Already created these lists (can occur when class is initialized more than once)
3069 return
3070
Eric Kunzee5e26762020-10-13 16:11:07 -07003071 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003072 if not self.args.level8k:
3073 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3074 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3075 else:
3076 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3077 KERNELS_2D = [[1, bigK], [bigK, 2]]
3078 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003079
Kevin Cheng1533b852021-09-01 12:51:58 -07003080 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003081 testName = "conv2d_{}x{}".format(k[0], k[1])
3082 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3083 self.TOSA_OP_LIST[testName]["filter"] = k
3084 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003085
Kevin Cheng550ccc52021-03-03 11:21:43 -08003086 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3087 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3088 "depthwise_conv2d_TEMPLATE"
3089 ].copy()
3090 self.TOSA_OP_LIST[testName]["filter"] = k
3091 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003092
Kevin Cheng550ccc52021-03-03 11:21:43 -08003093 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3094 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3095 "transpose_conv2d_TEMPLATE"
3096 ].copy()
3097 self.TOSA_OP_LIST[testName]["filter"] = k
3098 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003099
Kevin Cheng1533b852021-09-01 12:51:58 -07003100 for k in KERNELS_3D:
3101 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3102 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3103 self.TOSA_OP_LIST[testName]["filter"] = k
3104 self.TOSA_OP_LIST[testName]["template"] = False
3105
Eric Kunzee5e26762020-10-13 16:11:07 -07003106 # Delete any templates after having created any dynamic ops
3107 # This is a two-pass operation because it's bad practice to delete
3108 # keys from dictionaries while iterating
3109 keyList = []
3110 for k in self.TOSA_OP_LIST:
3111 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003112 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003113 keyList.append(k)
3114 continue
3115 except KeyError:
3116 pass
3117
3118 for k in keyList:
3119 del self.TOSA_OP_LIST[k]
3120
3121 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003122 """Fill in default fields for ops if they aren't already specified.
3123 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003124 for op in self.TOSA_OP_LIST:
3125
3126 # Required fields
3127 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003128 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003129 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003130 raise Exception(
3131 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3132 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003133
3134 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003135 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003136 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003137 raise Exception(
3138 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3139 op
3140 )
3141 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003142
3143 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003144 _ = self.TOSA_OP_LIST[op]["types"]
3145 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003146 raise Exception(
3147 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3148 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003149
3150 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003151 _ = self.TOSA_OP_LIST[op]["op"]
3152 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003153 raise Exception(
3154 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3155 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003156
3157 # Put in default rank range, if missing
3158 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003159 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003160 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003161 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003162
3163 # Tensor operator list
3164 # 'op': op name
3165 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003166 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3167 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003168 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3169 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003170 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003171
Kevin Cheng550ccc52021-03-03 11:21:43 -08003172 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003173 TYPE_INT_FP = [
3174 DType.INT8,
3175 DType.INT16,
3176 DType.INT32,
3177 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003178 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003179 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003180 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003181
Kevin Cheng550ccc52021-03-03 11:21:43 -08003182 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003183 TYPE_FI32 = [
3184 DType.FP32,
3185 DType.FP16,
3186 DType.BF16,
3187 DType.INT32,
3188 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003189 TYPE_FIB = [
3190 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003191 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003192 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003193 DType.INT8,
3194 DType.INT16,
3195 DType.INT32,
3196 DType.BOOL,
3197 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003198 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003199
Won Jeon2c34b462024-02-06 18:37:00 +00003200 TYPE_NARROW_INT_FP = [
3201 DType.INT8,
3202 DType.INT16,
3203 DType.FP16,
3204 DType.BF16,
3205 DType.FP32,
3206 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003207
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003208 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003209 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003210 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003211 [DType.INT8, DType.INT8, DType.INT32],
3212 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003213 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003214 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003215 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003216 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003217 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3218 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003219 ]
3220
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003221 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003222
3223 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003224 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003225 "argmax": {
3226 "op": Op.ARGMAX,
3227 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003228 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003229 "build_fcn": (
3230 build_argmax,
3231 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003232 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003233 TosaArgGen.agAxis,
3234 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003235 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003236 "error_if_validators": (
3237 TosaErrorValidator.evAxisSmallerZero,
3238 TosaErrorValidator.evAxisLargerRank,
3239 TosaErrorValidator.evArgmaxOutputRankMismatch,
3240 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3241 TosaErrorValidator.evWrongRank,
3242 TosaErrorValidator.evWrongInputType,
3243 TosaErrorValidator.evWrongOutputType,
3244 TosaErrorValidator.evWrongInputList,
3245 TosaErrorValidator.evWrongOutputList,
3246 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003247 "data_gen": {
3248 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3249 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003250 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003251 "avg_pool2d": {
3252 "op": Op.AVG_POOL2D,
3253 "operands": (1, 0),
3254 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003255 "build_fcn": (
3256 build_pool2d,
3257 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003258 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003259 TosaArgGen.agPooling,
3260 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003261 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003262 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003263 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003264 "error_if_validators": (
3265 TosaErrorValidator.evKernelSmallerOne,
3266 TosaErrorValidator.evStrideSmallerOne,
3267 TosaErrorValidator.evPadSmallerZero,
3268 TosaErrorValidator.evWrongRank,
3269 TosaErrorValidator.evWrongInputType,
3270 TosaErrorValidator.evWrongOutputType,
3271 TosaErrorValidator.evWrongInputList,
3272 TosaErrorValidator.evWrongOutputList,
3273 TosaErrorValidator.evInputZeroPointNotZero,
3274 TosaErrorValidator.evOutputZeroPointNotZero,
3275 TosaErrorValidator.evPadLargerEqualKernel,
3276 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003277 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003278 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003279 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003280 "data_gen": {
3281 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3282 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003283 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003284 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003285 "conv2d_TEMPLATE": {
3286 "op": Op.CONV2D,
3287 "operands": (1, 2),
3288 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003289 "build_fcn": (
3290 build_conv2d,
3291 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003292 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003293 TosaArgGen.agConv,
3294 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003295 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003296 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003297 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3298 "error_if_validators": (
3299 TosaErrorValidator.evWrongInputType,
3300 TosaErrorValidator.evWrongOutputType,
3301 TosaErrorValidator.evWrongInputList,
3302 TosaErrorValidator.evWrongOutputList,
3303 TosaErrorValidator.evInputZeroPointNotZero,
3304 TosaErrorValidator.evWeightZeroPointNotZero,
3305 TosaErrorValidator.evPadSmallerZero,
3306 TosaErrorValidator.evStrideSmallerOne,
3307 TosaErrorValidator.evDilationSmallerOne,
3308 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003309 TosaErrorValidator.evConvOutputShapeMismatch,
3310 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003311 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003312 "data_gen": {
3313 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3314 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003315 "template": True,
3316 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003317 # Templated operator. Filled in by createDynamicOpLists
3318 "conv3d_TEMPLATE": {
3319 "op": Op.CONV3D,
3320 "operands": (1, 2),
3321 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003322 "build_fcn": (
3323 build_conv3d,
3324 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003325 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003326 TosaArgGen.agConv,
3327 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003328 "qgen": TosaQuantGen.qgConv,
3329 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003330 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3331 "error_if_validators": (
3332 TosaErrorValidator.evWrongInputType,
3333 TosaErrorValidator.evWrongOutputType,
3334 TosaErrorValidator.evWrongInputList,
3335 TosaErrorValidator.evWrongOutputList,
3336 TosaErrorValidator.evInputZeroPointNotZero,
3337 TosaErrorValidator.evWeightZeroPointNotZero,
3338 TosaErrorValidator.evPadSmallerZero,
3339 TosaErrorValidator.evStrideSmallerOne,
3340 TosaErrorValidator.evDilationSmallerOne,
3341 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003342 TosaErrorValidator.evConvOutputShapeMismatch,
3343 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003344 ),
evacha0147ab1762024-01-29 13:23:23 +00003345 "data_gen": {
3346 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3347 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003348 "template": True,
3349 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003350 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003351 "depthwise_conv2d_TEMPLATE": {
3352 "op": Op.DEPTHWISE_CONV2D,
3353 "operands": (1, 2),
3354 "filter": [1, 1],
3355 "rank": (4, 4),
3356 "build_fcn": (
3357 build_depthwise_conv2d,
3358 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003359 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003360 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003361 ),
3362 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003363 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003364 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3365 "error_if_validators": (
3366 TosaErrorValidator.evWrongInputType,
3367 TosaErrorValidator.evWrongOutputType,
3368 TosaErrorValidator.evWrongInputList,
3369 TosaErrorValidator.evWrongOutputList,
3370 TosaErrorValidator.evInputZeroPointNotZero,
3371 TosaErrorValidator.evWeightZeroPointNotZero,
3372 TosaErrorValidator.evPadSmallerZero,
3373 TosaErrorValidator.evStrideSmallerOne,
3374 TosaErrorValidator.evDilationSmallerOne,
3375 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003376 TosaErrorValidator.evConvOutputShapeMismatch,
3377 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003378 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003379 "data_gen": {
3380 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3381 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003382 "template": True,
3383 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003384 "fully_connected": {
3385 "op": Op.FULLY_CONNECTED,
3386 "operands": (1, 2),
3387 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003388 "build_fcn": (
3389 build_fully_connected,
3390 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003391 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003392 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003393 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003394 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003395 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003396 "error_if_validators": (
3397 TosaErrorValidator.evInputZeroPointNotZero,
3398 TosaErrorValidator.evWeightZeroPointNotZero,
3399 TosaErrorValidator.evWrongRank,
3400 TosaErrorValidator.evWrongInputType,
3401 TosaErrorValidator.evWrongOutputType,
3402 TosaErrorValidator.evWrongInputList,
3403 TosaErrorValidator.evWrongOutputList,
3404 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003405 "data_gen": {
3406 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3407 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003408 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003409 "matmul": {
3410 "op": Op.MATMUL,
3411 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003412 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003413 "build_fcn": (
3414 build_matmul,
3415 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003416 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003417 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003418 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003419 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003420 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003421 "error_if_validators": (
3422 TosaErrorValidator.evInputZeroPointNotZero,
3423 TosaErrorValidator.evWrongRank,
3424 TosaErrorValidator.evWrongInputType,
3425 TosaErrorValidator.evWrongOutputType,
3426 TosaErrorValidator.evWrongInputList,
3427 TosaErrorValidator.evWrongOutputList,
3428 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003429 "data_gen": {
3430 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003431 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003432 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003433 "max_pool2d": {
3434 "op": Op.MAX_POOL2D,
3435 "operands": (1, 0),
3436 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003437 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003438 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003439 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003440 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003441 TosaArgGen.agPooling,
3442 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003443 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003444 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003445 "error_if_validators": (
3446 TosaErrorValidator.evKernelSmallerOne,
3447 TosaErrorValidator.evStrideSmallerOne,
3448 TosaErrorValidator.evPadSmallerZero,
3449 TosaErrorValidator.evWrongRank,
3450 TosaErrorValidator.evWrongInputType,
3451 TosaErrorValidator.evWrongOutputType,
3452 TosaErrorValidator.evWrongInputList,
3453 TosaErrorValidator.evWrongOutputList,
3454 TosaErrorValidator.evPadLargerEqualKernel,
3455 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003456 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003457 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003458 "data_gen": {
3459 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3460 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003461 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003462 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003463 "transpose_conv2d_TEMPLATE": {
3464 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003465 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003466 "rank": (4, 4),
3467 "build_fcn": (
3468 build_transpose_conv2d,
3469 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003470 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003471 TosaArgGen.agTransposeConv2D,
3472 ),
3473 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003474 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003475 "invalid_test_validators": (
3476 TosaInvalidValidator.ivHeightWidthInvalid,
3477 TosaInvalidValidator.ivNonPositiveOutputShape,
3478 ),
3479 "error_if_validators": (
3480 TosaErrorValidator.evWrongInputType,
3481 TosaErrorValidator.evWrongOutputType,
3482 TosaErrorValidator.evWrongInputList,
3483 TosaErrorValidator.evWrongOutputList,
3484 TosaErrorValidator.evInputZeroPointNotZero,
3485 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003486 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003487 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003488 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003489 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003490 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003491 "data_gen": {
3492 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3493 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003494 "template": True,
3495 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003496 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003497 "clamp": {
3498 "op": Op.CLAMP,
3499 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003500 "build_fcn": (
3501 build_clamp,
3502 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003503 TosaTensorValuesGen.tvgLazyGenDefault,
3504 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003505 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003506 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003507 "error_if_validators": (
3508 TosaErrorValidator.evMaxSmallerMin,
3509 TosaErrorValidator.evWrongInputType,
3510 TosaErrorValidator.evWrongOutputType,
3511 TosaErrorValidator.evWrongInputList,
3512 TosaErrorValidator.evWrongOutputList,
3513 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003514 "data_gen": {
3515 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3516 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003517 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003518 "sigmoid": {
3519 "op": Op.SIGMOID,
3520 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003521 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003522 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003523 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003524 TosaTensorValuesGen.tvgLazyGenDefault,
3525 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003526 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003527 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003528 "error_if_validators": (
3529 TosaErrorValidator.evWrongInputType,
3530 TosaErrorValidator.evWrongOutputType,
3531 TosaErrorValidator.evWrongInputList,
3532 TosaErrorValidator.evWrongOutputList,
3533 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003534 "data_gen": {
3535 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3536 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003537 },
3538 "tanh": {
3539 "op": Op.TANH,
3540 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003541 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003542 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003543 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003544 TosaTensorValuesGen.tvgLazyGenDefault,
3545 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003546 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003547 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003548 "error_if_validators": (
3549 TosaErrorValidator.evWrongInputType,
3550 TosaErrorValidator.evWrongOutputType,
3551 TosaErrorValidator.evWrongInputList,
3552 TosaErrorValidator.evWrongOutputList,
3553 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003554 "data_gen": {
3555 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3556 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003557 "compliance": {
3558 "abs_error_lower_bound": 0.5,
3559 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003560 },
Won Jeon78155c62023-06-10 00:20:04 +00003561 "erf": {
3562 "op": Op.ERF,
3563 "operands": (1, 0),
3564 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003565 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003566 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003567 TosaTensorValuesGen.tvgLazyGenDefault,
3568 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003569 ),
3570 "types": TYPE_FP,
3571 "error_if_validators": (
3572 TosaErrorValidator.evWrongInputType,
3573 TosaErrorValidator.evWrongOutputType,
3574 TosaErrorValidator.evWrongInputList,
3575 TosaErrorValidator.evWrongOutputList,
3576 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003577 "data_gen": {
3578 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3579 },
3580 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003581 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003582 # Elementwise Binary Operators
3583 "add": {
3584 "op": Op.ADD,
3585 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003586 "build_fcn": (
3587 build_binary_broadcast,
3588 TosaTensorGen.tgBroadcastFuzz,
3589 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003590 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003591 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003592 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003593 "error_if_validators": (
3594 TosaErrorValidator.evRankMismatch,
3595 TosaErrorValidator.evWrongInputType,
3596 TosaErrorValidator.evWrongOutputType,
3597 TosaErrorValidator.evWrongInputList,
3598 TosaErrorValidator.evWrongOutputList,
3599 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003600 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003601 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003602 "data_gen": {
3603 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3604 },
3605 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003606 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003607 "arithmetic_right_shift": {
3608 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3609 "operands": (2, 0),
3610 "build_fcn": (
3611 build_arithmetic_right_shift,
3612 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003613 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003614 TosaArgGen.agArithmeticRightShift,
3615 ),
3616 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003617 "error_if_validators": (
3618 TosaErrorValidator.evRankMismatch,
3619 TosaErrorValidator.evWrongInputType,
3620 TosaErrorValidator.evWrongOutputType,
3621 TosaErrorValidator.evWrongInputList,
3622 TosaErrorValidator.evWrongOutputList,
3623 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003624 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003625 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003626 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003627 "bitwise_and": {
3628 "op": Op.BITWISE_AND,
3629 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003630 "build_fcn": (
3631 build_binary_broadcast,
3632 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003633 TosaTensorValuesGen.tvgLazyGenDefault,
3634 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003635 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003636 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003637 "error_if_validators": (
3638 TosaErrorValidator.evRankMismatch,
3639 TosaErrorValidator.evWrongInputType,
3640 TosaErrorValidator.evWrongOutputType,
3641 TosaErrorValidator.evWrongInputList,
3642 TosaErrorValidator.evWrongOutputList,
3643 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003644 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003645 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003646 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003647 "bitwise_or": {
3648 "op": Op.BITWISE_OR,
3649 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003650 "build_fcn": (
3651 build_binary_broadcast,
3652 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003653 TosaTensorValuesGen.tvgLazyGenDefault,
3654 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003655 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003656 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003657 "error_if_validators": (
3658 TosaErrorValidator.evRankMismatch,
3659 TosaErrorValidator.evWrongInputType,
3660 TosaErrorValidator.evWrongOutputType,
3661 TosaErrorValidator.evWrongInputList,
3662 TosaErrorValidator.evWrongOutputList,
3663 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003664 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003665 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003666 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003667 "bitwise_xor": {
3668 "op": Op.BITWISE_XOR,
3669 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003670 "build_fcn": (
3671 build_binary_broadcast,
3672 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003673 TosaTensorValuesGen.tvgLazyGenDefault,
3674 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003675 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003676 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003677 "error_if_validators": (
3678 TosaErrorValidator.evRankMismatch,
3679 TosaErrorValidator.evWrongInputType,
3680 TosaErrorValidator.evWrongOutputType,
3681 TosaErrorValidator.evWrongInputList,
3682 TosaErrorValidator.evWrongOutputList,
3683 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003684 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003685 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003686 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003687 "intdiv": {
3688 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003689 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003690 "build_fcn": (
3691 build_binary_broadcast,
3692 TosaTensorGen.tgBroadcastFuzz,
3693 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003694 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003695 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003696 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003697 "error_if_validators": (
3698 TosaErrorValidator.evRankMismatch,
3699 TosaErrorValidator.evWrongInputType,
3700 TosaErrorValidator.evWrongOutputType,
3701 TosaErrorValidator.evWrongInputList,
3702 TosaErrorValidator.evWrongOutputList,
3703 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003704 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003705 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003706 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003707 "logical_and": {
3708 "op": Op.LOGICAL_AND,
3709 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003710 "build_fcn": (
3711 build_binary_broadcast,
3712 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003713 TosaTensorValuesGen.tvgLazyGenDefault,
3714 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003715 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003716 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003717 "error_if_validators": (
3718 TosaErrorValidator.evRankMismatch,
3719 TosaErrorValidator.evWrongInputType,
3720 TosaErrorValidator.evWrongOutputType,
3721 TosaErrorValidator.evWrongInputList,
3722 TosaErrorValidator.evWrongOutputList,
3723 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003724 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003725 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003726 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003727 "logical_left_shift": {
3728 "op": Op.LOGICAL_LEFT_SHIFT,
3729 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003730 "build_fcn": (
3731 build_binary_broadcast,
3732 TosaTensorGen.tgBroadcastFuzz,
3733 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003734 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003735 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003736 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003737 "error_if_validators": (
3738 TosaErrorValidator.evRankMismatch,
3739 TosaErrorValidator.evWrongInputType,
3740 TosaErrorValidator.evWrongOutputType,
3741 TosaErrorValidator.evWrongInputList,
3742 TosaErrorValidator.evWrongOutputList,
3743 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003744 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003745 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003746 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003747 "logical_right_shift": {
3748 "op": Op.LOGICAL_RIGHT_SHIFT,
3749 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003750 "build_fcn": (
3751 build_binary_broadcast,
3752 TosaTensorGen.tgBroadcastFuzz,
3753 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003754 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003755 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003756 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003757 "error_if_validators": (
3758 TosaErrorValidator.evRankMismatch,
3759 TosaErrorValidator.evWrongInputType,
3760 TosaErrorValidator.evWrongOutputType,
3761 TosaErrorValidator.evWrongInputList,
3762 TosaErrorValidator.evWrongOutputList,
3763 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003764 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003765 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003766 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003767 "logical_or": {
3768 "op": Op.LOGICAL_OR,
3769 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003770 "build_fcn": (
3771 build_binary_broadcast,
3772 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003773 TosaTensorValuesGen.tvgLazyGenDefault,
3774 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003775 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003776 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003777 "error_if_validators": (
3778 TosaErrorValidator.evRankMismatch,
3779 TosaErrorValidator.evWrongInputType,
3780 TosaErrorValidator.evWrongOutputType,
3781 TosaErrorValidator.evWrongInputList,
3782 TosaErrorValidator.evWrongOutputList,
3783 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003784 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003785 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003786 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003787 "logical_xor": {
3788 "op": Op.LOGICAL_XOR,
3789 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003790 "build_fcn": (
3791 build_binary_broadcast,
3792 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003793 TosaTensorValuesGen.tvgLazyGenDefault,
3794 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003795 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003796 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003797 "error_if_validators": (
3798 TosaErrorValidator.evRankMismatch,
3799 TosaErrorValidator.evWrongInputType,
3800 TosaErrorValidator.evWrongOutputType,
3801 TosaErrorValidator.evWrongInputList,
3802 TosaErrorValidator.evWrongOutputList,
3803 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003804 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003805 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003806 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003807 "maximum": {
3808 "op": Op.MAXIMUM,
3809 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003810 "build_fcn": (
3811 build_binary_broadcast,
3812 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003813 TosaTensorValuesGen.tvgLazyGenDefault,
3814 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003815 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003816 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003817 "error_if_validators": (
3818 TosaErrorValidator.evRankMismatch,
3819 TosaErrorValidator.evWrongInputType,
3820 TosaErrorValidator.evWrongOutputType,
3821 TosaErrorValidator.evWrongInputList,
3822 TosaErrorValidator.evWrongOutputList,
3823 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003824 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003825 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003826 "data_gen": {
3827 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3828 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003829 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003830 "minimum": {
3831 "op": Op.MINIMUM,
3832 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003833 "build_fcn": (
3834 build_binary_broadcast,
3835 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003836 TosaTensorValuesGen.tvgLazyGenDefault,
3837 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003838 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003839 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003840 "error_if_validators": (
3841 TosaErrorValidator.evRankMismatch,
3842 TosaErrorValidator.evWrongInputType,
3843 TosaErrorValidator.evWrongOutputType,
3844 TosaErrorValidator.evWrongInputList,
3845 TosaErrorValidator.evWrongOutputList,
3846 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003847 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003848 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003849 "data_gen": {
3850 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3851 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003852 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003853 "mul": {
3854 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003855 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003856 "build_fcn": (
3857 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003858 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003859 TosaTensorValuesGen.tvgMul,
3860 TosaArgGen.agMul,
3861 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003862 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003863 "error_if_validators": (
3864 TosaErrorValidator.evWrongInputType,
3865 TosaErrorValidator.evWrongOutputType,
3866 TosaErrorValidator.evWrongInputList,
3867 TosaErrorValidator.evWrongOutputList,
3868 TosaErrorValidator.evRankMismatch,
3869 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003870 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003871 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003872 "data_gen": {
3873 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3874 },
3875 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003876 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003877 "pow": {
3878 "op": Op.POW,
3879 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003880 "build_fcn": (
3881 build_binary_broadcast,
3882 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003883 TosaTensorValuesGen.tvgPow,
3884 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003885 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003886 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003887 "error_if_validators": (
3888 TosaErrorValidator.evRankMismatch,
3889 TosaErrorValidator.evWrongInputType,
3890 TosaErrorValidator.evWrongOutputType,
3891 TosaErrorValidator.evWrongInputList,
3892 TosaErrorValidator.evWrongOutputList,
3893 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003894 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003895 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003896 "data_gen": {
3897 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3898 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003899 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003900 "sub": {
3901 "op": Op.SUB,
3902 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003903 "build_fcn": (
3904 build_binary_broadcast,
3905 TosaTensorGen.tgBroadcastFuzz,
3906 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003907 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003908 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003909 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003910 "error_if_validators": (
3911 TosaErrorValidator.evRankMismatch,
3912 TosaErrorValidator.evWrongInputType,
3913 TosaErrorValidator.evWrongOutputType,
3914 TosaErrorValidator.evWrongInputList,
3915 TosaErrorValidator.evWrongOutputList,
3916 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003917 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003918 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003919 "data_gen": {
3920 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3921 },
3922 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003923 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003924 "table": {
3925 "op": Op.TABLE,
3926 # Use the automatic generation functions to create the input array
3927 # but create the table tensor in the build function, as it may be
3928 # a different type from the input
3929 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003930 "build_fcn": (
3931 build_table,
3932 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003933 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003934 TosaArgGen.agTable,
3935 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003936 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003937 "error_if_validators": (
3938 TosaErrorValidator.evWrongInputType,
3939 TosaErrorValidator.evWrongOutputType,
3940 TosaErrorValidator.evWrongInputList,
3941 TosaErrorValidator.evWrongOutputList,
3942 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003943 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003944 # Elementwise Unary operators
3945 "abs": {
3946 "op": Op.ABS,
3947 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003948 "build_fcn": (
3949 build_unary,
3950 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003951 TosaTensorValuesGen.tvgLazyGenDefault,
3952 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003953 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003954 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003955 "error_if_validators": (
3956 TosaErrorValidator.evWrongInputType,
3957 TosaErrorValidator.evWrongOutputType,
3958 TosaErrorValidator.evWrongInputList,
3959 TosaErrorValidator.evWrongOutputList,
3960 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003961 "data_gen": {
3962 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3963 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003964 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003965 "bitwise_not": {
3966 "op": Op.BITWISE_NOT,
3967 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003968 "build_fcn": (
3969 build_unary,
3970 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003971 TosaTensorValuesGen.tvgLazyGenDefault,
3972 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003973 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003974 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003975 "error_if_validators": (
3976 TosaErrorValidator.evWrongInputType,
3977 TosaErrorValidator.evWrongOutputType,
3978 TosaErrorValidator.evWrongInputList,
3979 TosaErrorValidator.evWrongOutputList,
3980 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003981 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003982 "ceil": {
3983 "op": Op.CEIL,
3984 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003985 "build_fcn": (
3986 build_unary,
3987 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003988 TosaTensorValuesGen.tvgLazyGenDefault,
3989 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003990 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003991 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003992 "error_if_validators": (
3993 TosaErrorValidator.evWrongInputType,
3994 TosaErrorValidator.evWrongOutputType,
3995 TosaErrorValidator.evWrongInputList,
3996 TosaErrorValidator.evWrongOutputList,
3997 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003998 "data_gen": {
3999 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4000 },
4001 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004002 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004003 "clz": {
4004 "op": Op.CLZ,
4005 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004006 "build_fcn": (
4007 build_unary,
4008 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004009 TosaTensorValuesGen.tvgLazyGenDefault,
4010 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004011 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004012 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004013 "error_if_validators": (
4014 TosaErrorValidator.evWrongInputType,
4015 TosaErrorValidator.evWrongOutputType,
4016 TosaErrorValidator.evWrongInputList,
4017 TosaErrorValidator.evWrongOutputList,
4018 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004019 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004020 "cos": {
4021 "op": Op.COS,
4022 "operands": (1, 0),
4023 "build_fcn": (
4024 build_unary,
4025 TosaTensorGen.tgBasic,
4026 TosaTensorValuesGen.tvgLazyGenDefault,
4027 TosaArgGen.agNone,
4028 ),
4029 "types": TYPE_FP,
4030 "error_if_validators": (
4031 TosaErrorValidator.evWrongInputType,
4032 TosaErrorValidator.evWrongOutputType,
4033 TosaErrorValidator.evWrongInputList,
4034 TosaErrorValidator.evWrongOutputList,
4035 ),
4036 "data_gen": {
4037 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4038 },
4039 "compliance": {"abs_error_normal_divisor": 2},
4040 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004041 "exp": {
4042 "op": Op.EXP,
4043 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004044 "build_fcn": (
4045 build_unary,
4046 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004047 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004048 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004049 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004050 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004051 "error_if_validators": (
4052 TosaErrorValidator.evWrongInputType,
4053 TosaErrorValidator.evWrongOutputType,
4054 TosaErrorValidator.evWrongInputList,
4055 TosaErrorValidator.evWrongOutputList,
4056 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004057 "data_gen": {
4058 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4059 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004060 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004061 "floor": {
4062 "op": Op.FLOOR,
4063 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004064 "build_fcn": (
4065 build_unary,
4066 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004067 TosaTensorValuesGen.tvgLazyGenDefault,
4068 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004069 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004070 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004071 "error_if_validators": (
4072 TosaErrorValidator.evWrongInputType,
4073 TosaErrorValidator.evWrongOutputType,
4074 TosaErrorValidator.evWrongInputList,
4075 TosaErrorValidator.evWrongOutputList,
4076 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004077 "data_gen": {
4078 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4079 },
4080 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004081 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004082 "log": {
4083 "op": Op.LOG,
4084 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004085 "build_fcn": (
4086 build_unary,
4087 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004088 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004089 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004090 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004091 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004092 "error_if_validators": (
4093 TosaErrorValidator.evWrongInputType,
4094 TosaErrorValidator.evWrongOutputType,
4095 TosaErrorValidator.evWrongInputList,
4096 TosaErrorValidator.evWrongOutputList,
4097 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004098 "data_gen": {
4099 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4100 },
4101 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004102 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004103 "logical_not": {
4104 "op": Op.LOGICAL_NOT,
4105 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004106 "build_fcn": (
4107 build_unary,
4108 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004109 TosaTensorValuesGen.tvgLazyGenDefault,
4110 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004111 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004112 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004113 "error_if_validators": (
4114 TosaErrorValidator.evWrongInputType,
4115 TosaErrorValidator.evWrongOutputType,
4116 TosaErrorValidator.evWrongInputList,
4117 TosaErrorValidator.evWrongOutputList,
4118 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004119 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004120 "negate": {
4121 "op": Op.NEGATE,
4122 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004123 "build_fcn": (
4124 build_unary,
4125 TosaTensorGen.tgBasic,
4126 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004127 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004128 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004129 "qgen": TosaQuantGen.qgUnary,
4130 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004131 "error_if_validators": (
4132 TosaErrorValidator.evInputZeroPointNotZero,
4133 TosaErrorValidator.evOutputZeroPointNotZero,
4134 TosaErrorValidator.evWrongInputType,
4135 TosaErrorValidator.evWrongOutputType,
4136 TosaErrorValidator.evWrongInputList,
4137 TosaErrorValidator.evWrongOutputList,
4138 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004139 "data_gen": {
4140 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4141 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004142 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004143 "reciprocal": {
4144 "op": Op.RECIPROCAL,
4145 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004146 "build_fcn": (
4147 build_unary,
4148 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004149 TosaTensorValuesGen.tvgLazyGenDefault,
4150 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004151 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004152 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004153 "error_if_validators": (
4154 TosaErrorValidator.evWrongInputType,
4155 TosaErrorValidator.evWrongOutputType,
4156 TosaErrorValidator.evWrongInputList,
4157 TosaErrorValidator.evWrongOutputList,
4158 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004159 "data_gen": {
4160 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4161 },
4162 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004163 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004164 "rsqrt": {
4165 "op": Op.RSQRT,
4166 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004167 "build_fcn": (
4168 build_unary,
4169 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004170 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004171 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004172 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004173 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004174 "error_if_validators": (
4175 TosaErrorValidator.evWrongInputType,
4176 TosaErrorValidator.evWrongOutputType,
4177 TosaErrorValidator.evWrongInputList,
4178 TosaErrorValidator.evWrongOutputList,
4179 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004180 "data_gen": {
4181 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4182 },
4183 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004184 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004185 "sin": {
4186 "op": Op.SIN,
4187 "operands": (1, 0),
4188 "build_fcn": (
4189 build_unary,
4190 TosaTensorGen.tgBasic,
4191 TosaTensorValuesGen.tvgLazyGenDefault,
4192 TosaArgGen.agNone,
4193 ),
4194 "types": TYPE_FP,
4195 "error_if_validators": (
4196 TosaErrorValidator.evWrongInputType,
4197 TosaErrorValidator.evWrongOutputType,
4198 TosaErrorValidator.evWrongInputList,
4199 TosaErrorValidator.evWrongOutputList,
4200 ),
4201 "data_gen": {
4202 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4203 },
4204 "compliance": {"abs_error_normal_divisor": 2},
4205 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004206 # Elementwise Ternary operators
4207 "select": {
4208 "op": Op.SELECT,
4209 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004210 "build_fcn": (
4211 build_select,
4212 TosaTensorGen.tgBroadcastFuzz,
4213 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004214 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004215 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004216 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004217 "error_if_validators": (
4218 TosaErrorValidator.evRankMismatch,
4219 TosaErrorValidator.evWrongInputType,
4220 TosaErrorValidator.evWrongOutputType,
4221 TosaErrorValidator.evWrongInputList,
4222 TosaErrorValidator.evWrongOutputList,
4223 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004224 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004225 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004226 "data_gen": {
4227 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4228 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004229 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004230 # Comparison operators
4231 "equal": {
4232 "op": Op.EQUAL,
4233 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004234 "build_fcn": (
4235 build_comparison,
4236 TosaTensorGen.tgBroadcastFuzz,
4237 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004238 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004239 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004240 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004241 "error_if_validators": (
4242 TosaErrorValidator.evRankMismatch,
4243 TosaErrorValidator.evWrongInputType,
4244 TosaErrorValidator.evWrongOutputType,
4245 TosaErrorValidator.evWrongInputList,
4246 TosaErrorValidator.evWrongOutputList,
4247 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004248 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004249 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004250 "data_gen": {
4251 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4252 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004253 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004254 "greater_equal": {
4255 "op": Op.GREATER_EQUAL,
4256 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004257 "build_fcn": (
4258 build_comparison,
4259 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004260 TosaTensorValuesGen.tvgLazyGenDefault,
4261 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004262 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004263 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004264 "error_if_validators": (
4265 TosaErrorValidator.evRankMismatch,
4266 TosaErrorValidator.evWrongInputType,
4267 TosaErrorValidator.evWrongOutputType,
4268 TosaErrorValidator.evWrongInputList,
4269 TosaErrorValidator.evWrongOutputList,
4270 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004271 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004272 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004273 "data_gen": {
4274 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4275 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004276 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004277 "greater": {
4278 "op": Op.GREATER,
4279 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004280 "build_fcn": (
4281 build_comparison,
4282 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004283 TosaTensorValuesGen.tvgLazyGenDefault,
4284 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004285 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004286 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004287 "error_if_validators": (
4288 TosaErrorValidator.evRankMismatch,
4289 TosaErrorValidator.evWrongInputType,
4290 TosaErrorValidator.evWrongOutputType,
4291 TosaErrorValidator.evWrongInputList,
4292 TosaErrorValidator.evWrongOutputList,
4293 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004294 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004295 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004296 "data_gen": {
4297 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4298 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004299 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004300 # Reduction operators
4301 "reduce_all": {
4302 "op": Op.REDUCE_ALL,
4303 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004304 "build_fcn": (
4305 build_reduce,
4306 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004307 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004308 TosaArgGen.agAxis,
4309 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004310 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004311 "error_if_validators": (
4312 TosaErrorValidator.evAxisLargerRank,
4313 TosaErrorValidator.evAxisSmallerZero,
4314 TosaErrorValidator.evShapeOfAxisNotOne,
4315 TosaErrorValidator.evWrongInputType,
4316 TosaErrorValidator.evWrongOutputType,
4317 TosaErrorValidator.evWrongRank,
4318 TosaErrorValidator.evWrongInputList,
4319 TosaErrorValidator.evWrongOutputList,
4320 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004321 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004322 "reduce_any": {
4323 "op": Op.REDUCE_ANY,
4324 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004325 "build_fcn": (
4326 build_reduce,
4327 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004328 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004329 TosaArgGen.agAxis,
4330 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004331 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004332 "error_if_validators": (
4333 TosaErrorValidator.evAxisLargerRank,
4334 TosaErrorValidator.evAxisSmallerZero,
4335 TosaErrorValidator.evShapeOfAxisNotOne,
4336 TosaErrorValidator.evWrongInputType,
4337 TosaErrorValidator.evWrongOutputType,
4338 TosaErrorValidator.evWrongRank,
4339 TosaErrorValidator.evWrongInputList,
4340 TosaErrorValidator.evWrongOutputList,
4341 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004342 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004343 "reduce_max": {
4344 "op": Op.REDUCE_MAX,
4345 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004346 "build_fcn": (
4347 build_reduce,
4348 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004349 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004350 TosaArgGen.agAxis,
4351 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004352 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004353 "error_if_validators": (
4354 TosaErrorValidator.evAxisLargerRank,
4355 TosaErrorValidator.evAxisSmallerZero,
4356 TosaErrorValidator.evShapeOfAxisNotOne,
4357 TosaErrorValidator.evWrongInputType,
4358 TosaErrorValidator.evWrongOutputType,
4359 TosaErrorValidator.evWrongRank,
4360 TosaErrorValidator.evWrongInputList,
4361 TosaErrorValidator.evWrongOutputList,
4362 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004363 "data_gen": {
4364 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4365 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004366 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004367 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004368 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004369 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004370 "build_fcn": (
4371 build_reduce,
4372 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004373 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004374 TosaArgGen.agAxis,
4375 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004376 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004377 "error_if_validators": (
4378 TosaErrorValidator.evAxisLargerRank,
4379 TosaErrorValidator.evAxisSmallerZero,
4380 TosaErrorValidator.evShapeOfAxisNotOne,
4381 TosaErrorValidator.evWrongInputType,
4382 TosaErrorValidator.evWrongOutputType,
4383 TosaErrorValidator.evWrongRank,
4384 TosaErrorValidator.evWrongInputList,
4385 TosaErrorValidator.evWrongOutputList,
4386 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004387 "data_gen": {
4388 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4389 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004390 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004391 "reduce_product": {
4392 "op": Op.REDUCE_PRODUCT,
4393 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004394 "build_fcn": (
4395 build_reduce,
4396 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004397 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004398 TosaArgGen.agAxis,
4399 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004400 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004401 "error_if_validators": (
4402 TosaErrorValidator.evAxisLargerRank,
4403 TosaErrorValidator.evAxisSmallerZero,
4404 TosaErrorValidator.evShapeOfAxisNotOne,
4405 TosaErrorValidator.evWrongInputType,
4406 TosaErrorValidator.evWrongOutputType,
4407 TosaErrorValidator.evWrongRank,
4408 TosaErrorValidator.evWrongInputList,
4409 TosaErrorValidator.evWrongOutputList,
4410 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004411 "data_gen": {
4412 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4413 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004414 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004415 "reduce_sum": {
4416 "op": Op.REDUCE_SUM,
4417 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004418 "build_fcn": (
4419 build_reduce,
4420 TosaTensorGen.tgBasic,
4421 TosaTensorValuesGen.tvgReduceSum,
4422 TosaArgGen.agAxis,
4423 ),
James Ward24dbc422022-10-19 12:20:31 +01004424 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004425 "error_if_validators": (
4426 TosaErrorValidator.evAxisLargerRank,
4427 TosaErrorValidator.evAxisSmallerZero,
4428 TosaErrorValidator.evShapeOfAxisNotOne,
4429 TosaErrorValidator.evWrongInputType,
4430 TosaErrorValidator.evWrongOutputType,
4431 TosaErrorValidator.evWrongRank,
4432 TosaErrorValidator.evWrongInputList,
4433 TosaErrorValidator.evWrongOutputList,
4434 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004435 "data_gen": {
4436 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4437 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004438 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004439 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004440 "concat": {
4441 "op": Op.CONCAT,
4442 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004443 "build_fcn": (
4444 build_concat,
4445 TosaTensorGen.tgConcat,
4446 TosaTensorValuesGen.tvgConcat,
4447 TosaArgGen.agAxis,
4448 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004449 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004450 "error_if_validators": (
4451 TosaErrorValidator.evAxisLargerRank,
4452 TosaErrorValidator.evAxisSmallerZero,
4453 TosaErrorValidator.evConcatInputRankMismatch,
4454 TosaErrorValidator.evConcatShapeSumMismatch,
4455 TosaErrorValidator.evConcatInputDimMismatch,
4456 TosaErrorValidator.evWrongInputType,
4457 TosaErrorValidator.evWrongOutputType,
4458 TosaErrorValidator.evWrongOutputList,
4459 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004460 "data_gen": {
4461 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4462 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004463 },
4464 "pad": {
4465 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004466 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004467 "build_fcn": (
4468 build_pad,
4469 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004470 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004471 TosaArgGen.agPad,
4472 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004473 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004474 "error_if_validators": (
4475 TosaErrorValidator.evWrongInputType,
4476 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004477 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004478 TosaErrorValidator.evWrongOutputType,
4479 TosaErrorValidator.evWrongInputList,
4480 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004481 TosaErrorValidator.evRankMismatch,
4482 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004483 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004484 "data_gen": {
4485 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4486 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004487 },
Won Jeona21b2e82023-08-10 10:33:01 +00004488 "dim": {
4489 "op": Op.DIM,
4490 "operands": (1, 0),
4491 "build_fcn": (
4492 build_dim,
4493 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004494 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004495 TosaArgGen.agAxis,
4496 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004497 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004498 "error_if_validators": (
4499 TosaErrorValidator.evAxisLargerRank,
4500 TosaErrorValidator.evAxisSmallerZero,
4501 TosaErrorValidator.evWrongInputType,
4502 TosaErrorValidator.evWrongInputList,
4503 TosaErrorValidator.evWrongOutputList,
4504 TosaErrorValidator.evWrongRank,
4505 ),
4506 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004507 "reshape": {
4508 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004509 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004510 "build_fcn": (
4511 build_reshape,
4512 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004513 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004514 TosaArgGen.agReshape,
4515 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004516 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004517 "error_if_validators": (
4518 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4519 TosaErrorValidator.evWrongInputType,
4520 TosaErrorValidator.evWrongOutputType,
4521 TosaErrorValidator.evWrongInputList,
4522 TosaErrorValidator.evWrongOutputList,
4523 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004524 "data_gen": {
4525 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4526 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004527 },
4528 "reverse": {
4529 "op": Op.REVERSE,
4530 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004531 "build_fcn": (
4532 build_reverse,
4533 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004534 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004535 TosaArgGen.agAxis,
4536 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004537 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004538 "error_if_validators": (
4539 TosaErrorValidator.evAxisSmallerZero,
4540 TosaErrorValidator.evAxisLargerRank,
4541 TosaErrorValidator.evWrongInputType,
4542 TosaErrorValidator.evWrongOutputType,
4543 TosaErrorValidator.evWrongInputList,
4544 TosaErrorValidator.evWrongOutputList,
4545 ),
evacha0198477222024-01-26 12:25:32 +00004546 "data_gen": {
4547 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4548 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004549 },
4550 "slice": {
4551 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004552 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004553 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004554 "build_fcn": (
4555 build_slice,
4556 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004557 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004558 TosaArgGen.agSlice,
4559 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004560 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004561 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004562 # TODO Turn off these error categories for now as the reference
4563 # model cannot allocate memory space for empty tensor. We probably
4564 # can report an accurate error messege at the right place during
4565 # exeuction.
4566 # TosaErrorValidator.evStartSmallerZero,
4567 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004568 TosaErrorValidator.evStartSizeOutsideBounds,
4569 TosaErrorValidator.evSizeOutputShapeMismatch,
4570 TosaErrorValidator.evInputSizeStartLengthMismatch,
4571 TosaErrorValidator.evWrongRank,
4572 TosaErrorValidator.evWrongInputType,
4573 TosaErrorValidator.evWrongOutputType,
4574 TosaErrorValidator.evWrongInputList,
4575 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004576 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004577 ),
evacha017f7d4252024-01-24 12:08:09 +00004578 "data_gen": {
4579 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4580 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004581 },
4582 "tile": {
4583 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004584 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004585 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004586 "build_fcn": (
4587 build_tile,
4588 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004589 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004590 TosaArgGen.agTile,
4591 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004592 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004593 "error_if_validators": (
4594 TosaErrorValidator.evWrongInputType,
4595 TosaErrorValidator.evWrongOutputType,
4596 TosaErrorValidator.evWrongInputList,
4597 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004598 TosaErrorValidator.evRankMismatch,
4599 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004600 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004601 "data_gen": {
4602 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4603 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004604 },
4605 "transpose": {
4606 "op": Op.TRANSPOSE,
4607 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004608 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004609 "build_fcn": (
4610 build_transpose,
4611 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004612 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004613 TosaArgGen.agTranspose,
4614 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004615 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004616 "error_if_validators": (
4617 TosaErrorValidator.evIndexOutsideBounds,
4618 TosaErrorValidator.evIndexUsedTwice,
4619 TosaErrorValidator.evWrongInputType,
4620 TosaErrorValidator.evWrongOutputType,
4621 TosaErrorValidator.evWrongInputList,
4622 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004623 TosaErrorValidator.evWrongRank,
4624 TosaErrorValidator.evRankMismatch,
4625 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004626 ),
evacha0198477222024-01-26 12:25:32 +00004627 "data_gen": {
4628 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4629 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004630 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004631 # Data nodes
4632 "const": {
4633 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004634 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004635 "build_fcn": (
4636 build_const,
4637 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004638 TosaTensorValuesGen.tvgLazyGenDefault,
4639 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004640 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004641 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha0198477222024-01-26 12:25:32 +00004642 "data_gen": {
4643 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4644 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004645 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004646 "identity": {
4647 "op": Op.IDENTITY,
4648 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004649 "build_fcn": (
4650 build_unary,
4651 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004652 TosaTensorValuesGen.tvgLazyGenDefault,
4653 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004654 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004655 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004656 "data_gen": {
4657 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4658 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004659 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004660 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004661 "gather": {
4662 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004663 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004664 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004665 "build_fcn": (
4666 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004667 TosaTensorGen.tgGather,
4668 TosaTensorValuesGen.tvgGather,
4669 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004670 ),
James Ward24dbc422022-10-19 12:20:31 +01004671 "types": (
4672 DType.INT8,
4673 DType.INT16,
4674 DType.INT32,
4675 DType.FP16,
4676 DType.BF16,
4677 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004678 DType.FP8E4M3,
4679 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004680 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004681 "error_if_validators": (
4682 TosaErrorValidator.evWrongInputType,
4683 TosaErrorValidator.evWrongOutputType,
4684 TosaErrorValidator.evWrongInputList,
4685 TosaErrorValidator.evWrongOutputList,
4686 TosaErrorValidator.evWrongRank,
4687 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004688 "data_gen": {
4689 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4690 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004691 },
4692 "scatter": {
4693 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004694 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004695 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004696 "build_fcn": (
4697 build_scatter,
4698 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004699 TosaTensorValuesGen.tvgScatter,
4700 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004701 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004702 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004703 "error_if_validators": (
4704 TosaErrorValidator.evWrongInputType,
4705 TosaErrorValidator.evWrongOutputType,
4706 TosaErrorValidator.evWrongInputList,
4707 TosaErrorValidator.evWrongOutputList,
4708 TosaErrorValidator.evWrongRank,
4709 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004710 "data_gen": {
4711 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4712 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004713 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004714 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004715 "resize": {
4716 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004717 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004718 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004719 "build_fcn": (
4720 build_resize,
4721 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004722 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004723 TosaArgGen.agResize,
4724 ),
James Ward24dbc422022-10-19 12:20:31 +01004725 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004726 "invalid_test_validators": (
4727 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004728 ),
4729 "error_if_validators": (
4730 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004731 TosaErrorValidator.evScaleSmallerEqualZero,
4732 TosaErrorValidator.evScaleNLargerMax,
4733 TosaErrorValidator.evScaleDLargerMax,
4734 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004735 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004736 TosaErrorValidator.evBorderSmallerMin,
4737 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004738 TosaErrorValidator.evWrongInputType,
4739 TosaErrorValidator.evWrongOutputType,
4740 TosaErrorValidator.evWrongRank,
4741 TosaErrorValidator.evWrongInputList,
4742 TosaErrorValidator.evWrongOutputList,
4743 TosaErrorValidator.evBatchMismatch,
4744 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004745 TosaErrorValidator.evResizeOutputShapeMismatch,
4746 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004747 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004748 "data_gen": {
4749 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4750 },
4751 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004752 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004753 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004754 "cast": {
4755 "op": Op.CAST,
4756 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004757 "build_fcn": (
4758 build_cast,
4759 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004760 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004761 TosaArgGen.agCast,
4762 ),
James Ward8b390432022-08-12 20:48:56 +01004763 "types": (
4764 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004765 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004766 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004767 DType.INT8,
4768 DType.INT16,
4769 DType.INT32,
4770 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004771 DType.FP8E4M3,
4772 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004773 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004774 "error_if_validators": (
4775 TosaErrorValidator.evWrongInputType,
4776 TosaErrorValidator.evWrongOutputType,
4777 TosaErrorValidator.evWrongInputList,
4778 TosaErrorValidator.evWrongOutputList,
4779 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004780 "data_gen": {
4781 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4782 },
4783 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004784 },
4785 "rescale": {
4786 "op": Op.RESCALE,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004787 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004788 "build_fcn": (
4789 build_rescale,
4790 TosaTensorGen.tgBasic,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004791 TosaTensorValuesGen.tvgRescale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004792 TosaArgGen.agRescale,
4793 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004794 "types": [
4795 DType.UINT8,
4796 DType.INT8,
4797 DType.INT16,
4798 DType.INT32,
4799 DType.INT48,
4800 DType.UINT16,
4801 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004802 "error_if_validators": (
4803 TosaErrorValidator.evInputZeroPointNotZero,
4804 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004805 TosaErrorValidator.evU16InputZeroPointNotValid,
4806 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004807 TosaErrorValidator.evScaleTrue,
4808 TosaErrorValidator.evScaleNotTrue,
4809 TosaErrorValidator.evWrongInputType,
4810 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004811 TosaErrorValidator.evWrongInputList,
4812 TosaErrorValidator.evWrongOutputList,
4813 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004814 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004815 # Custom
4816 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004817 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004818 # Two varients of cond_if, one that generates one of two constant tensors (no
4819 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4820 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004821 "cond_if_const": {
4822 "op": Op.COND_IF,
4823 "operands": (0, 2),
4824 "build_fcn": (
4825 build_cond_if_const,
4826 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004827 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004828 TosaArgGen.agCondIf,
4829 ),
4830 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004831 "error_if_validators": (
4832 TosaErrorValidator.evOutputListThenGraphMismatch,
4833 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004834 TosaErrorValidator.evCondIfCondNotMatchingBool,
4835 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004836 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004837 },
4838 "cond_if_binary": {
4839 "op": Op.COND_IF,
4840 "operands": (2, 0),
4841 "build_fcn": (
4842 build_cond_if_binary,
4843 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004844 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004845 TosaArgGen.agCondIf,
4846 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004847 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004848 "error_if_validators": (
4849 TosaErrorValidator.evInputListThenGraphMismatch,
4850 TosaErrorValidator.evInputListElseGraphMismatch,
4851 TosaErrorValidator.evOutputListThenGraphMismatch,
4852 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004853 TosaErrorValidator.evCondIfCondNotMatchingBool,
4854 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004855 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004856 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004857 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004858 "while_loop": {
4859 "op": Op.WHILE_LOOP,
4860 "operands": (0, 1),
4861 "build_fcn": (
4862 build_while_loop,
4863 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004864 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004865 TosaArgGen.agWhileLoop,
4866 ),
4867 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004868 "error_if_validators": (
4869 TosaErrorValidator.evInputListOutputListMismatch,
4870 TosaErrorValidator.evInputListCondGraphMismatch,
4871 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4872 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4873 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004874 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004875 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004876 },
Luke Hutton57287132023-02-06 14:54:18 +00004877 "fft2d": {
4878 "op": Op.FFT2D,
4879 "operands": (2, 0),
4880 "rank": (3, 3),
4881 "build_fcn": (
4882 build_fft2d,
4883 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004884 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004885 TosaArgGen.agFFT2d,
4886 ),
4887 "types": [DType.FP32],
4888 "error_if_validators": (
4889 TosaErrorValidator.evWrongInputType,
4890 TosaErrorValidator.evWrongOutputType,
4891 TosaErrorValidator.evWrongInputList,
4892 TosaErrorValidator.evWrongOutputList,
4893 TosaErrorValidator.evWrongRank,
4894 TosaErrorValidator.evBatchMismatch,
4895 TosaErrorValidator.evKernelNotPowerOfTwo,
4896 TosaErrorValidator.evFFTInputShapeMismatch,
4897 TosaErrorValidator.evFFTOutputShapeMismatch,
4898 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004899 "data_gen": {
4900 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4901 },
Luke Hutton57287132023-02-06 14:54:18 +00004902 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004903 "rfft2d": {
4904 "op": Op.RFFT2D,
4905 "operands": (1, 0),
4906 "rank": (3, 3),
4907 "build_fcn": (
4908 build_rfft2d,
4909 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004910 TosaTensorValuesGen.tvgLazyGenDefault,
4911 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004912 ),
4913 "types": [DType.FP32],
4914 "error_if_validators": (
4915 TosaErrorValidator.evWrongInputType,
4916 TosaErrorValidator.evWrongOutputType,
4917 TosaErrorValidator.evWrongInputList,
4918 TosaErrorValidator.evWrongOutputList,
4919 TosaErrorValidator.evWrongRank,
4920 TosaErrorValidator.evBatchMismatch,
4921 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004922 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004923 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004924 "data_gen": {
4925 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4926 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004927 },
Won Jeon74342e52024-01-09 00:34:40 +00004928 # Shape
4929 "add_shape": {
4930 "op": Op.ADD_SHAPE,
4931 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004932 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004933 "build_fcn": (
4934 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004935 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004936 TosaTensorValuesGen.tvgAddSub,
4937 TosaArgGen.agNone,
4938 ),
4939 "types": [DType.SHAPE],
4940 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4941 },
4942 "sub_shape": {
4943 "op": Op.SUB_SHAPE,
4944 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004945 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004946 "build_fcn": (
4947 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004948 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004949 TosaTensorValuesGen.tvgAddSub,
4950 TosaArgGen.agNone,
4951 ),
4952 "types": [DType.SHAPE],
4953 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4954 },
4955 "mul_shape": {
4956 "op": Op.MUL_SHAPE,
4957 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004958 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004959 "build_fcn": (
4960 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004961 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004962 TosaTensorValuesGen.tvgMul,
4963 TosaArgGen.agNone,
4964 ),
4965 "types": [DType.SHAPE],
4966 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4967 },
4968 "div_shape": {
4969 "op": Op.DIV_SHAPE,
4970 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004971 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004972 "build_fcn": (
4973 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004974 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004975 TosaTensorValuesGen.tvgIntDiv,
4976 TosaArgGen.agNone,
4977 ),
4978 "types": [DType.SHAPE],
4979 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4980 },
4981 "concat_shape": {
4982 "op": Op.CONCAT_SHAPE,
4983 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004984 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004985 "build_fcn": (
4986 build_concat,
4987 TosaTensorGen.tgConcat,
4988 TosaTensorValuesGen.tvgConcat,
4989 TosaArgGen.agNone,
4990 ),
4991 "types": [DType.SHAPE],
4992 "error_if_validators": (),
4993 },
4994 "const_shape": {
4995 "op": Op.CONST_SHAPE,
4996 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004997 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004998 "build_fcn": (
4999 build_const,
5000 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00005001 TosaTensorValuesGen.tvgLazyGenDefault,
5002 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00005003 ),
5004 "types": [DType.SHAPE],
5005 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005006 }
5007
Kevin Cheng550ccc52021-03-03 11:21:43 -08005008
Eric Kunzee5e26762020-10-13 16:11:07 -07005009class OutputShaper:
5010 # Methods in this class compute the expected output shape and datatype
5011 # for common classes of operations
5012 def __init__(self):
5013 pass
5014
5015 # These methods return arguments that can be used for
5016 # creating a new output tensor
5017 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005018 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5019 if error_name != ErrorIf.RankMismatch:
5020 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005021 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005022
5023 shape = []
5024 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005025 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005026 shape.append(b.shape[i])
5027 else:
5028 shape.append(a.shape[i])
5029
Jerry Ge135c9552023-05-23 20:59:32 +00005030 fuzz_idx = rng.integers(0, len(a.shape))
5031 if error_name == ErrorIf.DimensionMismatch:
5032 shape[fuzz_idx] += 1
5033
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005034 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005035 all_dtypes = [
5036 DType.INT8,
5037 DType.INT16,
5038 DType.INT32,
5039 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005040 DType.FP16,
5041 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005042 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005043 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005044 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5045 outputDType = rng.choice(wrong_dtypes)
5046 else:
5047 outputDType = a.dtype
5048
5049 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005050
5051 @staticmethod
5052 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005053 assert len(a.shape) == len(b.shape)
5054 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005055
5056 shape = []
5057 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005058 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005059 shape.append(a.shape[i])
5060
Kevin Cheng550ccc52021-03-03 11:21:43 -08005061 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005062
5063 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005064 def unaryOp(ser, rng, a, error_name=None):
5065 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005066 all_dtypes = [
5067 DType.INT8,
5068 DType.INT16,
5069 DType.INT32,
5070 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005071 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005072 DType.FP16,
5073 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005074 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005075 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5076 outputDType = rng.choice(wrong_dtypes)
5077 else:
5078 outputDType = a.dtype
5079
5080 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005081
5082 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005083 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005084 if error_name != ErrorIf.RankMismatch:
5085 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005086 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005087
5088 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005089 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005090 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005091 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5092 else:
5093 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005094
Jerry Ge135c9552023-05-23 20:59:32 +00005095 fuzz_idx = rng.integers(0, len(a.shape))
5096 if error_name == ErrorIf.DimensionMismatch:
5097 shape[fuzz_idx] += 1
5098
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005099 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005100 all_dtypes = [
5101 DType.INT8,
5102 DType.INT16,
5103 DType.INT32,
5104 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005105 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005106 DType.FP16,
5107 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005108 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005109 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5110 outputDType = rng.choice(wrong_dtypes)
5111 else:
5112 outputDType = a.dtype
5113
5114 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005115
5116 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005117 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005118 if error_name != ErrorIf.RankMismatch:
5119 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005120 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005121
5122 # Do broadcast
5123 shape = []
5124 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005125 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005126 shape.append(b.shape[i])
5127 else:
5128 shape.append(a.shape[i])
5129
Jerry Ge135c9552023-05-23 20:59:32 +00005130 fuzz_idx = rng.integers(0, len(a.shape))
5131 if error_name == ErrorIf.DimensionMismatch:
5132 shape[fuzz_idx] += 1
5133
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005134 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005135 wrong_dtypes = [
5136 DType.INT8,
5137 DType.INT16,
5138 DType.INT32,
5139 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005140 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005141 DType.FP16,
5142 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005143 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005144 outputDType = rng.choice(wrong_dtypes)
5145 else:
5146 outputDType = DType.BOOL
5147
5148 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005149
5150 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005151 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005152 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005153 if error_name not in [
5154 ErrorIf.AxisSmallerZero,
5155 ErrorIf.AxisLargerRank,
5156 ErrorIf.ShapeOfAxisNotOne,
5157 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005158 shape[axis] = 1
5159 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5160 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005161
Matthew Haddond6ce7252021-09-29 15:35:44 +01005162 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005163 all_dtypes = [
5164 DType.INT8,
5165 DType.INT16,
5166 DType.INT32,
5167 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005168 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005169 DType.FP16,
5170 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005171 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005172 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5173 outputDType = rng.choice(wrong_dtypes)
5174 else:
5175 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005176
Matthew Haddond6ce7252021-09-29 15:35:44 +01005177 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005178
5179 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005180 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005181 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005182
5183 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5184 del shape[axis]
5185
5186 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5187 remove = rng.choice([True, False])
5188 if remove and len(shape) > 1:
5189 del shape[0]
5190 else:
5191 shape.append(1)
5192 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5193 for i in range(len(shape)):
5194 shape[i] = shape[i] + rng.integers(1, 10)
5195
5196 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005197 all_dtypes = [
5198 DType.INT8,
5199 DType.INT16,
5200 DType.INT32,
5201 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005202 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005203 DType.FP16,
5204 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005205 DType.FP8E4M3,
5206 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005207 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005208 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5209 outputDType = rng.choice(wrong_dtypes)
5210 else:
5211 outputDType = DType.INT32
5212
5213 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005214
5215 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005216 def conv2dOp(
5217 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5218 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005219
5220 # IFM: NHWC
5221 # Filter: OHWI
5222 # OFM: NHWC
5223
Kevin Cheng550ccc52021-03-03 11:21:43 -08005224 h = (
5225 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005226 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005227 + padding[0]
5228 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005229 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005230 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005231
Kevin Cheng550ccc52021-03-03 11:21:43 -08005232 w = (
5233 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005234 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005235 + padding[2]
5236 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005237 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005238 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005239
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005240 if error_name == ErrorIf.ConvOutputShapeMismatch:
5241 choices = [1, 2, 3]
5242 change = rng.choice(choices)
5243 # increment in multiples of stride to not hit non-integer error case
5244 if change in [1, 3]:
5245 h = h + (rng.choice(choices) * strides[0])
5246 if change in [2, 3]:
5247 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005248
Eric Kunzee5e26762020-10-13 16:11:07 -07005249 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5250
James Ward8b390432022-08-12 20:48:56 +01005251 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005252 # Pick some potentially correct output dtype if input type is incorrect
5253 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005254 else:
James Ward8b390432022-08-12 20:48:56 +01005255 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005256
5257 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005258 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005259 excludes = [DType.FP16, DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00005260 if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
5261 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005262 else:
5263 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005264 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005265 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005266
Kevin Cheng550ccc52021-03-03 11:21:43 -08005267 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005268
5269 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005270 def conv3dOp(
5271 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5272 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005273
5274 # IFM: NDHWC
5275 # Filter: ODHWI
5276 # OFM: NDHWC
5277
5278 d = (
5279 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005280 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005281 + padding[0]
5282 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005283 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005284 ) // strides[0] + 1
5285
5286 h = (
5287 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005288 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005289 + padding[2]
5290 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005291 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005292 ) // strides[1] + 1
5293
5294 w = (
5295 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005296 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005297 + padding[4]
5298 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005299 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005300 ) // strides[2] + 1
5301
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005302 if error_name == ErrorIf.ConvOutputShapeMismatch:
5303 choices = [1, 2, 3, 4]
5304 change = rng.choice(choices)
5305 # increment in multiples of stride to not hit non-integer error case
5306 if change in [1, 4]:
5307 d = d + (rng.choice(choices) * strides[0])
5308 if change in [2, 4]:
5309 h = h + (rng.choice(choices) * strides[1])
5310 if change in [3, 4]:
5311 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005312
Kevin Cheng1533b852021-09-01 12:51:58 -07005313 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5314
James Ward8b390432022-08-12 20:48:56 +01005315 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005316 # Pick some potentially correct output dtype if input type is incorrect
5317 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005318 else:
James Ward8b390432022-08-12 20:48:56 +01005319 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005320
5321 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005322 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005323 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005324 else:
5325 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005326 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005327 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005328
5329 return ser.addOutput(ofm_shape, out_dtype)
5330
5331 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005332 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005333 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005334 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005335 # IFM: NHWC
5336 # Filter: HWCM
5337 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005338
Kevin Cheng550ccc52021-03-03 11:21:43 -08005339 h = (
5340 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005341 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005342 + padding[0]
5343 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005344 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005345 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005346
Kevin Cheng550ccc52021-03-03 11:21:43 -08005347 w = (
5348 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005349 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005350 + padding[2]
5351 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005352 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005353 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005354
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005355 if error_name == ErrorIf.ConvOutputShapeMismatch:
5356 choices = [1, 2, 3]
5357 change = rng.choice(choices)
5358 # increment in multiples of stride to not hit non-integer error case
5359 if change in [1, 3]:
5360 h = h + (rng.choice(choices) * strides[0])
5361 if change in [2, 3]:
5362 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005363
Eric Kunzee5e26762020-10-13 16:11:07 -07005364 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5365
James Ward8b390432022-08-12 20:48:56 +01005366 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005367 # Pick some potentially correct output dtype if input type is incorrect
5368 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005369 else:
James Ward8b390432022-08-12 20:48:56 +01005370 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005371
5372 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005373 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005374 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005375 else:
5376 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005377 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005378 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005379
Kevin Cheng550ccc52021-03-03 11:21:43 -08005380 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005381
5382 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005383 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005384 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005385 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005386 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005387 h = 1
5388 w = 1
5389 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005390 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5391 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005392
5393 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005394 choices = [1, 2, 3]
5395 change = rng.choice(choices)
5396 # increment in multiples of stride to not hit non-integer error case
5397 if change in [1, 3]:
5398 h = h + (rng.choice(choices) * stride[0])
5399 if change in [2, 3]:
5400 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005401 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005402
5403 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005404 all_dtypes = [
5405 DType.INT8,
5406 DType.INT16,
5407 DType.INT32,
5408 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005409 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005410 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005411 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005412 DType.FP8E4M3,
5413 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005414 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005415 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5416 outputDType = rng.choice(wrong_dtypes)
5417 else:
5418 outputDType = ifm.dtype
5419
5420 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005421
5422 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005423 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005424 # input: N, IC
5425 # filter: OC, IC
5426 # output: N, OC
5427
5428 output_shape = [input.shape[0], filter.shape[0]]
5429
James Ward8b390432022-08-12 20:48:56 +01005430 # Validated in arg_gen (also invalidated for ErrorIf)
5431 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005432
Kevin Cheng550ccc52021-03-03 11:21:43 -08005433 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005434
5435 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005436 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005437 # a: N, H, C
5438 # b: N, C, W
5439 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005440
Kevin Cheng2d60f002021-06-09 14:18:32 -07005441 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005442
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005443 if error_name == ErrorIf.WrongOutputType:
5444 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005445 incorrect_types = (
5446 DType.INT4,
5447 DType.INT8,
5448 DType.INT16,
5449 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005450 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005451 DType.FP16,
5452 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005453 DType.FP8E4M3,
5454 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005455 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005456 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005457 incorrect_types = (
5458 DType.INT4,
5459 DType.INT8,
5460 DType.INT16,
5461 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005462 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005463 DType.FP16,
5464 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005465 DType.FP8E4M3,
5466 DType.FP8E5M2,
5467 )
5468 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5469 incorrect_types = (
5470 DType.INT4,
5471 DType.INT8,
5472 DType.INT16,
5473 DType.INT32,
5474 DType.INT48,
5475 DType.FP32,
5476 DType.BF16,
5477 DType.FP8E4M3,
5478 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005479 )
James Ward24dbc422022-10-19 12:20:31 +01005480 elif (
5481 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5482 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005483 incorrect_types = (
5484 DType.INT4,
5485 DType.INT8,
5486 DType.INT16,
5487 DType.INT32,
5488 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005489 DType.FP8E4M3,
5490 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005491 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005492 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005493 elif error_name == ErrorIf.WrongInputType:
5494 # Pick some potentially correct output dtype if input type is incorrect
5495 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005496 else:
James Ward8b390432022-08-12 20:48:56 +01005497 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005498
Kevin Cheng550ccc52021-03-03 11:21:43 -08005499 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005500
5501 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005502 def concatOp(ser, rng, axis, inputs, error_name=None):
5503 input1 = inputs[0]
5504 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005505
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005506 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005507 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005508 if not (
5509 # unable to concat tensors of different ranks
5510 error_name == ErrorIf.ConcatInputRankMismatch
5511 # unable to concat tensors along an invalid axis
5512 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005513 ):
5514 for tensor in remaining_inputs:
5515 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005516
Matthew Haddon01c359d2021-10-15 16:30:48 +01005517 if error_name == ErrorIf.ConcatShapeSumMismatch:
5518 output_shape[axis] += rng.integers(5, 10)
5519
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005520 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005521 all_dtypes = {
5522 DType.INT8,
5523 DType.INT16,
5524 DType.INT32,
5525 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005526 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005527 DType.FP16,
5528 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005529 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005530 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5531 outputDType = rng.choice(wrong_dtypes)
5532 else:
5533 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005534
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005535 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005536
5537 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005538 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005539
5540 output_shape = a.shape.copy()
5541
5542 for i in range(len(output_shape)):
5543 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5544
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005545 if error_name == ErrorIf.PadOutputShapeMismatch:
5546 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005547 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005548 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005549 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005550
Matthew Haddone807aae2021-10-11 18:12:58 +01005551 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005552 all_dtypes = [
5553 DType.INT8,
5554 DType.INT16,
5555 DType.INT32,
5556 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005557 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005558 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005559 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005560 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005561 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5562 outputDType = rng.choice(wrong_dtypes)
5563 else:
5564 outputDType = a.dtype
5565
5566 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005567
5568 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005569 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005570 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005571
5572 if error_name == ErrorIf.WrongOutputType:
5573 all_dtypes = [
5574 DType.INT8,
5575 DType.INT16,
5576 DType.INT32,
5577 DType.INT48,
5578 DType.FP32,
5579 DType.FP16,
5580 DType.BF16,
5581 ]
5582 wrong_dtypes = list(set(all_dtypes))
5583 outputDType = rng.choice(wrong_dtypes)
5584 else:
5585 outputDType = DType.SHAPE
5586
5587 return ser.addOutput(output_shape, outputDType)
5588
5589 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005590 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005591 output_shape = shape.copy()
5592
Matthew Haddone807aae2021-10-11 18:12:58 +01005593 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5594 for i in range(len(output_shape)):
5595 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5596
5597 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005598 all_dtypes = [
5599 DType.INT8,
5600 DType.INT16,
5601 DType.INT32,
5602 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005603 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005604 DType.FP16,
5605 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005606 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005607 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5608 outputDType = rng.choice(wrong_dtypes)
5609 else:
5610 outputDType = a.dtype
5611
5612 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005613
5614 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005615 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005616
Matthew Haddone807aae2021-10-11 18:12:58 +01005617 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005618 all_dtypes = [
5619 DType.INT8,
5620 DType.INT16,
5621 DType.INT32,
5622 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005623 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005624 DType.FP16,
5625 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005626 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005627 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005628 outputDType = rng.choice(wrong_dtypes)
5629 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005630 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005631
Luke Huttona4e48ca2023-02-22 11:53:48 +00005632 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005633 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005634 for index in range(len(output_shape)):
5635 if output_shape[index] <= 2:
5636 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5637 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005638 output_shape[index] = output_shape[index] + rng.choice(
5639 [-2, -1, 1, 2]
5640 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005641 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5642 output_shape = input.shape.copy()
5643 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005644 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005645
5646 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005647
5648 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005649 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005650
5651 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005652 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005653
5654 for i in range(len(output_shape)):
5655 output_shape[i] = a.shape[i] * multiples[i]
5656
Luke Huttona4e48ca2023-02-22 11:53:48 +00005657 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005658 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005659
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005660 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005661 all_dtypes = [
5662 DType.INT8,
5663 DType.INT16,
5664 DType.INT32,
5665 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005666 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005667 DType.FP16,
5668 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005669 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005670 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5671 outputDType = rng.choice(wrong_dtypes)
5672 else:
5673 outputDType = a.dtype
5674
5675 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005676
5677 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005678 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005679 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005680
Kevin Cheng550ccc52021-03-03 11:21:43 -08005681 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005682
Luke Huttona4e48ca2023-02-22 11:53:48 +00005683 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005684 for i in range(len(output_shape)):
5685 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005686
Luke Huttona4e48ca2023-02-22 11:53:48 +00005687 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5688 for i in range(len(output_shape)):
5689 output_shape[i] += rng.integers(1, 10)
5690 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005691 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005692
Matthew Haddone807aae2021-10-11 18:12:58 +01005693 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005694 all_dtypes = [
5695 DType.INT8,
5696 DType.INT16,
5697 DType.INT32,
5698 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005699 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005700 DType.FP16,
5701 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005702 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005703 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5704 outputDType = rng.choice(wrong_dtypes)
5705 else:
5706 outputDType = a.dtype
5707
5708 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005709
5710 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005711 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005712 if error_name != ErrorIf.WrongRank:
5713 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005714 assert len(indices.shape) == 2
5715 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005716
Kevin Cheng77d0f762020-11-24 10:26:32 -08005717 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5718
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005719 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005720 all_dtypes = [
5721 DType.INT8,
5722 DType.INT16,
5723 DType.INT32,
5724 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005725 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005726 DType.FP16,
5727 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005728 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005729 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5730 outputDType = rng.choice(wrong_dtypes)
5731 else:
5732 outputDType = values.dtype
5733
5734 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005735
5736 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005737 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005738 if error_name != ErrorIf.WrongRank:
5739 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005740 assert len(indices.shape) == 2
5741 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005742 assert values_in.shape[0] == indices.shape[0] # N
5743 assert input.shape[1] == indices.shape[1] # W
5744 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005745
5746 output_shape = values_in.shape
5747
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005748 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005749 all_dtypes = [
5750 DType.INT8,
5751 DType.INT16,
5752 DType.INT32,
5753 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005754 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005755 DType.FP16,
5756 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005757 DType.FP8E4M3,
5758 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005759 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005760 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5761 outputDType = rng.choice(wrong_dtypes)
5762 else:
5763 outputDType = values_in.dtype
5764
5765 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005766
5767 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005768 def tableOp(ser, rng, input, error_name=None):
5769 # Same shape as the input, dtype dependent on input dtype
5770 if error_name != ErrorIf.WrongInputType:
5771 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005772 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005773 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005774 wrong_dtypes = [
5775 DType.INT8,
5776 DType.INT16,
5777 DType.INT32,
5778 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005779 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005780 DType.FP16,
5781 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005782 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005783 wrong_dtypes.remove(output_dtype)
5784 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005785 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005786
5787 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005788 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005789 serializer,
5790 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005791 input,
5792 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005793 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005794 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005795 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005796 input_dtype,
5797 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005798 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005799 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005800 # Calculate OH, OW
5801 scale_y_n = scale[0]
5802 scale_y_d = scale[1]
5803 scale_x_n = scale[2]
5804 scale_x_d = scale[3]
5805 if error_name == ErrorIf.ScaleSmallerEqualZero:
5806 scale_y_n = max(scale_y_n, 1)
5807 scale_y_d = max(scale_y_d, 1)
5808 scale_x_n = max(scale_x_n, 1)
5809 scale_x_d = max(scale_x_d, 1)
5810
5811 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5812 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5813
5814 if error_name is not None:
5815 # Make sure the output tensor is valid, which can occur when
5816 # scale, offset or border have been changed for ERROR_IFs
5817 oh = max(oh, 1)
5818 ow = max(ow, 1)
5819 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005820 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5821 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005822
5823 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5824 choices = [1, 2, 3]
5825 change = rng.choice(choices)
5826 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5827 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005828 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005829 oh -= scale_y_d
5830 assert oh > 0 # Should have been caught in agResize
5831 else:
5832 oh += scale_y_d
5833 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005834 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005835 ow -= scale_x_d
5836 assert ow > 0 # Should have been caught in agResize
5837 else:
5838 ow += scale_x_d
5839
Matthew Haddon848efb42021-09-09 12:30:53 +01005840 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005841 output_dims = [
5842 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005843 oh,
5844 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005845 input.shape[0],
5846 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005847 elif error_name == ErrorIf.BatchMismatch:
5848 output_dims = [
5849 input.shape[0] + rng.integers(1, 10),
5850 oh,
5851 ow,
5852 input.shape[3],
5853 ]
5854 elif error_name == ErrorIf.ChannelMismatch:
5855 output_dims = [
5856 input.shape[0],
5857 oh,
5858 ow,
5859 input.shape[3] + rng.integers(1, 10),
5860 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005861 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005862 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005863
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005864 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005865
5866 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005867 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005868 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005869
5870 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005871 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005872 if error_name == ErrorIf.ConvOutputShapeMismatch:
5873 choices = [1, 2, 3]
5874 change = rng.choice(choices)
5875 if change in [1, 3]:
5876 output_shape[1] = output_shape[1] + rng.choice(choices)
5877 if change in [2, 3]:
5878 output_shape[2] = output_shape[2] + rng.choice(choices)
5879
James Ward8b390432022-08-12 20:48:56 +01005880 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005881 # Pick some potentially correct output dtype if input type is incorrect
5882 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005883 else:
James Ward8b390432022-08-12 20:48:56 +01005884 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005885
5886 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005887 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005888 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005889 else:
5890 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005891 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005892 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005893
Kevin Cheng550ccc52021-03-03 11:21:43 -08005894 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005895
5896 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005897 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5898 outputs = []
5899
5900 assert ifm1.dtype == ifm2.dtype
5901 input_dtype = ifm1.dtype
5902
5903 if error_name != ErrorIf.FFTInputShapeMismatch:
5904 assert ifm1.shape == ifm2.shape
5905
5906 input_shape = ifm1.shape
5907 if error_name != ErrorIf.WrongRank:
5908 assert len(input_shape) == 3
5909
5910 output_shape = input_shape.copy()
5911 output_dtype = input_dtype
5912
5913 if error_name == ErrorIf.WrongOutputType:
5914 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005915 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005916 output_dtype = rng.choice(wrong_dtypes)
5917 elif error_name == ErrorIf.BatchMismatch:
5918 output_shape[0] += rng.integers(1, 10)
5919 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5920 modify_dim = rng.choice([1, 2])
5921 output_shape[modify_dim] += rng.integers(1, 10)
5922
5923 outputs.append(serializer.addOutput(output_shape, output_dtype))
5924 outputs.append(serializer.addOutput(output_shape, output_dtype))
5925 return outputs
5926
5927 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005928 def rfft2dOp(serializer, rng, value, error_name=None):
5929 outputs = []
5930
5931 input_shape = value.shape
5932 if error_name != ErrorIf.WrongRank:
5933 assert len(input_shape) == 3
5934
5935 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5936
5937 output_dtype = value.dtype
5938 if error_name == ErrorIf.WrongOutputType:
5939 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005940 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005941 output_dtype = rng.choice(wrong_dtypes)
5942 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005943 output_shape[0] += rng.integers(1, 10)
5944 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5945 modify_dim = rng.choice([1, 2])
5946 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005947
5948 outputs.append(serializer.addOutput(output_shape, output_dtype))
5949 outputs.append(serializer.addOutput(output_shape, output_dtype))
5950 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005951
5952 @staticmethod
5953 def addShapeOp(ser, rng, a, b, error_name=None):
5954 if error_name != ErrorIf.RankMismatch:
5955 assert len(a.shape) == len(b.shape)
5956 assert a.dtype == b.dtype
5957
5958 shape = []
5959 for i in range(len(a.shape)):
5960 shape.append(a.shape[i])
5961
5962 fuzz_idx = rng.integers(0, len(a.shape))
5963 if error_name == ErrorIf.DimensionMismatch:
5964 shape[fuzz_idx] += 1
5965
5966 if error_name == ErrorIf.WrongOutputType:
5967 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5968 outputDType = rng.choice(wrong_dtypes)
5969 else:
5970 outputDType = DType.SHAPE
5971 return ser.addOutput(shape, outputDType)