blob: b472087198fd1b2e3465fa0f0bf0e10f0d7b3df8 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
Won Jeon2c34b462024-02-06 18:37:00 +000079 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010080 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Won Jeon2c34b462024-02-06 18:37:00 +0000155 if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Jerry Ge20ab3df2024-01-26 16:56:55 +0000198 elif dtype == DType.INT16:
199 return np.int16(self.rng.integers(low=low, high=high, size=shape))
200 elif dtype == DType.UINT16:
201 return np.uint16(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000202 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100203 return np.int64(self.rng.integers(low=low, high=high, size=shape))
Won Jeon2c34b462024-02-06 18:37:00 +0000204 elif dtype in (
205 DType.FP16,
206 DType.BF16,
207 DType.FP32,
208 DType.FP8E4M3,
209 DType.FP8E5M2,
210 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100211 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
212
213 if dtype == DType.FP16:
214 return np.float16(f_tensor)
215 else:
216 f32_tensor = np.float32(f_tensor)
217 if dtype == DType.BF16:
218 # Floor the last 16 bits of each f32 value
219 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
Won Jeon2c34b462024-02-06 18:37:00 +0000220 elif dtype == DType.FP8E4M3:
221 return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
222 elif dtype == DType.FP8E5M2:
223 return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
Jeremy Johnson1271c442023-09-05 11:39:26 +0100224 else:
225 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700226 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100227 # All other integer types
228 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
Kevin Cheng989cb052021-04-28 16:29:44 -0700230 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700231 placeholders = []
232
Kevin Cheng989cb052021-04-28 16:29:44 -0700233 assert len(shape_list) == len(dtype_list)
234
Jeremy Johnson1271c442023-09-05 11:39:26 +0100235 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700236 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100237 if not self.args.lazy_data_gen:
238 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700239 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700240
241 return placeholders
242
Kevin Cheng989cb052021-04-28 16:29:44 -0700243 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700244 consts = []
245
Kevin Cheng989cb052021-04-28 16:29:44 -0700246 assert len(shape_list) == len(dtype_list)
247
Jeremy Johnson1271c442023-09-05 11:39:26 +0100248 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700249 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100250 if not self.args.lazy_data_gen:
251 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700252 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700253
254 return consts
255
256 def makeShape(self, rank):
257 if self.targetted_shape:
258 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800259 return np.int32(
260 self.rng.integers(
261 low=self.args.tensor_shape_range[0],
262 high=self.args.tensor_shape_range[1],
263 size=rank,
264 )
265 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700266
267 def setTargetShape(self, shape):
268 self.targetted_shape = shape
269
270 def randInt(self, low=0, high=256):
271 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
272
273 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100274 low, high = self.getDTypeRange(dtype)
275
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100276 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100277 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100278 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100279 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100280 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100281 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
282 return gtu.vect_f32_to_bf16(rand_f32)
Won Jeon2c34b462024-02-06 18:37:00 +0000283 elif dtype == DType.FP8E4M3:
284 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
285 return gtu.vect_f32_to_fp8e4m3(rand_f32)
286 elif dtype == DType.FP8E5M2:
287 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
288 return gtu.vect_f32_to_fp8e5m2(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700289 elif dtype == DType.BOOL:
290 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000291 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 # Special size
293 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700294
295 return np.int32(self.rng.integers(low, high, size=1))[0]
296
297 def shapeStr(self, shape):
298
299 sStr = []
300 # Convert to strings
301 for i in shape:
302 sStr.append(str(i))
303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100306 def typeStr(self, dtype):
307 if isinstance(dtype, list) or isinstance(dtype, tuple):
308 assert len(dtype) >= 2
309 strs = [self.typeStr(t) for t in dtype]
310 # Limit types to the first 2 as the 3rd is the accumulator
311 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700312 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100313 if dtype in gtu.DTYPE_ATTRIBUTES:
314 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700315 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100316 raise Exception(
317 "Unknown dtype, cannot convert to string: {}".format(dtype)
318 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700319
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100320 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100321 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100322 if dtype in gtu.DTYPE_ATTRIBUTES:
323 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700324 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100325 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700326
Luke Hutton57287132023-02-06 14:54:18 +0000327 def constrictBatchSize(self, shape):
328 # Limit the batch size unless an explicit target shape set
329 if self.args.max_batch_size and not self.args.target_shapes:
330 shape[0] = min(shape[0], self.args.max_batch_size)
331 return shape
332
James Ward30124a82023-02-02 14:56:33 +0000333 def makeDimension(self):
334 return self.randInt(
335 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
336 )
337
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100338 def tensorComplianceMetaData(
339 self, op, inputType, argsDict, outputTensor, errorName
340 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000341 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
342 UNSUPPORTED_NON_FP32_INPUT_OPS = (
343 Op.MATMUL,
344 Op.CONV2D,
345 Op.FULLY_CONNECTED,
346 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000347 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000348 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000349 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100350 if (
351 errorName
352 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000353 or (
354 not gtu.dtypeIsSupportedByCompliance(inputType)
355 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
356 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100357 ):
358 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100359 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100360
Jeremy Johnson1271c442023-09-05 11:39:26 +0100361 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100362 compliance_tens = {
363 "mode": None,
364 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
365 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
366 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100367 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
368 mode = gtu.ComplianceMode.DOT_PRODUCT
369 compliance_tens["dot_product_info"] = {
370 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100371 "ks": int(argsDict["ksb"])
372 if "ksb" in argsDict
373 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100374 }
375 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
376 mode = gtu.ComplianceMode.FP_SPECIAL
377 elif "compliance" in op and "ulp" in op["compliance"]:
378 mode = gtu.ComplianceMode.ULP
379 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000380 elif "compliance" in op and "relative" in op["compliance"]:
381 mode = gtu.ComplianceMode.RELATIVE
382 compliance_tens["relative_info"] = {
383 "max": argsDict["max_abs_value"],
384 "scale": op["compliance"]["relative"],
385 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100386 elif op["op"] == Op.REDUCE_PRODUCT:
387 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000388 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000389 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000390 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000391 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
392 compliance_tens["abs_error_info"] = {
393 "lower_bound": op["compliance"]["abs_error_lower_bound"]
394 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100395 else:
396 mode = gtu.ComplianceMode.EXACT
397 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
398
399 return compliance_tens
400
401 # Build Op functions
402 # Create the output tensor (calling OutputShaper as needed)
403 # Do final tweaks to attributes (if necessary for errorIf)
404 # Add Op into graph
405 # Return resulting tensor information or BuildInfo
406
407 class BuildInfo:
408 """Enhanced build information containing result tensor and associated compliance dict."""
409
410 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000411 if isinstance(resultTensor, list):
412 assert complianceDict is None or isinstance(complianceDict, list)
413 self.resultTensorList = resultTensor
414 self.complianceDictList = complianceDict
415 else:
416 self.resultTensorList = [resultTensor]
417 if complianceDict is None:
418 self.complianceDictList = None
419 else:
420 self.complianceDictList = [complianceDict]
421
422 def getComplianceInfo(self):
423 if self.complianceDictList is None:
424 return None
425 else:
426 tens_dict = {}
427 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
428 if comp is not None:
429 tens_dict[tens.name] = comp
430
431 if tens_dict:
432 # Have some compliance data, so return the info
433 compliance = {
434 "version": "0.1",
435 "tensors": tens_dict,
436 }
437 else:
438 compliance = None
439 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700440
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000441 def build_unary(
442 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
443 ):
444 assert len(inputs) == 1
445 a = inputs[0]
446 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100447
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000448 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100449
450 # Ensure new output type has correct qinfo
451 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000452 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000453 qinfo = [
454 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000455 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000456 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100457
458 # Invalidate Input/Output list for error if checks.
459 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000460 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100461 pCount, cCount = op["operands"]
462 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000463 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
464 self, error_name, input_list, output_list
465 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100466
Les Bell729b0352021-11-24 10:28:21 +0000467 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100468 self.ser,
469 validator_fcns,
470 error_name,
471 op=op,
472 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000473 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000474 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000475 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100476 input_list=input_list,
477 output_list=output_list,
478 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000479 ):
480 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100481
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000482 attr = None
483 if op["op"] == Op.NEGATE:
484 attr = ts.TosaSerializerAttribute()
485 attr.NegateAttribute(qinfo[0], qinfo[1])
486
487 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000488
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000489 compliance = self.tensorComplianceMetaData(
490 op, a.dtype, args_dict, result_tensor, error_name
491 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000492 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700493
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000494 def build_binary_broadcast(
495 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
496 ):
497 assert len(inputs) == 2
498 a, b = inputs
499 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000500 self.ser, self.rng, a, b, error_name
501 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100502
503 # Invalidate Input/Output list for error if checks.
504 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000505 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100506 pCount, cCount = op["operands"]
507 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000508 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
509 self, error_name, input_list, output_list
510 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100511
Les Bell729b0352021-11-24 10:28:21 +0000512 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100513 self.ser,
514 validator_fcns,
515 error_name,
516 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000517 input1=a,
518 input2=b,
519 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000520 output_dtype=result_tensor.dtype,
521 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100522 input_list=input_list,
523 output_list=output_list,
524 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000525 ):
526 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100527
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000528 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000529
Jeremy Johnson9a758382023-11-07 16:27:35 +0000530 compliance = self.tensorComplianceMetaData(
531 op, a.dtype, args_dict, result_tensor, error_name
532 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000533
534 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700535
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100536 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700537 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000538 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700539 return result_tens
540
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000541 def build_arithmetic_right_shift(
Jeremy Johnson587cc842024-02-08 11:45:44 +0000542 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000543 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000544 assert len(inputs) == 2
545 a, b = inputs
546 round = args_dict["round"]
547 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000548 self.ser, self.rng, a, b, error_name
549 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100550
551 # Invalidate Input/Output list for error if checks.
552 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000553 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100554 pCount, cCount = op["operands"]
555 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000556 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
557 self, error_name, input_list, output_list
558 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100559
Les Bell729b0352021-11-24 10:28:21 +0000560 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100561 self.ser,
562 validator_fcns,
563 error_name,
564 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000565 input1=a,
566 input2=b,
567 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000568 output_dtype=result_tensor.dtype,
569 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100570 input_list=input_list,
571 output_list=output_list,
572 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000573 ):
574 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800575
576 attr = ts.TosaSerializerAttribute()
577 attr.ArithmeticRightShiftAttribute(round)
578
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000579 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000580
581 compliance = self.tensorComplianceMetaData(
582 op, a.dtype, args_dict, result_tensor, error_name
583 )
584
585 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800586
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100587 def build_mul(
588 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
589 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000590 # Note that mul is binary operator but it has a shift value tensor
591 assert len(inputs) == 3
592 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100593
594 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000595 self.ser, self.rng, a, b, error_name
596 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700597
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100598 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100599 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100600 result_tensor.setDtype(DType.INT32)
601
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100602 if error_name == ErrorIf.WrongOutputType:
603 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
604 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100605 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100606
607 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000608 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100609 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100610 pCount, cCount = op["operands"]
611 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000612 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
613 self, error_name, input_list, output_list
614 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100615
Les Bell729b0352021-11-24 10:28:21 +0000616 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100617 self.ser,
618 validator_fcns,
619 error_name,
620 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000621 input1=a,
622 input2=b,
623 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100624 output_dtype=result_tensor.dtype,
625 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100626 input_list=input_list,
627 output_list=output_list,
628 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000629 ):
630 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700631
Jeremy Johnson0a042992024-02-28 13:20:05 +0000632 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100633
634 compliance = self.tensorComplianceMetaData(
635 op, a.dtype, args_dict, result_tensor, error_name
636 )
637
638 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700639
Jeremy Johnson587cc842024-02-08 11:45:44 +0000640 def build_table(
641 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
642 ):
643 assert len(inputs) == 1
644 a = inputs[0]
645 table = args_dict["table"]
646 result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700647
Kevin Chengfe392ce2021-10-18 21:51:55 +0000648 attr = ts.TosaSerializerAttribute()
649 attr.TableAttribute(table)
650
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100651 # Invalidate Input/Output list for error if checks.
652 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000653 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100654 pCount, cCount = op["operands"]
655 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000656 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
657 self, error_name, input_list, output_list
658 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100659
Les Bell729b0352021-11-24 10:28:21 +0000660 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100661 self.ser,
662 validator_fcns,
663 error_name,
664 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000665 input_shape=a.shape,
666 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000667 output_dtype=result_tensor.dtype,
668 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100669 input_list=input_list,
670 output_list=output_list,
671 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000672 ):
673 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100674
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000675 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700676
Jeremy Johnson587cc842024-02-08 11:45:44 +0000677 compliance = self.tensorComplianceMetaData(
678 op, a.dtype, args_dict, result_tensor, error_name
679 )
680
681 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700682
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000683 def build_select(
684 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
685 ):
686 assert len(inputs) == 3
687 cond, a, b = inputs
688
689 result_tensor = OutputShaper.selectOp(
690 self.ser, self.rng, cond, a, b, error_name
691 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100692
693 # Invalidate Input/Output list for error if checks.
694 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000695 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100696 pCount, cCount = op["operands"]
697 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000698 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
699 self, error_name, input_list, output_list
700 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100701
Les Bell729b0352021-11-24 10:28:21 +0000702 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100703 self.ser,
704 validator_fcns,
705 error_name,
706 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000707 input1=cond,
708 input2=a,
709 input3=b,
710 input_shape=a.shape,
711 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000712 output_dtype=result_tensor.dtype,
713 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100714 input_list=input_list,
715 output_list=output_list,
716 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000717 ):
718 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100719
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000720 self.ser.addOperator(
721 op["op"],
722 input_list,
723 output_list,
724 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000725 compliance = self.tensorComplianceMetaData(
726 op, a.dtype, args_dict, result_tensor, error_name
727 )
728
729 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700730
Jeremy Johnsona0150012023-11-15 15:52:06 +0000731 def build_comparison(
732 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
733 ):
734 assert len(inputs) == 2
735 a, b = inputs
736
737 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000738 self.ser, self.rng, a, b, error_name
739 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100740
741 # Invalidate Input/Output list for error if checks.
742 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000743 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100744 pCount, cCount = op["operands"]
745 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000746 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
747 self, error_name, input_list, output_list
748 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100749
Les Bell729b0352021-11-24 10:28:21 +0000750 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100751 self.ser,
752 validator_fcns,
753 error_name,
754 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000755 input1=a,
756 input2=b,
757 input_shape=a.shape,
758 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000759 output_shape=result_tensor.shape,
760 output_dtype=result_tensor.dtype,
761 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100762 input_list=input_list,
763 output_list=output_list,
764 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000765 ):
766 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100767
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000768 self.ser.addOperator(
769 op["op"],
770 input_list,
771 output_list,
772 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000773
774 compliance = self.tensorComplianceMetaData(
775 op, a.dtype, args_dict, result_tensor, error_name
776 )
777 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700778
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000779 def build_argmax(
780 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
781 ):
782 assert len(inputs) == 1
783 a = inputs[0]
784 axis = args_dict["axis"]
785 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100786
787 # Invalidate Input/Output list for error if checks.
788 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000789 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100790 pCount, cCount = op["operands"]
791 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000792 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
793 self, error_name, input_list, output_list
794 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100795
Les Bell729b0352021-11-24 10:28:21 +0000796 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100797 self.ser,
798 validator_fcns,
799 error_name,
800 op=op,
801 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000802 input_shape=a.shape,
803 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000804 output_shape=result_tensor.shape,
805 output_dtype=result_tensor.dtype,
806 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100807 input_list=input_list,
808 output_list=output_list,
809 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000810 ):
811 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700812
813 attr = ts.TosaSerializerAttribute()
814 attr.AxisAttribute(axis)
815
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000816 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000817
818 compliance = self.tensorComplianceMetaData(
819 op, inputs[0].dtype, args_dict, result_tensor, error_name
820 )
821 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700822
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000823 def build_pool2d(
824 self,
825 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100826 inputs,
827 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000828 validator_fcns=None,
829 error_name=None,
830 qinfo=None,
831 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100832 assert len(inputs) == 1
833 input = inputs[0]
834 # max_pool has no accum_dtype
835 accum_dtype = (
836 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
837 )
838 stride = args_dict["stride"]
839 pad = args_dict["pad"]
840 kernel = args_dict["kernel"]
841
Jeremy Johnson0601f802023-11-08 16:28:09 +0000842 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000843 self.ser, self.rng, input, kernel, stride, pad, error_name
844 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100845
846 # Ensure new output type has correct qinfo
847 if error_name == ErrorIf.WrongInputType:
848 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000849 qinfo = [
850 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000851 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000852 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100853
854 # Invalidate Input/Output list for error if checks.
855 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000856 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100857 pCount, cCount = op["operands"]
858 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000859 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
860 self, error_name, input_list, output_list
861 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100862
Les Bell729b0352021-11-24 10:28:21 +0000863 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100864 self.ser,
865 validator_fcns,
866 error_name,
867 op=op,
868 input_shape=input.shape,
869 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000870 output_shape=result_tensor.shape,
871 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000872 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100873 kernel=kernel,
874 stride=stride,
875 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000876 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000877 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100878 input_list=input_list,
879 output_list=output_list,
880 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000881 ):
882 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700883
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000884 if qinfo is None:
885 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700886
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000887 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100888 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000889
890 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700891
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100892 compliance = self.tensorComplianceMetaData(
893 op, inputs[0].dtype, args_dict, result_tensor, error_name
894 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100895
896 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100897
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000898 def build_conv2d(
899 self,
900 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100901 inputs,
902 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000903 validator_fcns=None,
904 error_name=None,
905 qinfo=None,
906 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100907 assert len(inputs) == 3
908 ifm, filter, bias = inputs
909 accum_dtype = args_dict["acc_type"]
910 strides = args_dict["stride"]
911 padding = args_dict["pad"]
912 dilations = args_dict["dilation"]
913
Kevin Cheng550ccc52021-03-03 11:21:43 -0800914 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100915 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100916 self.ser,
917 self.rng,
918 ifm,
919 filter,
920 accum_dtype,
921 strides,
922 padding,
923 dilations,
924 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000925 )
926
927 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000928 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
929 DType.INT8,
930 DType.UINT8,
931 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000932 qinfo = [
933 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100934 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000935 ]
Les Bell0e027d42021-11-09 14:42:14 +0000936
937 # Invalidate Input/Output list for error_if checks.
938 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100939 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000940 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000941 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
942 self, error_name, input_list, output_list
943 )
Les Bell0e027d42021-11-09 14:42:14 +0000944
Les Bell729b0352021-11-24 10:28:21 +0000945 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000946 self.ser,
947 validator_fcns,
948 error_name,
949 op=op,
950 input_dtype=ifm.dtype,
951 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100952 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000953 qinfo=qinfo,
954 input_list=input_list,
955 num_operands=num_operands,
956 output_list=output_list,
957 pad=padding,
958 stride=strides,
959 dilation=dilations,
960 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100961 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100962 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000963 ):
964 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700965
Tai Lyd3797f02023-11-15 23:06:19 +0000966 # TODO - Test local_bound, for now set local bound attribute to False
967 local_bound = False
968
Eric Kunzee5e26762020-10-13 16:11:07 -0700969 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000970 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700971
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000972 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100973
974 compliance = self.tensorComplianceMetaData(
975 op, ifm.dtype, args_dict, result_tensor, error_name
976 )
977
978 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700979
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000980 def build_conv3d(
981 self,
982 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100983 inputs,
984 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000985 validator_fcns=None,
986 error_name=None,
987 qinfo=None,
988 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100989 assert len(inputs) == 3
990 ifm, filter, bias = inputs
991 accum_dtype = args_dict["acc_type"]
992 strides = args_dict["stride"]
993 padding = args_dict["pad"]
994 dilations = args_dict["dilation"]
995
Kevin Cheng1533b852021-09-01 12:51:58 -0700996 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000997 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100998 self.ser,
999 self.rng,
1000 ifm,
1001 filter,
1002 accum_dtype,
1003 strides,
1004 padding,
1005 dilations,
1006 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001007 )
1008
1009 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001010 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1011 DType.INT8,
1012 DType.UINT8,
1013 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001014 qinfo = [
1015 TosaQuantGen.getZeroPoint(self, ifm.dtype),
evacha0147ab1762024-01-29 13:23:23 +00001016 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001017 ]
Les Bell0e027d42021-11-09 14:42:14 +00001018
1019 # Invalidate Input/Output list for error_if checks.
1020 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +00001021 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001022 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001023 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1024 self, error_name, input_list, output_list
1025 )
Les Bell0e027d42021-11-09 14:42:14 +00001026
Les Bell729b0352021-11-24 10:28:21 +00001027 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001028 self.ser,
1029 validator_fcns,
1030 error_name,
1031 op=op,
1032 input_dtype=ifm.dtype,
1033 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +00001034 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001035 qinfo=qinfo,
1036 input_list=input_list,
1037 num_operands=num_operands,
1038 output_list=output_list,
1039 pad=padding,
1040 stride=strides,
1041 dilation=dilations,
1042 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001043 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +00001044 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001045 ):
1046 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07001047
Tai Lyd3797f02023-11-15 23:06:19 +00001048 # TODO - Test local_bound, for now set local bound attribute to False
1049 local_bound = False
1050
Kevin Cheng1533b852021-09-01 12:51:58 -07001051 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001052 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -07001053
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001054 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001055
1056 compliance = self.tensorComplianceMetaData(
1057 op, ifm.dtype, args_dict, result_tensor, error_name
1058 )
1059
1060 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001061
Kevin Cheng550ccc52021-03-03 11:21:43 -08001062 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001063 self,
1064 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001065 inputs,
1066 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001067 validator_fcns=None,
1068 error_name=None,
1069 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001070 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001071 assert len(inputs) == 3
1072 ifm, filter, bias = inputs
1073 accum_dtype = args_dict["acc_type"]
1074 strides = args_dict["stride"]
1075 out_pad = args_dict["pad"]
1076 output_shape = args_dict["out_shape"]
1077
TatWai Chong24594f52022-06-08 00:48:04 -07001078 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001079 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001080 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001081 )
Les Bell0e027d42021-11-09 14:42:14 +00001082
1083 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001084 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1085 DType.INT8,
1086 DType.UINT8,
1087 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001088 qinfo = [
1089 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001090 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001091 ]
Les Bell0e027d42021-11-09 14:42:14 +00001092
1093 # Invalidate Input/Output list for error_if checks.
1094 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001095 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001096 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001097 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1098 self, error_name, input_list, output_list
1099 )
Les Bell0e027d42021-11-09 14:42:14 +00001100
Les Bell729b0352021-11-24 10:28:21 +00001101 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001102 self.ser,
1103 validator_fcns,
1104 error_name,
1105 op=op,
1106 input_dtype=ifm.dtype,
1107 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001108 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001109 qinfo=qinfo,
1110 input_list=input_list,
1111 num_operands=num_operands,
1112 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001113 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001114 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001115 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001116 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001117 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001118 ):
1119 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001120
Tai Lyd3797f02023-11-15 23:06:19 +00001121 # TODO - Test local_bound, for now set local bound attribute to False
1122 local_bound = False
1123
Eric Kunzee5e26762020-10-13 16:11:07 -07001124 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001125 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001126 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001127 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001128
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001129 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001130
1131 compliance = self.tensorComplianceMetaData(
1132 op, ifm.dtype, args_dict, result_tensor, error_name
1133 )
1134
1135 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001136
Kevin Cheng550ccc52021-03-03 11:21:43 -08001137 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001138 self,
1139 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001140 inputs,
1141 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001142 validator_fcns=None,
1143 error_name=None,
1144 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001145 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001146 assert len(inputs) == 3
1147 ifm, filter, bias = inputs
1148 accum_dtype = args_dict["acc_type"]
1149 strides = args_dict["stride"]
1150 padding = args_dict["pad"]
1151 dilations = args_dict["dilation"]
1152
Jeremy Johnson4f931302024-01-04 17:05:24 +00001153 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001154 self.ser,
1155 self.rng,
1156 ifm,
1157 filter,
1158 accum_dtype,
1159 strides,
1160 padding,
1161 dilations,
1162 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001163 )
1164
1165 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001166 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1167 DType.INT8,
1168 DType.UINT8,
1169 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001170 qinfo = [
1171 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001172 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001173 ]
Les Bell0e027d42021-11-09 14:42:14 +00001174
1175 # Invalidate Input/Output list for error_if checks.
1176 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001177 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001178 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001179 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1180 self, error_name, input_list, output_list
1181 )
Les Bell0e027d42021-11-09 14:42:14 +00001182
Les Bell729b0352021-11-24 10:28:21 +00001183 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001184 self.ser,
1185 validator_fcns,
1186 error_name,
1187 op=op,
1188 input_dtype=ifm.dtype,
1189 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001190 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001191 qinfo=qinfo,
1192 input_list=input_list,
1193 num_operands=num_operands,
1194 output_list=output_list,
1195 pad=padding,
1196 stride=strides,
1197 dilation=dilations,
1198 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001199 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001200 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001201 ):
1202 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001203
Tai Lyd3797f02023-11-15 23:06:19 +00001204 # TODO - Test local_bound, for now set local bound attribute to False
1205 local_bound = False
1206
Eric Kunzee5e26762020-10-13 16:11:07 -07001207 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001208 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001209
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001210 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001211
1212 compliance = self.tensorComplianceMetaData(
1213 op, ifm.dtype, args_dict, result_tensor, error_name
1214 )
1215
1216 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001217
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001218 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001219 self,
1220 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001221 inputs,
1222 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001223 validator_fcns=None,
1224 error_name=None,
1225 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001226 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001227 assert len(inputs) == 3
1228 ifm, filter, bias = inputs
1229 accum_dtype = args_dict["acc_type"]
1230
1231 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001232 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001233 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001234
1235 # Invalidate Input/Output list for error if checks.
1236 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001237 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001238 pCount, cCount = op["operands"]
1239 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001240 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1241 self, error_name, input_list, output_list
1242 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001243
Les Bell729b0352021-11-24 10:28:21 +00001244 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001245 self.ser,
1246 validator_fcns,
1247 error_name,
1248 op=op,
1249 input_shape=ifm.shape,
1250 input_dtype=ifm.dtype,
1251 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001252 output_shape=result_tensor.shape,
1253 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001254 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001255 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001256 input_list=input_list,
1257 output_list=output_list,
1258 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001259 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001260 ):
1261 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001262
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001263 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001264 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001265
1266 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001267
1268 compliance = self.tensorComplianceMetaData(
1269 op, ifm.dtype, args_dict, result_tensor, error_name
1270 )
1271
1272 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001273
James Ward8b390432022-08-12 20:48:56 +01001274 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001275 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001276 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001277 assert len(inputs) == 2
1278 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001279 accum_dtype = args_dict["acc_type"]
1280 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001281 self.ser, self.rng, a, b, accum_dtype, error_name
1282 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001283
1284 # Invalidate Input/Output list for error if checks.
1285 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001286 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001287 pCount, cCount = op["operands"]
1288 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001289 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1290 self, error_name, input_list, output_list
1291 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001292
Les Bell729b0352021-11-24 10:28:21 +00001293 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001294 self.ser,
1295 validator_fcns,
1296 error_name,
1297 op=op,
1298 input_shape=a.shape,
1299 input_dtype=a.dtype,
1300 input2_shape=b.shape,
1301 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001302 output_shape=result_tensor.shape,
1303 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001304 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001305 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001306 input_list=input_list,
1307 output_list=output_list,
1308 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001309 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001310 ):
1311 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001312
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001313 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001314 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001315
1316 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001317
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001318 compliance = self.tensorComplianceMetaData(
1319 op, a.dtype, args_dict, result_tensor, error_name
1320 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001321
1322 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001323
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001324 def build_reduce(
1325 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1326 ):
1327 assert len(inputs) == 1
1328 a = inputs[0]
1329 axis = args_dict["axis"]
1330 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001331
1332 # Invalidate Input/Output list for error if checks.
1333 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001334 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001335 pCount, cCount = op["operands"]
1336 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001337 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1338 self, error_name, input_list, output_list
1339 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001340
Les Bell729b0352021-11-24 10:28:21 +00001341 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001342 self.ser,
1343 validator_fcns,
1344 error_name,
1345 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001346 axis=axis,
1347 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001348 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001349 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001350 output_dtype=result_tensor.dtype,
1351 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001352 input_list=input_list,
1353 output_list=output_list,
1354 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001355 ):
1356 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001357
1358 attr = ts.TosaSerializerAttribute()
1359 attr.AxisAttribute(axis)
1360
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001361 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001362
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001363 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1364 # Number of products - needed for compliance
1365 args_dict["n"] = a.shape[axis]
1366
1367 compliance = self.tensorComplianceMetaData(
1368 op, a.dtype, args_dict, result_tensor, error_name
1369 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001370
1371 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001372
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001373 def build_clamp(
1374 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1375 ):
1376 assert len(inputs) == 1
1377 a = inputs[0]
1378
1379 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001380
Jeremy Johnson18e26662021-07-22 16:15:29 +01001381 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001382
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001383 if error_name == ErrorIf.MaxSmallerMin:
1384 # Make sure the numbers are different to invoke this error
1385 while v[0] == v[1]:
1386 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1387 max_val = min(v)
1388 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001389 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001390 max_val = max(v)
1391 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001392
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001393 # Invalidate Input/Output list for error if checks.
1394 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001395 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001396 pCount, cCount = op["operands"]
1397 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001398 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1399 self, error_name, input_list, output_list
1400 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001401
Les Bell729b0352021-11-24 10:28:21 +00001402 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001403 self.ser,
1404 validator_fcns,
1405 error_name,
1406 op=op,
1407 max_val=max_val,
1408 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001409 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001410 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001411 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001412 output_dtype=result_tensor.dtype,
1413 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001414 input_list=input_list,
1415 output_list=output_list,
1416 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001417 ):
1418 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001419
1420 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001421 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1422 if a.dtype == DType.FP16:
1423 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1424 min_val = min_val.astype(np.float32)
1425 max_val = max_val.astype(np.float32)
1426
1427 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001428 elif a.dtype in (DType.INT8, DType.INT16):
James Ward34071252022-12-07 15:48:47 +00001429 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Won Jeon2c34b462024-02-06 18:37:00 +00001430 else:
1431 # to avoid internal error for incorrect input types
1432 attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001433
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001434 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001435
1436 compliance = self.tensorComplianceMetaData(
1437 op, a.dtype, args_dict, result_tensor, error_name
1438 )
1439
1440 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001441
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001442 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1443 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001444 attr = ts.TosaSerializerAttribute()
1445
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001446 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001447
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001448 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001449 return result_tens
1450
1451 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001452 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1453 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001454
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001455 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001456 return result_tens
1457
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001458 def build_activation(
1459 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1460 ):
1461 assert len(inputs) == 1
1462 a = inputs[0]
1463
1464 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001465
1466 # Invalidate Input/Output list for error if checks.
1467 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001468 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001469 pCount, cCount = op["operands"]
1470 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001471 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1472 self, error_name, input_list, output_list
1473 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001474
Les Bell729b0352021-11-24 10:28:21 +00001475 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001476 self.ser,
1477 validator_fcns,
1478 error_name,
1479 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001480 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001481 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001482 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001483 output_dtype=result_tensor.dtype,
1484 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001485 input_list=input_list,
1486 output_list=output_list,
1487 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001488 ):
1489 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001490
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001491 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001492
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001493 compliance = self.tensorComplianceMetaData(
1494 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001495 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001496
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001497 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001498
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001499 def build_concat(
1500 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1501 ):
Won Jeon74342e52024-01-09 00:34:40 +00001502 if op["op"] == Op.CONCAT_SHAPE:
1503 axis = 0
1504 else:
1505 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001506 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001507 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001508
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001509 result_tensor = OutputShaper.concatOp(
1510 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001511 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001512
Matthew Haddon818ab902021-07-27 09:12:49 +01001513 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001514 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001515 input_tensor_names.append(tensor.name)
1516
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001517 # Invalidate Input/Output list for error if checks.
1518 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001519 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001520 pCount, cCount = op["operands"]
1521 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001522 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1523 self, error_name, input_list, output_list
1524 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001525
Les Bell729b0352021-11-24 10:28:21 +00001526 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001527 self.ser,
1528 validator_fcns,
1529 error_name,
1530 op=op,
1531 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001532 input_shape=inputs[0].shape,
1533 output_shape=result_tensor.shape,
1534 input_dtype=inputs[0].dtype,
1535 output_dtype=result_tensor.dtype,
1536 inputs=inputs,
1537 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001538 input_list=input_list,
1539 output_list=output_list,
1540 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001541 ):
1542 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001543
Won Jeon74342e52024-01-09 00:34:40 +00001544 if op["op"] == Op.CONCAT:
1545 attr = ts.TosaSerializerAttribute()
1546 attr.AxisAttribute(axis)
1547 else:
1548 assert op["op"] == Op.CONCAT_SHAPE
1549 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001550 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001551
1552 compliance = self.tensorComplianceMetaData(
1553 op, inputs[0].dtype, args_dict, result_tensor, error_name
1554 )
1555
1556 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001557
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001558 def build_pad(
1559 self,
1560 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001561 inputs,
1562 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001563 validator_fcns=None,
1564 error_name=None,
1565 qinfo=None,
1566 ):
Tai Lye095da72024-01-25 22:00:18 +00001567 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001568 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001569 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001570 padding = args_dict["pad"]
1571 pad_const_int = args_dict["pad_const_int"]
1572 pad_const_float = args_dict["pad_const_fp"]
1573
1574 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001575
Tai Lye095da72024-01-25 22:00:18 +00001576 # write empty padding into PadAttribute to ensure inputs[1] is used
Kevin Chengfe392ce2021-10-18 21:51:55 +00001577 attr = ts.TosaSerializerAttribute()
Tai Lye095da72024-01-25 22:00:18 +00001578 attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001579
Matthew Haddone807aae2021-10-11 18:12:58 +01001580 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001581 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001582 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001583 pCount, cCount = op["operands"]
1584 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001585 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1586 self, error_name, input_list, output_list
1587 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001588
Les Bell729b0352021-11-24 10:28:21 +00001589 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001590 self.ser,
1591 validator_fcns,
1592 error_name,
1593 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001594 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001595 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001596 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001597 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001598 pad=padding,
1599 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001600 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001601 input_list=input_list,
1602 output_list=output_list,
1603 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001604 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001605 ):
1606 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001607
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001608 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001609
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001610 compliance = self.tensorComplianceMetaData(
1611 op, a.dtype, args_dict, result_tensor, error_name
1612 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001613
1614 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001615
Won Jeona21b2e82023-08-10 10:33:01 +00001616 def build_dim(
1617 self,
1618 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001619 inputs,
1620 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001621 validator_fcns=None,
1622 error_name=None,
1623 qinfo=None,
1624 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001625 assert len(inputs) == 1
1626 a = inputs[0]
1627 axis = args_dict["axis"]
1628 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001629
1630 # Invalidate Input/Output list for error if checks.
1631 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001632 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001633 pCount, cCount = op["operands"]
1634 num_operands = pCount + cCount
1635 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1636 self, error_name, input_list, output_list
1637 )
1638
1639 if not TosaErrorValidator.evValidateErrorIfs(
1640 self.ser,
1641 validator_fcns,
1642 error_name,
1643 op=op,
1644 axis=axis,
1645 input_shape=a.shape,
1646 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001647 output_shape=result_tensor.shape,
1648 output_dtype=result_tensor.dtype,
1649 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001650 input_list=input_list,
1651 output_list=output_list,
1652 num_operands=num_operands,
1653 ):
1654 return None
1655
1656 attr = ts.TosaSerializerAttribute()
1657 attr.AxisAttribute(axis)
1658
1659 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001660 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001661
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001662 def build_reshape(
1663 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1664 ):
Tai Ly8690a082023-12-18 20:40:24 +00001665 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001666 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001667 shape = inputs[1]
1668 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001669 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001670 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001671 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001672
1673 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001674 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001675 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001676 pCount, cCount = op["operands"]
1677 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001678 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1679 self, error_name, input_list, output_list
1680 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001681
Les Bell729b0352021-11-24 10:28:21 +00001682 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001683 self.ser,
1684 validator_fcns,
1685 error_name,
1686 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001687 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001688 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001689 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001690 output_dtype=result_tensor.dtype,
1691 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001692 input_list=input_list,
1693 output_list=output_list,
1694 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001695 ):
1696 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001697
Tai Ly8690a082023-12-18 20:40:24 +00001698 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001699
1700 compliance = self.tensorComplianceMetaData(
1701 op, a.dtype, args_dict, result_tensor, error_name
1702 )
1703
1704 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001705
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001706 def build_reverse(
1707 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1708 ):
1709 assert len(inputs) == 1
1710 a = inputs[0]
1711 axis = args_dict["axis"]
1712 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001713
1714 # Invalidate Input/Output list for error if checks.
1715 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001716 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001717 pCount, cCount = op["operands"]
1718 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001719 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1720 self, error_name, input_list, output_list
1721 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001722
Les Bell729b0352021-11-24 10:28:21 +00001723 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001724 self.ser,
1725 validator_fcns,
1726 error_name,
1727 op=op,
1728 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001729 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001730 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001731 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001732 output_dtype=result_tensor.dtype,
1733 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001734 input_list=input_list,
1735 output_list=output_list,
1736 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001737 ):
1738 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001739
1740 attr = ts.TosaSerializerAttribute()
1741 attr.AxisAttribute(axis)
1742
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001743 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001744 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001745
evacha0198477222024-01-26 12:25:32 +00001746 def build_transpose(
1747 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1748 ):
1749 assert len(inputs) == 1
1750 a = inputs[0]
1751 perms = args_dict["perms"]
1752
1753 result_tensor = OutputShaper.transposeOp(
1754 self.ser, self.rng, a, perms, error_name
1755 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001756
Kevin Chengfe392ce2021-10-18 21:51:55 +00001757 attr = ts.TosaSerializerAttribute()
1758 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001759
Matthew Haddone807aae2021-10-11 18:12:58 +01001760 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001761 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001762 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001763 pCount, cCount = op["operands"]
1764 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001765 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1766 self, error_name, input_list, output_list
1767 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001768
Les Bell729b0352021-11-24 10:28:21 +00001769 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001770 self.ser,
1771 validator_fcns,
1772 error_name,
1773 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001774 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001775 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001776 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001777 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001778 output_dtype=result_tensor.dtype,
1779 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001780 input_list=input_list,
1781 output_list=output_list,
1782 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001783 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001784 ):
1785 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001786
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001787 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001788
1789 compliance = self.tensorComplianceMetaData(
1790 op, a.dtype, args_dict, result_tensor, error_name
1791 )
1792
1793 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001794
evacha017f7d4252024-01-24 12:08:09 +00001795 def build_slice(
1796 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1797 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001798 assert len(inputs) == 3
1799 a, start_var, size_var = inputs
1800 start_const = args_dict["start"]
1801 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001802
1803 result_tensor = OutputShaper.sliceOp(
TatWai Chongf15bad82024-01-31 21:33:27 -08001804 self.ser, self.rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001805 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001806
1807 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001808 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001809 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001810 pCount, cCount = op["operands"]
1811 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001812 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1813 self, error_name, input_list, output_list
1814 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001815
Les Bell729b0352021-11-24 10:28:21 +00001816 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001817 self.ser,
1818 validator_fcns,
1819 error_name,
1820 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001821 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001822 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001823 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001824 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001825 start=start_const,
1826 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001827 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001828 input_list=input_list,
1829 output_list=output_list,
1830 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001831 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001832 ):
1833 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001834
Tai Ly8ead6c42024-02-14 22:35:44 +00001835 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001836
1837 compliance = self.tensorComplianceMetaData(
1838 op, a.dtype, args_dict, result_tensor, error_name
1839 )
1840
1841 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001842
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001843 def build_tile(
1844 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1845 ):
Tai Ly8690a082023-12-18 20:40:24 +00001846 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001847 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001848 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001849 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001850 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001851 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001852 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001853
1854 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001855 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001856 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001857 pCount, cCount = op["operands"]
1858 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001859 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1860 self, error_name, input_list, output_list
1861 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001862
Les Bell729b0352021-11-24 10:28:21 +00001863 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001864 self.ser,
1865 validator_fcns,
1866 error_name,
1867 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001868 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001869 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001870 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001871 output_dtype=result_tensor.dtype,
1872 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001873 input_list=input_list,
1874 output_list=output_list,
1875 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001876 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001877 ):
1878 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001879
Tai Ly8690a082023-12-18 20:40:24 +00001880 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001881
1882 compliance = self.tensorComplianceMetaData(
1883 op, a.dtype, args_dict, result_tensor, error_name
1884 )
1885
1886 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001887
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001888 def build_gather(
1889 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1890 ):
1891 assert len(inputs) == 2
1892 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001893
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001894 result_tensor = OutputShaper.gatherOp(
1895 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001896 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001897
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001898 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001899 input_list = [values.name, indices.name]
1900 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001901 pCount, cCount = op["operands"]
1902 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001903 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1904 self, error_name, input_list, output_list
1905 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001906
Les Bell729b0352021-11-24 10:28:21 +00001907 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001908 self.ser,
1909 validator_fcns,
1910 error_name,
1911 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001912 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001913 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001914 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001915 output_dtype=result_tensor.dtype,
1916 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001917 input_list=input_list,
1918 output_list=output_list,
1919 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001920 ):
1921 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001922
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001923 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001924
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001925 compliance = self.tensorComplianceMetaData(
1926 op, values.dtype, args_dict, result_tensor, error_name
1927 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001928
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001929 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001930
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001931 def build_scatter(
1932 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1933 ):
1934 assert len(inputs) == 3
1935 values_in, indices, input = inputs
1936 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001937 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001938 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001939
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001940 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001941 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001942 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001943 pCount, cCount = op["operands"]
1944 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001945 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1946 self, error_name, input_list, output_list
1947 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001948
Les Bell729b0352021-11-24 10:28:21 +00001949 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001950 self.ser,
1951 validator_fcns,
1952 error_name,
1953 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001954 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001955 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001956 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001957 output_dtype=result_tensor.dtype,
1958 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001959 input_list=input_list,
1960 output_list=output_list,
1961 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001962 ):
1963 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001964
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001965 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001966
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001967 compliance = self.tensorComplianceMetaData(
1968 op, values_in.dtype, args_dict, result_tensor, error_name
1969 )
1970
1971 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001972
Kevin Cheng550ccc52021-03-03 11:21:43 -08001973 def build_resize(
1974 self,
1975 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001976 inputs,
1977 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01001978 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001979 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001980 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001981 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001982 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001983 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001984 scale_input = inputs[1]
1985 offset_input = inputs[2]
1986 border_input = inputs[3]
1987
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001988 mode = args_dict["mode"]
1989 scale = args_dict["scale"]
1990 offset = args_dict["offset"]
1991 border = args_dict["border"]
1992 output_dtype = args_dict["output_dtype"]
1993
1994 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08001995 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001996 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001997 input,
1998 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001999 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002000 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002001 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002002 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002003 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002004 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002005 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002006
Matthew Haddon848efb42021-09-09 12:30:53 +01002007 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002008 input_list = [
2009 input.name,
2010 scale_input.name,
2011 offset_input.name,
2012 border_input.name,
2013 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002014 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002015 pCount, cCount = op["operands"]
2016 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002017 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2018 self, error_name, input_list, output_list
2019 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002020
Les Bell729b0352021-11-24 10:28:21 +00002021 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002022 self.ser,
2023 validator_fcns,
2024 error_name,
2025 op=op,
2026 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002027 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002028 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002029 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002030 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002031 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002032 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002033 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002034 input_list=input_list,
2035 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002036 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002037 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002038 ):
2039 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002040
Eric Kunzee5e26762020-10-13 16:11:07 -07002041 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002042 # write empty scale/offset/border into ResizeAttribute
2043 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002044 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002045
2046 compliance = self.tensorComplianceMetaData(
2047 op, input.dtype, args_dict, result_tensor, error_name
2048 )
2049
2050 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002051
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002052 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
2053 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
2054 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002055 self.ser.addOperator(
2056 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2057 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002058 return result_tens
2059
evacha0198477222024-01-26 12:25:32 +00002060 def build_const(
2061 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2062 ):
2063 assert len(inputs) == 1
2064 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002065 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002066
2067 compliance = self.tensorComplianceMetaData(
2068 op, val.dtype, args_dict, val, error_name
2069 )
2070
2071 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002072
2073 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002074 def build_cast(
2075 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2076 ):
2077 assert len(inputs) == 1
2078 val = inputs[0]
2079 out_dtype = args_dict["out_type"]
2080
2081 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002082 self.ser, self.rng, val, out_dtype, error_name
2083 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002084
2085 # Invalidate Input/Output list for error if checks.
2086 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002087 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002088 pCount, cCount = op["operands"]
2089 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002090 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2091 self, error_name, input_list, output_list
2092 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002093
Les Bell729b0352021-11-24 10:28:21 +00002094 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002095 self.ser,
2096 validator_fcns,
2097 error_name,
2098 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002099 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002100 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002101 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002102 output_dtype=result_tensor.dtype,
2103 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002104 input_list=input_list,
2105 output_list=output_list,
2106 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002107 ):
2108 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002109
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002110 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002111
2112 compliance = self.tensorComplianceMetaData(
2113 op, val.dtype, args_dict, result_tensor, error_name
2114 )
2115
2116 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002117
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002118 def build_rescale(
2119 self,
2120 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002121 inputs,
2122 args_dict,
2123 validator_fcns=None,
2124 error_name=None,
2125 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002126 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002127 assert len(inputs) == 1
2128 val = inputs[0]
2129 out_dtype = args_dict["output_dtype"]
2130 scale32 = args_dict["scale"]
2131 double_round = args_dict["double_round"]
2132 per_channel = args_dict["per_channel"]
2133
2134 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002135 self.ser, self.rng, val, out_dtype, error_name
2136 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002137
2138 if per_channel:
2139 nc = val.shape[-1]
2140 else:
2141 nc = 1
2142
2143 in_type_width = self.typeWidth(val.dtype)
2144 out_type_width = self.typeWidth(out_dtype)
2145
Tai Ly8690a082023-12-18 20:40:24 +00002146 input_unsigned = False
2147 output_unsigned = False
2148
Kevin Cheng3a478572021-01-22 17:21:02 -08002149 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002150 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002151 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002152 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002153 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002154 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002155 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002156 elif error_name in [
2157 ErrorIf.InputZeroPointNotZero,
2158 ErrorIf.U16InputZeroPointNotValid,
2159 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002160 input_zp = self.randInt(-128, 128)
2161 if input_zp == 0:
2162 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002163 in_type_width += 1
2164 elif val.dtype == DType.UINT16:
2165 # Must come after ErrorIf.U16InputZeroPointNotValid check
2166 input_zp = self.rng.choice([0, 32768])
2167 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002168 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002169 else:
2170 input_zp = 0
2171
Kevin Cheng3a478572021-01-22 17:21:02 -08002172 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002173 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002174 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002175 elif out_dtype == DType.UINT8:
2176 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002177 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002178 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002179 elif error_name in [
2180 ErrorIf.OutputZeroPointNotZero,
2181 ErrorIf.U16OutputZeroPointNotValid,
2182 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002183 output_zp = self.randInt(-128, 128)
2184 if output_zp == 0:
2185 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002186 out_type_width += 1
2187 elif out_dtype == DType.UINT16:
2188 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2189 output_zp = self.rng.choice([0, 32768])
2190 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002191 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002192 else:
2193 output_zp = 0
2194
2195 # Calculate scale based on:
2196 # scale = a *(2^output_width)/(2^input_width))
2197
2198 a = np.float32(self.rng.random(size=[nc]))
2199 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2200
2201 if scale32:
2202 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002203 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002204 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2205 else:
2206 # Cap the scaling at 2^15 - 1 for scale16
2207 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2208
Kevin Cheng550ccc52021-03-03 11:21:43 -08002209 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002210
2211 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2212 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002213 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2214 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002215
2216 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002217 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2218 scale_arr[i], scale32
2219 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002220 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2221 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002222
Kevin Cheng550ccc52021-03-03 11:21:43 -08002223 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002224 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002225 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002226 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002227 assert val.placeholderFilename
2228 values = np.load(
2229 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2230 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002231 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2232 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2233 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002234 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2235 # Check we can safely convert to the expected dtype
2236 assert (
2237 val_adj.all() >= np.iinfo(values.dtype).min
2238 and val_adj.all() <= np.iinfo(values.dtype).max
2239 )
2240
2241 # Force casting to output datatype
2242 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2243
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002244 if not np.all(np.array_equal(values, val_adj)):
2245 # Values changed so overwrite file with new values
2246 np.save(
2247 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2248 val_adj,
2249 False,
2250 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002251
Matthew Haddonc2025212021-10-08 21:21:05 +01002252 # Invalidate Input/Output list for error if checks.
2253 input_list = [val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002254 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002255 pCount, cCount = op["operands"]
2256 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002257 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2258 self, error_name, input_list, output_list
2259 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002260
2261 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002262 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002263 self.ser,
2264 validator_fcns,
2265 error_name,
2266 op=op,
2267 input_dtype=val.dtype,
2268 output_dtype=out_dtype,
2269 input_shape=val.shape,
2270 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002271 scale32=scale32,
2272 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002273 input_list=input_list,
2274 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002275 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002276 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002277 ):
2278 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002279
Eric Kunzee5e26762020-10-13 16:11:07 -07002280 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002281 attr.RescaleAttribute(
2282 input_zp,
2283 output_zp,
2284 multiplier_arr,
2285 shift_arr,
2286 scale32,
2287 double_round,
2288 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002289 input_unsigned,
2290 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002291 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002292
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002293 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002294
2295 compliance = self.tensorComplianceMetaData(
2296 op, val.dtype, args_dict, result_tensor, error_name
2297 )
2298
2299 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002300
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002301 def _get_condition_tensor(self, op, cond, error_name):
2302 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002303 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002304 else:
2305 cond_type = DType.BOOL
2306 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2307 choice = self.rng.choice([1, 2])
2308 if choice == 1:
2309 cond_shape = [2]
2310 else:
2311 cond_shape = [1, 2]
2312 else:
2313 # Must be of size 1 (rank 0)
2314 cond_shape = []
2315 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2316 return cond_tens
2317
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002318 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002319 self,
2320 op,
2321 inputs,
2322 args_dict,
2323 validator_fcns=None,
2324 error_name=None,
2325 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002326 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002327 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002328 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002329 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002330 assert len(inputs) == 2
2331 then_tens, else_tens = inputs
2332
2333 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002334
2335 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002336 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002337
2338 # Make then/else tensors
2339 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002340
Jeremy Johnson587cc842024-02-08 11:45:44 +00002341 dtype = DType.INT32
2342
Matthew Haddon630c17c2021-10-14 15:05:41 +01002343 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002344 if error_name in [
2345 ErrorIf.CondIfOutputListThenGraphMismatch,
2346 ErrorIf.CondIfOutputListElseGraphMismatch,
2347 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002348 incorrect_shape = deepcopy(then_tens.shape)
2349 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002350 incorrect_shape[i] += (
2351 self.rng.choice([-3, -2, 2, 3])
2352 if incorrect_shape[i] > 3
2353 else self.rng.choice([1, 2, 4])
2354 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002355 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2356
Jeremy Johnson18e26662021-07-22 16:15:29 +01002357 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2358 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002359
2360 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002361 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002362
2363 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002364 then_block = "THEN_BLOCK"
2365 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002366 attr = ts.TosaSerializerAttribute()
2367 attr.CondIfAttribute(then_block, else_block)
2368
2369 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002370 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002371
Jerry Ge9e94af82022-10-27 09:57:00 -07002372 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002373 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002374 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002375 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002376 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002377 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002378 self.ser.addOutputTensor(then_tens)
2379
Jerry Ge9e94af82022-10-27 09:57:00 -07002380 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002381 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002382 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002383 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002384 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002385 self.ser.addOutputTensor(else_tens)
2386
Les Bell729b0352021-11-24 10:28:21 +00002387 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002388 self.ser,
2389 validator_fcns,
2390 error_name,
2391 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002392 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002393 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002394 ):
2395 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002396
Jeremy Johnson587cc842024-02-08 11:45:44 +00002397 compliance = self.tensorComplianceMetaData(
2398 op, dtype, args_dict, result_tensor, error_name
2399 )
2400
2401 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002402
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002403 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002404 self,
2405 op,
2406 inputs,
2407 args_dict,
2408 validator_fcns=None,
2409 error_name=None,
2410 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002411 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002412 # For cond_if with a binary op in the then/else blocks, take a and b and
2413 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002414 assert len(inputs) == 2
2415 a, b = inputs
2416
2417 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002418
2419 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002420 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002421
Jeremy Johnson587cc842024-02-08 11:45:44 +00002422 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002423
2424 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002425 then_block = "THEN_BLOCK"
2426 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002427 attr = ts.TosaSerializerAttribute()
2428 attr.CondIfAttribute(then_block, else_block)
2429
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002430 if error_name in [
2431 ErrorIf.CondIfInputListThenGraphMismatch,
2432 ErrorIf.CondIfInputListElseGraphMismatch,
2433 ErrorIf.CondIfOutputListElseGraphMismatch,
2434 ErrorIf.CondIfOutputListThenGraphMismatch,
2435 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002436 incorrect_shape = a.shape.copy()
2437 for i in range(len(incorrect_shape)):
2438 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2439 incorrect_block_input = deepcopy(a)
2440 incorrect_block_input.shape = incorrect_shape
2441
Eric Kunzee5e26762020-10-13 16:11:07 -07002442 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002443 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002444 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002445 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002446
James Ward24dbc422022-10-19 12:20:31 +01002447 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002448 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002449 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002450 then_op, else_op = (
2451 self.TOSA_OP_LIST["logical_right_shift"],
2452 self.TOSA_OP_LIST["logical_left_shift"],
2453 )
Les Bell6040b4d2021-10-11 12:50:31 +01002454 else:
2455 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002456
Jeremy Johnson587cc842024-02-08 11:45:44 +00002457 # Determine the element-wise binary operation that compliance will need to
2458 # check the results of
2459 compliance_op = then_op if cond else else_op
2460
2461 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002462 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002463 if (
2464 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2465 and block == then_block
2466 ) or (
2467 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2468 and block == else_block
2469 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002470 self.ser.addInputTensor(incorrect_block_input)
2471 self.ser.addInputTensor(b)
2472 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002473 elif (
2474 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2475 and block == then_block
2476 ) or (
2477 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2478 and block == else_block
2479 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002480 self.ser.addInputTensor(a)
2481 self.ser.addInputTensor(b)
2482 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2483 else:
2484 self.ser.addInputTensor(a)
2485 self.ser.addInputTensor(b)
2486 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002487 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002488
Les Bell729b0352021-11-24 10:28:21 +00002489 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002490 self.ser,
2491 validator_fcns,
2492 error_name,
2493 op=op,
2494 a=a,
2495 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002496 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002497 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002498 ):
2499 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002500
Jeremy Johnson587cc842024-02-08 11:45:44 +00002501 compliance = self.tensorComplianceMetaData(
2502 compliance_op, a.dtype, args_dict, result_tensor, error_name
2503 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002504
Jeremy Johnson587cc842024-02-08 11:45:44 +00002505 return TosaTestGen.BuildInfo(result_tensor, compliance)
2506
2507 def build_while_loop(
2508 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2509 ):
2510 assert len(inputs) == 1
2511 a = inputs[0]
2512 iter_val = args_dict["iterations"]
2513
Kevin Cheng550ccc52021-03-03 11:21:43 -08002514 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002515
Kevin Cheng550ccc52021-03-03 11:21:43 -08002516 cond_block = "COND_BLOCK"
2517 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002518
2519 attr = ts.TosaSerializerAttribute()
2520 attr.WhileLoopAttribute(cond_block, body_block)
2521
2522 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002523 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002524 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002525 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002526
2527 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002528 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2529 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002530 if error_name == ErrorIf.InputListOutputListMismatch:
2531 incorrect_acc = deepcopy(acc)
2532 for i in range(len(incorrect_acc.shape)):
2533 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2534 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2535 else:
2536 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002537
2538 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002539 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002540 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002541 [iter.name, a.name, acc.name],
2542 [iter_out.name, a_out.name, acc_out.name],
2543 attr,
2544 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002545 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002546
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002547 if error_name in [
2548 ErrorIf.InputListCondGraphMismatch,
2549 ErrorIf.InputListBodyGraphInputMismatch,
2550 ErrorIf.InputListBodyGraphOutputMismatch,
2551 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002552 incorrect_iter = deepcopy(iter)
2553 for i in range(len(incorrect_iter.shape)):
2554 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2555 if len(incorrect_iter.shape) == 0:
2556 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2557
2558 incorrect_acc = deepcopy(acc)
2559 for i in range(len(incorrect_acc.shape)):
2560 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2561
Eric Kunzee5e26762020-10-13 16:11:07 -07002562 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002563 self.ser.addBasicBlock(cond_block)
2564
Matthew Haddon630c17c2021-10-14 15:05:41 +01002565 if error_name == ErrorIf.InputListCondGraphMismatch:
2566 self.ser.addInputTensor(incorrect_iter)
2567 self.ser.addInputTensor(a)
2568 self.ser.addInputTensor(incorrect_acc)
2569 else:
2570 self.ser.addInputTensor(iter)
2571 self.ser.addInputTensor(a)
2572 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002573 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002574
2575 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002576 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002577 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002578 cond_type = DType.BOOL
2579 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2580 choice = self.rng.choice([1, 2])
2581 if choice == 1:
2582 cond_shape = [3]
2583 else:
2584 cond_shape = [1, 2]
2585 else:
2586 cond_shape = []
2587 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002588
Kevin Cheng550ccc52021-03-03 11:21:43 -08002589 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002590
2591 # BODY block (input: a, acc, iter, output: a, acc, iter)
2592 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002593 self.ser.addBasicBlock(body_block)
2594
Matthew Haddon630c17c2021-10-14 15:05:41 +01002595 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2596 self.ser.addInputTensor(incorrect_iter)
2597 self.ser.addInputTensor(a)
2598 self.ser.addInputTensor(incorrect_acc)
2599 else:
2600 self.ser.addInputTensor(iter)
2601 self.ser.addInputTensor(a)
2602 self.ser.addInputTensor(acc)
2603
Kevin Cheng550ccc52021-03-03 11:21:43 -08002604 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002605
2606 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002607 iter_body_out = self.ser.addIntermediate(
2608 incorrect_iter.shape, incorrect_iter.dtype
2609 )
2610 acc_body_out = self.ser.addIntermediate(
2611 incorrect_acc.shape, incorrect_acc.dtype
2612 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002613 else:
2614 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2615 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2616
Eric Kunzee5e26762020-10-13 16:11:07 -07002617 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2618 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2619 self.ser.addOutputTensor(iter_body_out)
2620 self.ser.addOutputTensor(a)
2621 self.ser.addOutputTensor(acc_body_out)
2622
Les Bell729b0352021-11-24 10:28:21 +00002623 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002624 self.ser,
2625 validator_fcns,
2626 error_name,
2627 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002628 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002629 ):
2630 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002631
Jeremy Johnson587cc842024-02-08 11:45:44 +00002632 compliance = self.tensorComplianceMetaData(
2633 op, a.dtype, args_dict, acc_out, error_name
2634 )
2635
2636 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002637
Luke Hutton57287132023-02-06 14:54:18 +00002638 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002639 self,
2640 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002641 inputs,
2642 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002643 validator_fcns=None,
2644 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002645 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002646 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002647 assert len(inputs) == 2
2648 val1, val2 = inputs
2649 inverse = args_dict["inverse"]
2650
Luke Hutton57287132023-02-06 14:54:18 +00002651 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2652
2653 input_names = [val1.name, val2.name]
2654 pCount, cCount = op["operands"]
2655 num_operands = pCount + cCount
2656
2657 output_names = [res.name for res in results]
2658 output_shapes = [res.shape for res in results]
2659 output_dtypes = [res.dtype for res in results]
2660
2661 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2662 self, error_name, input_names, output_names
2663 )
2664
2665 if not TosaErrorValidator.evValidateErrorIfs(
2666 self.ser,
2667 validator_fcns,
2668 error_name,
2669 op=op,
2670 inverse=inverse,
2671 input1=val1,
2672 input2=val2,
2673 input_shape=val1.shape,
2674 input_dtype=val1.dtype,
2675 output_shape=output_shapes,
2676 output_dtype=output_dtypes,
2677 result_tensors=results,
2678 input_list=input_names,
2679 output_list=output_names,
2680 num_operands=num_operands,
2681 ):
2682 return None
2683
Tai Lyd3797f02023-11-15 23:06:19 +00002684 # TODO - Test local_bound, for now set local bound attribute to False
2685 local_bound = False
2686
Luke Hutton57287132023-02-06 14:54:18 +00002687 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002688 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002689
2690 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002691
2692 compliance = []
2693 for res in results:
2694 compliance.append(
2695 self.tensorComplianceMetaData(
2696 op, val1.dtype, args_dict, res, error_name
2697 )
2698 )
2699
2700 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002701
Tai Lyd3797f02023-11-15 23:06:19 +00002702 def build_rfft2d(
2703 self,
2704 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002705 inputs,
2706 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002707 validator_fcns=None,
2708 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002709 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002710 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002711 assert len(inputs) == 1
2712 val = inputs[0]
Luke Hutton261b7b62023-01-10 14:50:31 +00002713 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2714
2715 input_names = [val.name]
2716 pCount, cCount = op["operands"]
2717 num_operands = pCount + cCount
2718
2719 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002720 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002721 output_dtypes = [res.dtype for res in results]
2722
2723 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2724 self, error_name, input_names, output_names
2725 )
2726
2727 if not TosaErrorValidator.evValidateErrorIfs(
2728 self.ser,
2729 validator_fcns,
2730 error_name,
2731 op=op,
2732 input_shape=val.shape,
2733 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002734 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002735 output_dtype=output_dtypes,
2736 result_tensors=results,
2737 input_list=input_names,
2738 output_list=output_names,
2739 num_operands=num_operands,
2740 ):
2741 return None
2742
Tai Lyd3797f02023-11-15 23:06:19 +00002743 # TODO - Test local_bound, for now set local bound attribute to False
2744 local_bound = False
2745
2746 attr = ts.TosaSerializerAttribute()
2747 attr.RFFTAttribute(local_bound)
2748
2749 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002750
2751 compliance = []
2752 for res in results:
2753 compliance.append(
2754 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2755 )
2756
2757 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002758
Won Jeon74342e52024-01-09 00:34:40 +00002759 def build_shape_op(
2760 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2761 ):
2762 assert len(inputs) == 2
2763 a, b = inputs
2764
2765 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2766
2767 # Invalidate Input/Output list for error if checks.
2768 input_list = [a.name, b.name]
2769 output_list = [result_tensor.name]
2770 pCount, cCount = op["operands"]
2771 num_operands = pCount + cCount
2772 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2773 self, error_name, input_list, output_list
2774 )
2775
2776 if not TosaErrorValidator.evValidateErrorIfs(
2777 self.ser,
2778 validator_fcns,
2779 error_name,
2780 op=op,
2781 input1=a,
2782 input2=b,
2783 input_shape=a.shape,
2784 input_dtype=a.dtype,
2785 output_shape=result_tensor.shape,
2786 output_dtype=result_tensor.dtype,
2787 result_tensors=[result_tensor],
2788 input_list=input_list,
2789 output_list=output_list,
2790 num_operands=num_operands,
2791 ):
2792 return None
2793
2794 self.ser.addOperator(
2795 op["op"],
2796 input_list,
2797 output_list,
2798 )
2799 compliance = self.tensorComplianceMetaData(
2800 op, a.dtype, args_dict, result_tensor, error_name
2801 )
2802
2803 return TosaTestGen.BuildInfo(result_tensor, compliance)
2804
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002805 def create_filter_lists(
2806 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2807 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002808 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2809 default_test_rank_range = range(1, 5)
2810 if not shapeFilter:
2811 shapeFilter = [None]
2812
2813 # Calculate the filters based on what is requested and what the operator allows
2814 rmin, rmax = op["rank"]
2815 if rankFilter is not None:
2816 cleanRankFilter = []
2817 # Ensure rankFilter values are allowed by operator
2818 for rank in rankFilter:
2819 if rank >= rmin and rank <= rmax:
2820 cleanRankFilter.append(rank)
2821 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002822 # Ensure default behaviour is bounded by default range or by operator,
2823 # whichever is the smaller range of ranks.
2824 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002825 cleanRankFilter = (
2826 opRankRange
2827 if len(opRankRange) <= len(default_test_rank_range)
2828 else default_test_rank_range
2829 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002830 else:
2831 cleanRankFilter = range(rmin, rmax + 1)
2832
2833 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002834
Matthew Haddon1c00b712021-10-01 15:51:03 +01002835 if dtypeFilter is not None:
2836 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002837 # Create list of operator dtypes filtered by requested dtypes
2838 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002839 if dtype in dtypeFilter or (
2840 isinstance(dtype, list) and dtype[0] in dtypeFilter
2841 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002842 cleanDtypeFilter.append(dtype)
2843 else:
2844 cleanDtypeFilter = dtypes
2845
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002846 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002847 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002848 "shapeFilter": shapeFilter,
2849 "rankFilter": cleanRankFilter,
2850 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002851 }
2852 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002853 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002854 if validator is not None:
2855 validator_info = validator(check=False, op=op)
2856 else:
2857 return None
2858
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002859 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002860
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002861 # Set parameters as required
2862 if error_arguments["rank"] is not None:
2863 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002864 else:
2865 rankFilter = cleanRankFilter
2866
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002867 if error_arguments["dtype"] is not None:
2868 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002869 else:
2870 dtypeFilter = cleanDtypeFilter
2871
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002872 if error_arguments["shape"] is not None:
2873 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002874 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002875 shapeFilter = shapeFilter[
2876 :2
2877 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002878
2879 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002880 "shapeFilter": shapeFilter,
2881 "rankFilter": rankFilter,
2882 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002883 }
2884 return filterDict
2885
Kevin Cheng550ccc52021-03-03 11:21:43 -08002886 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002887 self,
2888 opName,
2889 shapeFilter=[None],
2890 rankFilter=None,
2891 dtypeFilter=None,
2892 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002893 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002894
2895 try:
2896 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002897 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002898 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002899
2900 # Initialize a new random number generator
2901 self.rng = np.random.default_rng(self.random_seed)
2902
Jeremy Johnson1271c442023-09-05 11:39:26 +01002903 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002904
Eric Kunzee5e26762020-10-13 16:11:07 -07002905 # Test list consists of a tuple of:
2906 # (opName, testNameStr, dtype, shapeList, argumentsList)
2907 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002908 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002909 error_if_validators = op["error_if_validators"]
2910 else:
2911 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002912
Matthew Haddon1c00b712021-10-01 15:51:03 +01002913 for validator in error_if_validators:
2914 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002915 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002916 else:
2917 error_name = None
2918
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002919 filterDict = self.create_filter_lists(
2920 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2921 )
2922 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002923 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002924 cleanRankFilter = filterDict["rankFilter"]
2925 cleanDtypeFilter = filterDict["dtypeFilter"]
2926 cleanShapeFilter = filterDict["shapeFilter"]
2927 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002928
2929 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002930 for t in cleanDtypeFilter:
2931 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002932 # Filter out by rank
2933 if shape is not None and len(shape) != r:
2934 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002935 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002936 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002937
Matthew Haddon74567092021-07-16 15:38:20 +01002938 shapeStr = self.shapeStr(shapeList[0])
2939 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002940
Matthew Haddon74567092021-07-16 15:38:20 +01002941 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2942 argList = []
2943 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002944 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002945 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002946 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002947
Matthew Haddon74567092021-07-16 15:38:20 +01002948 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002949 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002950 if argStr:
2951 testStr = "{}_{}_{}_{}".format(
2952 opName, shapeStr, typeStr, argStr
2953 )
2954 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002955 testStr = "{}_{}_{}".format(
2956 opName, shapeStr, typeStr
2957 )
2958 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002959 if argStr:
2960 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2961 opName, error_name, shapeStr, typeStr, argStr
2962 )
2963 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002964 testStr = "{}_ERRORIF_{}_{}_{}".format(
2965 opName, error_name, shapeStr, typeStr
2966 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002967
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002968 testList.append(
2969 (opName, testStr, t, error_name, shapeList, args)
2970 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002971
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002972 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002973 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2974 if "invalid_test_validators" in op:
2975 invalid_test_validators = op["invalid_test_validators"]
2976 clean_testList = []
2977 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002978 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002979 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002980 if validator_fcn(
2981 opName=test[0],
2982 input_dtype=test[2],
2983 shapeList=test[4],
2984 args=test[5],
2985 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002986 remove_test = True
2987 if not remove_test:
2988 clean_testList.append(test)
2989 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002990
2991 return testList
2992
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002993 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002994 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002995 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002996 try:
2997 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002998 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002999 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003000
Jeremy Johnson0c716862023-04-13 17:18:19 +01003001 if self.args.verbose:
3002 print(f"Creating {testStr}")
3003
Eric Kunzee5e26762020-10-13 16:11:07 -07003004 # Create a serializer
3005 self.createSerializer(opName, testStr)
3006
Jeremy Johnson1271c442023-09-05 11:39:26 +01003007 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003008 if "error_if_validators" in op:
3009 error_if_validators = op["error_if_validators"]
3010 else:
3011 error_if_validators = None
3012
Kevin Cheng550ccc52021-03-03 11:21:43 -08003013 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003014 num_operands = pCount + cCount
3015
3016 if isinstance(dtype_or_dtypeList, list):
3017 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003018 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003019 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003020 else:
3021 dtypeList = [dtype_or_dtypeList] * (num_operands)
3022
Won Jeon74342e52024-01-09 00:34:40 +00003023 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003024 assert (
3025 len(shapeList) == num_operands
3026 ), "shapeList length {} must match number of operands {}".format(
3027 len(shapeList), num_operands
3028 )
3029 assert (
3030 len(dtypeList) == num_operands
3031 ), "dtypeList length {} must match number of operands {}".format(
3032 len(dtypeList), num_operands
3033 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003034
3035 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003036 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003037 except KeyError:
3038 qgen = None
3039
3040 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003041
Matthew Haddon1c00b712021-10-01 15:51:03 +01003042 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003043 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003044 else:
3045 qinfo = None
3046
Jeremy Johnson1271c442023-09-05 11:39:26 +01003047 # Extra meta data for the desc.json
3048 tensMeta = {}
3049
Jeremy Johnson587cc842024-02-08 11:45:44 +00003050 # Check we are using the new interface with an argsDict dictionary
3051 assert isinstance(
3052 argsDict, dict
3053 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003054
Jeremy Johnson587cc842024-02-08 11:45:44 +00003055 # New interface with args info in dictionary
3056 assert "dg_type" in argsDict
3057 tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name)
3058 if tvgInfo.dataGenDict:
3059 tensMeta["data_gen"] = tvgInfo.dataGenDict
3060 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003061
Jeremy Johnson587cc842024-02-08 11:45:44 +00003062 result = build_fcn(
3063 self,
3064 op,
3065 tens,
3066 argsDict,
3067 validator_fcns=error_if_validators,
3068 error_name=error_name,
3069 qinfo=qinfo,
3070 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003071
Jeremy Johnson1271c442023-09-05 11:39:26 +01003072 if result:
Les Bell729b0352021-11-24 10:28:21 +00003073 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003074 if isinstance(result, TosaTestGen.BuildInfo):
3075 # Add the compliance meta data (if any)
3076 compliance = result.getComplianceInfo()
3077 if compliance:
3078 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003079 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00003080 else:
3081 # The test is not valid
3082 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01003083
Eric Kunzee5e26762020-10-13 16:11:07 -07003084 def createDynamicOpLists(self):
3085
Jeremy Johnson00423432022-09-12 17:27:37 +01003086 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3087 # Already created these lists (can occur when class is initialized more than once)
3088 return
3089
Eric Kunzee5e26762020-10-13 16:11:07 -07003090 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003091 if not self.args.level8k:
3092 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3093 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3094 else:
3095 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3096 KERNELS_2D = [[1, bigK], [bigK, 2]]
3097 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003098
Kevin Cheng1533b852021-09-01 12:51:58 -07003099 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003100 testName = "conv2d_{}x{}".format(k[0], k[1])
3101 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3102 self.TOSA_OP_LIST[testName]["filter"] = k
3103 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003104
Kevin Cheng550ccc52021-03-03 11:21:43 -08003105 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3106 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3107 "depthwise_conv2d_TEMPLATE"
3108 ].copy()
3109 self.TOSA_OP_LIST[testName]["filter"] = k
3110 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003111
Kevin Cheng550ccc52021-03-03 11:21:43 -08003112 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3113 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3114 "transpose_conv2d_TEMPLATE"
3115 ].copy()
3116 self.TOSA_OP_LIST[testName]["filter"] = k
3117 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003118
Kevin Cheng1533b852021-09-01 12:51:58 -07003119 for k in KERNELS_3D:
3120 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3121 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3122 self.TOSA_OP_LIST[testName]["filter"] = k
3123 self.TOSA_OP_LIST[testName]["template"] = False
3124
Eric Kunzee5e26762020-10-13 16:11:07 -07003125 # Delete any templates after having created any dynamic ops
3126 # This is a two-pass operation because it's bad practice to delete
3127 # keys from dictionaries while iterating
3128 keyList = []
3129 for k in self.TOSA_OP_LIST:
3130 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003131 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003132 keyList.append(k)
3133 continue
3134 except KeyError:
3135 pass
3136
3137 for k in keyList:
3138 del self.TOSA_OP_LIST[k]
3139
3140 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003141 """Fill in default fields for ops if they aren't already specified.
3142 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003143 for op in self.TOSA_OP_LIST:
3144
3145 # Required fields
3146 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003147 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003148 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003149 raise Exception(
3150 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3151 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003152
3153 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003154 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003155 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003156 raise Exception(
3157 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3158 op
3159 )
3160 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003161
3162 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003163 _ = self.TOSA_OP_LIST[op]["types"]
3164 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003165 raise Exception(
3166 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3167 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003168
3169 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003170 _ = self.TOSA_OP_LIST[op]["op"]
3171 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003172 raise Exception(
3173 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3174 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003175
3176 # Put in default rank range, if missing
3177 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003178 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003179 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003180 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003181
3182 # Tensor operator list
3183 # 'op': op name
3184 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003185 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3186 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003187 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3188 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003189 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003190
Kevin Cheng550ccc52021-03-03 11:21:43 -08003191 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003192 TYPE_INT_FP = [
3193 DType.INT8,
3194 DType.INT16,
3195 DType.INT32,
3196 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003197 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003198 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003199 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003200
Kevin Cheng550ccc52021-03-03 11:21:43 -08003201 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003202 TYPE_FI32 = [
3203 DType.FP32,
3204 DType.FP16,
3205 DType.BF16,
3206 DType.INT32,
3207 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003208 TYPE_FIB = [
3209 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003210 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003211 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003212 DType.INT8,
3213 DType.INT16,
3214 DType.INT32,
3215 DType.BOOL,
3216 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003217 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003218
Won Jeon2c34b462024-02-06 18:37:00 +00003219 TYPE_NARROW_INT_FP = [
3220 DType.INT8,
3221 DType.INT16,
3222 DType.FP16,
3223 DType.BF16,
3224 DType.FP32,
3225 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003226
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003227 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003228 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003229 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003230 [DType.INT8, DType.INT8, DType.INT32],
3231 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003232 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003233 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003234 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003235 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003236 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3237 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003238 ]
3239
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003240 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003241
3242 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003243 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003244 "argmax": {
3245 "op": Op.ARGMAX,
3246 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003247 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003248 "build_fcn": (
3249 build_argmax,
3250 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003251 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003252 TosaArgGen.agAxis,
3253 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003254 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003255 "error_if_validators": (
3256 TosaErrorValidator.evAxisSmallerZero,
3257 TosaErrorValidator.evAxisLargerRank,
3258 TosaErrorValidator.evArgmaxOutputRankMismatch,
3259 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3260 TosaErrorValidator.evWrongRank,
3261 TosaErrorValidator.evWrongInputType,
3262 TosaErrorValidator.evWrongOutputType,
3263 TosaErrorValidator.evWrongInputList,
3264 TosaErrorValidator.evWrongOutputList,
3265 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003266 "data_gen": {
3267 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3268 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003269 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003270 "avg_pool2d": {
3271 "op": Op.AVG_POOL2D,
3272 "operands": (1, 0),
3273 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003274 "build_fcn": (
3275 build_pool2d,
3276 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003277 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003278 TosaArgGen.agPooling,
3279 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003280 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003281 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003282 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003283 "error_if_validators": (
3284 TosaErrorValidator.evKernelSmallerOne,
3285 TosaErrorValidator.evStrideSmallerOne,
3286 TosaErrorValidator.evPadSmallerZero,
3287 TosaErrorValidator.evWrongRank,
3288 TosaErrorValidator.evWrongInputType,
3289 TosaErrorValidator.evWrongOutputType,
3290 TosaErrorValidator.evWrongInputList,
3291 TosaErrorValidator.evWrongOutputList,
3292 TosaErrorValidator.evInputZeroPointNotZero,
3293 TosaErrorValidator.evOutputZeroPointNotZero,
3294 TosaErrorValidator.evPadLargerEqualKernel,
3295 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003296 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003297 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003298 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003299 "data_gen": {
3300 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3301 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003302 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003303 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003304 "conv2d_TEMPLATE": {
3305 "op": Op.CONV2D,
3306 "operands": (1, 2),
3307 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003308 "build_fcn": (
3309 build_conv2d,
3310 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003311 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003312 TosaArgGen.agConv,
3313 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003314 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003315 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003316 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3317 "error_if_validators": (
3318 TosaErrorValidator.evWrongInputType,
3319 TosaErrorValidator.evWrongOutputType,
3320 TosaErrorValidator.evWrongInputList,
3321 TosaErrorValidator.evWrongOutputList,
3322 TosaErrorValidator.evInputZeroPointNotZero,
3323 TosaErrorValidator.evWeightZeroPointNotZero,
3324 TosaErrorValidator.evPadSmallerZero,
3325 TosaErrorValidator.evStrideSmallerOne,
3326 TosaErrorValidator.evDilationSmallerOne,
3327 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003328 TosaErrorValidator.evConvOutputShapeMismatch,
3329 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003330 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003331 "data_gen": {
3332 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3333 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003334 "template": True,
3335 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003336 # Templated operator. Filled in by createDynamicOpLists
3337 "conv3d_TEMPLATE": {
3338 "op": Op.CONV3D,
3339 "operands": (1, 2),
3340 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003341 "build_fcn": (
3342 build_conv3d,
3343 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003344 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003345 TosaArgGen.agConv,
3346 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003347 "qgen": TosaQuantGen.qgConv,
3348 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003349 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3350 "error_if_validators": (
3351 TosaErrorValidator.evWrongInputType,
3352 TosaErrorValidator.evWrongOutputType,
3353 TosaErrorValidator.evWrongInputList,
3354 TosaErrorValidator.evWrongOutputList,
3355 TosaErrorValidator.evInputZeroPointNotZero,
3356 TosaErrorValidator.evWeightZeroPointNotZero,
3357 TosaErrorValidator.evPadSmallerZero,
3358 TosaErrorValidator.evStrideSmallerOne,
3359 TosaErrorValidator.evDilationSmallerOne,
3360 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003361 TosaErrorValidator.evConvOutputShapeMismatch,
3362 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003363 ),
evacha0147ab1762024-01-29 13:23:23 +00003364 "data_gen": {
3365 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3366 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003367 "template": True,
3368 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003369 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003370 "depthwise_conv2d_TEMPLATE": {
3371 "op": Op.DEPTHWISE_CONV2D,
3372 "operands": (1, 2),
3373 "filter": [1, 1],
3374 "rank": (4, 4),
3375 "build_fcn": (
3376 build_depthwise_conv2d,
3377 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003378 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003379 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003380 ),
3381 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003382 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003383 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3384 "error_if_validators": (
3385 TosaErrorValidator.evWrongInputType,
3386 TosaErrorValidator.evWrongOutputType,
3387 TosaErrorValidator.evWrongInputList,
3388 TosaErrorValidator.evWrongOutputList,
3389 TosaErrorValidator.evInputZeroPointNotZero,
3390 TosaErrorValidator.evWeightZeroPointNotZero,
3391 TosaErrorValidator.evPadSmallerZero,
3392 TosaErrorValidator.evStrideSmallerOne,
3393 TosaErrorValidator.evDilationSmallerOne,
3394 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003395 TosaErrorValidator.evConvOutputShapeMismatch,
3396 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003397 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003398 "data_gen": {
3399 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3400 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003401 "template": True,
3402 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003403 "fully_connected": {
3404 "op": Op.FULLY_CONNECTED,
3405 "operands": (1, 2),
3406 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003407 "build_fcn": (
3408 build_fully_connected,
3409 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003410 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003411 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003412 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003413 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003414 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003415 "error_if_validators": (
3416 TosaErrorValidator.evInputZeroPointNotZero,
3417 TosaErrorValidator.evWeightZeroPointNotZero,
3418 TosaErrorValidator.evWrongRank,
3419 TosaErrorValidator.evWrongInputType,
3420 TosaErrorValidator.evWrongOutputType,
3421 TosaErrorValidator.evWrongInputList,
3422 TosaErrorValidator.evWrongOutputList,
3423 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003424 "data_gen": {
3425 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3426 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003427 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003428 "matmul": {
3429 "op": Op.MATMUL,
3430 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003431 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003432 "build_fcn": (
3433 build_matmul,
3434 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003435 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003436 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003437 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003438 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003439 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003440 "error_if_validators": (
3441 TosaErrorValidator.evInputZeroPointNotZero,
3442 TosaErrorValidator.evWrongRank,
3443 TosaErrorValidator.evWrongInputType,
3444 TosaErrorValidator.evWrongOutputType,
3445 TosaErrorValidator.evWrongInputList,
3446 TosaErrorValidator.evWrongOutputList,
3447 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003448 "data_gen": {
3449 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003450 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003451 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003452 "max_pool2d": {
3453 "op": Op.MAX_POOL2D,
3454 "operands": (1, 0),
3455 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003456 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003457 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003458 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003459 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003460 TosaArgGen.agPooling,
3461 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003462 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003463 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003464 "error_if_validators": (
3465 TosaErrorValidator.evKernelSmallerOne,
3466 TosaErrorValidator.evStrideSmallerOne,
3467 TosaErrorValidator.evPadSmallerZero,
3468 TosaErrorValidator.evWrongRank,
3469 TosaErrorValidator.evWrongInputType,
3470 TosaErrorValidator.evWrongOutputType,
3471 TosaErrorValidator.evWrongInputList,
3472 TosaErrorValidator.evWrongOutputList,
3473 TosaErrorValidator.evPadLargerEqualKernel,
3474 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003475 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003476 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003477 "data_gen": {
3478 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3479 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003480 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003481 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003482 "transpose_conv2d_TEMPLATE": {
3483 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003484 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003485 "rank": (4, 4),
3486 "build_fcn": (
3487 build_transpose_conv2d,
3488 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003489 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003490 TosaArgGen.agTransposeConv2D,
3491 ),
3492 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003493 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003494 "invalid_test_validators": (
3495 TosaInvalidValidator.ivHeightWidthInvalid,
3496 TosaInvalidValidator.ivNonPositiveOutputShape,
3497 ),
3498 "error_if_validators": (
3499 TosaErrorValidator.evWrongInputType,
3500 TosaErrorValidator.evWrongOutputType,
3501 TosaErrorValidator.evWrongInputList,
3502 TosaErrorValidator.evWrongOutputList,
3503 TosaErrorValidator.evInputZeroPointNotZero,
3504 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003505 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003506 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003507 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003508 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003509 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003510 "data_gen": {
3511 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3512 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003513 "template": True,
3514 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003515 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003516 "clamp": {
3517 "op": Op.CLAMP,
3518 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003519 "build_fcn": (
3520 build_clamp,
3521 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003522 TosaTensorValuesGen.tvgLazyGenDefault,
3523 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003524 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003525 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003526 "error_if_validators": (
3527 TosaErrorValidator.evMaxSmallerMin,
3528 TosaErrorValidator.evWrongInputType,
3529 TosaErrorValidator.evWrongOutputType,
3530 TosaErrorValidator.evWrongInputList,
3531 TosaErrorValidator.evWrongOutputList,
3532 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003533 "data_gen": {
3534 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3535 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003536 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003537 "sigmoid": {
3538 "op": Op.SIGMOID,
3539 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003540 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003541 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003542 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003543 TosaTensorValuesGen.tvgLazyGenDefault,
3544 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003545 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003546 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003547 "error_if_validators": (
3548 TosaErrorValidator.evWrongInputType,
3549 TosaErrorValidator.evWrongOutputType,
3550 TosaErrorValidator.evWrongInputList,
3551 TosaErrorValidator.evWrongOutputList,
3552 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003553 "data_gen": {
3554 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3555 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003556 },
3557 "tanh": {
3558 "op": Op.TANH,
3559 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003560 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003561 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003562 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003563 TosaTensorValuesGen.tvgLazyGenDefault,
3564 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003565 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003566 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003567 "error_if_validators": (
3568 TosaErrorValidator.evWrongInputType,
3569 TosaErrorValidator.evWrongOutputType,
3570 TosaErrorValidator.evWrongInputList,
3571 TosaErrorValidator.evWrongOutputList,
3572 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003573 "data_gen": {
3574 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3575 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003576 "compliance": {
3577 "abs_error_lower_bound": 0.5,
3578 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003579 },
Won Jeon78155c62023-06-10 00:20:04 +00003580 "erf": {
3581 "op": Op.ERF,
3582 "operands": (1, 0),
3583 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003584 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003585 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003586 TosaTensorValuesGen.tvgLazyGenDefault,
3587 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003588 ),
3589 "types": TYPE_FP,
3590 "error_if_validators": (
3591 TosaErrorValidator.evWrongInputType,
3592 TosaErrorValidator.evWrongOutputType,
3593 TosaErrorValidator.evWrongInputList,
3594 TosaErrorValidator.evWrongOutputList,
3595 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003596 "data_gen": {
3597 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3598 },
3599 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003600 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003601 # Elementwise Binary Operators
3602 "add": {
3603 "op": Op.ADD,
3604 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003605 "build_fcn": (
3606 build_binary_broadcast,
3607 TosaTensorGen.tgBroadcastFuzz,
3608 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003609 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003610 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003611 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003612 "error_if_validators": (
3613 TosaErrorValidator.evRankMismatch,
3614 TosaErrorValidator.evWrongInputType,
3615 TosaErrorValidator.evWrongOutputType,
3616 TosaErrorValidator.evWrongInputList,
3617 TosaErrorValidator.evWrongOutputList,
3618 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003619 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003620 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003621 "data_gen": {
3622 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3623 },
3624 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003625 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003626 "arithmetic_right_shift": {
3627 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3628 "operands": (2, 0),
3629 "build_fcn": (
3630 build_arithmetic_right_shift,
3631 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003632 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003633 TosaArgGen.agArithmeticRightShift,
3634 ),
3635 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003636 "error_if_validators": (
3637 TosaErrorValidator.evRankMismatch,
3638 TosaErrorValidator.evWrongInputType,
3639 TosaErrorValidator.evWrongOutputType,
3640 TosaErrorValidator.evWrongInputList,
3641 TosaErrorValidator.evWrongOutputList,
3642 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003643 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003644 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003645 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003646 "bitwise_and": {
3647 "op": Op.BITWISE_AND,
3648 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003649 "build_fcn": (
3650 build_binary_broadcast,
3651 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003652 TosaTensorValuesGen.tvgLazyGenDefault,
3653 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003654 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003655 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003656 "error_if_validators": (
3657 TosaErrorValidator.evRankMismatch,
3658 TosaErrorValidator.evWrongInputType,
3659 TosaErrorValidator.evWrongOutputType,
3660 TosaErrorValidator.evWrongInputList,
3661 TosaErrorValidator.evWrongOutputList,
3662 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003663 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003664 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003665 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003666 "bitwise_or": {
3667 "op": Op.BITWISE_OR,
3668 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003669 "build_fcn": (
3670 build_binary_broadcast,
3671 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003672 TosaTensorValuesGen.tvgLazyGenDefault,
3673 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003674 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003675 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003676 "error_if_validators": (
3677 TosaErrorValidator.evRankMismatch,
3678 TosaErrorValidator.evWrongInputType,
3679 TosaErrorValidator.evWrongOutputType,
3680 TosaErrorValidator.evWrongInputList,
3681 TosaErrorValidator.evWrongOutputList,
3682 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003683 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003684 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003685 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003686 "bitwise_xor": {
3687 "op": Op.BITWISE_XOR,
3688 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003689 "build_fcn": (
3690 build_binary_broadcast,
3691 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003692 TosaTensorValuesGen.tvgLazyGenDefault,
3693 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003694 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003695 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003696 "error_if_validators": (
3697 TosaErrorValidator.evRankMismatch,
3698 TosaErrorValidator.evWrongInputType,
3699 TosaErrorValidator.evWrongOutputType,
3700 TosaErrorValidator.evWrongInputList,
3701 TosaErrorValidator.evWrongOutputList,
3702 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003703 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003704 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003705 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003706 "intdiv": {
3707 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003708 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003709 "build_fcn": (
3710 build_binary_broadcast,
3711 TosaTensorGen.tgBroadcastFuzz,
3712 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003713 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003714 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003715 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003716 "error_if_validators": (
3717 TosaErrorValidator.evRankMismatch,
3718 TosaErrorValidator.evWrongInputType,
3719 TosaErrorValidator.evWrongOutputType,
3720 TosaErrorValidator.evWrongInputList,
3721 TosaErrorValidator.evWrongOutputList,
3722 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003723 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003724 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003725 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003726 "logical_and": {
3727 "op": Op.LOGICAL_AND,
3728 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003729 "build_fcn": (
3730 build_binary_broadcast,
3731 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003732 TosaTensorValuesGen.tvgLazyGenDefault,
3733 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003734 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003735 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003736 "error_if_validators": (
3737 TosaErrorValidator.evRankMismatch,
3738 TosaErrorValidator.evWrongInputType,
3739 TosaErrorValidator.evWrongOutputType,
3740 TosaErrorValidator.evWrongInputList,
3741 TosaErrorValidator.evWrongOutputList,
3742 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003743 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003744 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003745 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003746 "logical_left_shift": {
3747 "op": Op.LOGICAL_LEFT_SHIFT,
3748 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003749 "build_fcn": (
3750 build_binary_broadcast,
3751 TosaTensorGen.tgBroadcastFuzz,
3752 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003753 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003754 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003755 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003756 "error_if_validators": (
3757 TosaErrorValidator.evRankMismatch,
3758 TosaErrorValidator.evWrongInputType,
3759 TosaErrorValidator.evWrongOutputType,
3760 TosaErrorValidator.evWrongInputList,
3761 TosaErrorValidator.evWrongOutputList,
3762 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003763 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003764 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003765 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003766 "logical_right_shift": {
3767 "op": Op.LOGICAL_RIGHT_SHIFT,
3768 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003769 "build_fcn": (
3770 build_binary_broadcast,
3771 TosaTensorGen.tgBroadcastFuzz,
3772 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003773 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003774 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003775 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003776 "error_if_validators": (
3777 TosaErrorValidator.evRankMismatch,
3778 TosaErrorValidator.evWrongInputType,
3779 TosaErrorValidator.evWrongOutputType,
3780 TosaErrorValidator.evWrongInputList,
3781 TosaErrorValidator.evWrongOutputList,
3782 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003783 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003784 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003785 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003786 "logical_or": {
3787 "op": Op.LOGICAL_OR,
3788 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003789 "build_fcn": (
3790 build_binary_broadcast,
3791 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003792 TosaTensorValuesGen.tvgLazyGenDefault,
3793 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003794 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003795 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003796 "error_if_validators": (
3797 TosaErrorValidator.evRankMismatch,
3798 TosaErrorValidator.evWrongInputType,
3799 TosaErrorValidator.evWrongOutputType,
3800 TosaErrorValidator.evWrongInputList,
3801 TosaErrorValidator.evWrongOutputList,
3802 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003803 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003804 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003805 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003806 "logical_xor": {
3807 "op": Op.LOGICAL_XOR,
3808 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003809 "build_fcn": (
3810 build_binary_broadcast,
3811 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003812 TosaTensorValuesGen.tvgLazyGenDefault,
3813 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003814 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003815 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003816 "error_if_validators": (
3817 TosaErrorValidator.evRankMismatch,
3818 TosaErrorValidator.evWrongInputType,
3819 TosaErrorValidator.evWrongOutputType,
3820 TosaErrorValidator.evWrongInputList,
3821 TosaErrorValidator.evWrongOutputList,
3822 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003823 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003824 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003825 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003826 "maximum": {
3827 "op": Op.MAXIMUM,
3828 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003829 "build_fcn": (
3830 build_binary_broadcast,
3831 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003832 TosaTensorValuesGen.tvgLazyGenDefault,
3833 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003834 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003835 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003836 "error_if_validators": (
3837 TosaErrorValidator.evRankMismatch,
3838 TosaErrorValidator.evWrongInputType,
3839 TosaErrorValidator.evWrongOutputType,
3840 TosaErrorValidator.evWrongInputList,
3841 TosaErrorValidator.evWrongOutputList,
3842 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003843 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003844 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003845 "data_gen": {
3846 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3847 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003848 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003849 "minimum": {
3850 "op": Op.MINIMUM,
3851 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003852 "build_fcn": (
3853 build_binary_broadcast,
3854 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003855 TosaTensorValuesGen.tvgLazyGenDefault,
3856 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003857 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003858 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003859 "error_if_validators": (
3860 TosaErrorValidator.evRankMismatch,
3861 TosaErrorValidator.evWrongInputType,
3862 TosaErrorValidator.evWrongOutputType,
3863 TosaErrorValidator.evWrongInputList,
3864 TosaErrorValidator.evWrongOutputList,
3865 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003866 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003867 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003868 "data_gen": {
3869 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3870 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003871 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003872 "mul": {
3873 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003874 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003875 "build_fcn": (
3876 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003877 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003878 TosaTensorValuesGen.tvgMul,
3879 TosaArgGen.agMul,
3880 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003881 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003882 "error_if_validators": (
3883 TosaErrorValidator.evWrongInputType,
3884 TosaErrorValidator.evWrongOutputType,
3885 TosaErrorValidator.evWrongInputList,
3886 TosaErrorValidator.evWrongOutputList,
3887 TosaErrorValidator.evRankMismatch,
3888 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003889 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003890 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003891 "data_gen": {
3892 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3893 },
3894 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003895 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003896 "pow": {
3897 "op": Op.POW,
3898 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003899 "build_fcn": (
3900 build_binary_broadcast,
3901 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003902 TosaTensorValuesGen.tvgPow,
3903 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003904 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003905 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003906 "error_if_validators": (
3907 TosaErrorValidator.evRankMismatch,
3908 TosaErrorValidator.evWrongInputType,
3909 TosaErrorValidator.evWrongOutputType,
3910 TosaErrorValidator.evWrongInputList,
3911 TosaErrorValidator.evWrongOutputList,
3912 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003913 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003914 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003915 "data_gen": {
3916 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3917 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003918 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003919 "sub": {
3920 "op": Op.SUB,
3921 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003922 "build_fcn": (
3923 build_binary_broadcast,
3924 TosaTensorGen.tgBroadcastFuzz,
3925 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003926 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003927 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003928 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003929 "error_if_validators": (
3930 TosaErrorValidator.evRankMismatch,
3931 TosaErrorValidator.evWrongInputType,
3932 TosaErrorValidator.evWrongOutputType,
3933 TosaErrorValidator.evWrongInputList,
3934 TosaErrorValidator.evWrongOutputList,
3935 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003936 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003937 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003938 "data_gen": {
3939 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3940 },
3941 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003942 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003943 "table": {
3944 "op": Op.TABLE,
3945 # Use the automatic generation functions to create the input array
3946 # but create the table tensor in the build function, as it may be
3947 # a different type from the input
3948 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003949 "build_fcn": (
3950 build_table,
3951 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003952 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003953 TosaArgGen.agTable,
3954 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003955 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003956 "error_if_validators": (
3957 TosaErrorValidator.evWrongInputType,
3958 TosaErrorValidator.evWrongOutputType,
3959 TosaErrorValidator.evWrongInputList,
3960 TosaErrorValidator.evWrongOutputList,
3961 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003962 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003963 # Elementwise Unary operators
3964 "abs": {
3965 "op": Op.ABS,
3966 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003967 "build_fcn": (
3968 build_unary,
3969 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003970 TosaTensorValuesGen.tvgLazyGenDefault,
3971 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003972 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003973 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003974 "error_if_validators": (
3975 TosaErrorValidator.evWrongInputType,
3976 TosaErrorValidator.evWrongOutputType,
3977 TosaErrorValidator.evWrongInputList,
3978 TosaErrorValidator.evWrongOutputList,
3979 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003980 "data_gen": {
3981 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3982 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003983 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003984 "bitwise_not": {
3985 "op": Op.BITWISE_NOT,
3986 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003987 "build_fcn": (
3988 build_unary,
3989 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003990 TosaTensorValuesGen.tvgLazyGenDefault,
3991 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003992 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003993 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003994 "error_if_validators": (
3995 TosaErrorValidator.evWrongInputType,
3996 TosaErrorValidator.evWrongOutputType,
3997 TosaErrorValidator.evWrongInputList,
3998 TosaErrorValidator.evWrongOutputList,
3999 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004000 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004001 "ceil": {
4002 "op": Op.CEIL,
4003 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004004 "build_fcn": (
4005 build_unary,
4006 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004007 TosaTensorValuesGen.tvgLazyGenDefault,
4008 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004009 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004010 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004011 "error_if_validators": (
4012 TosaErrorValidator.evWrongInputType,
4013 TosaErrorValidator.evWrongOutputType,
4014 TosaErrorValidator.evWrongInputList,
4015 TosaErrorValidator.evWrongOutputList,
4016 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004017 "data_gen": {
4018 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4019 },
4020 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004021 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004022 "clz": {
4023 "op": Op.CLZ,
4024 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004025 "build_fcn": (
4026 build_unary,
4027 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004028 TosaTensorValuesGen.tvgLazyGenDefault,
4029 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004030 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004031 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004032 "error_if_validators": (
4033 TosaErrorValidator.evWrongInputType,
4034 TosaErrorValidator.evWrongOutputType,
4035 TosaErrorValidator.evWrongInputList,
4036 TosaErrorValidator.evWrongOutputList,
4037 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004038 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004039 "exp": {
4040 "op": Op.EXP,
4041 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004042 "build_fcn": (
4043 build_unary,
4044 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004045 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004046 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004047 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004048 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004049 "error_if_validators": (
4050 TosaErrorValidator.evWrongInputType,
4051 TosaErrorValidator.evWrongOutputType,
4052 TosaErrorValidator.evWrongInputList,
4053 TosaErrorValidator.evWrongOutputList,
4054 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004055 "data_gen": {
4056 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4057 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004058 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004059 "floor": {
4060 "op": Op.FLOOR,
4061 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004062 "build_fcn": (
4063 build_unary,
4064 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004065 TosaTensorValuesGen.tvgLazyGenDefault,
4066 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004067 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004068 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004069 "error_if_validators": (
4070 TosaErrorValidator.evWrongInputType,
4071 TosaErrorValidator.evWrongOutputType,
4072 TosaErrorValidator.evWrongInputList,
4073 TosaErrorValidator.evWrongOutputList,
4074 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004075 "data_gen": {
4076 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4077 },
4078 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004079 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004080 "log": {
4081 "op": Op.LOG,
4082 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004083 "build_fcn": (
4084 build_unary,
4085 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004086 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004087 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004088 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004089 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004090 "error_if_validators": (
4091 TosaErrorValidator.evWrongInputType,
4092 TosaErrorValidator.evWrongOutputType,
4093 TosaErrorValidator.evWrongInputList,
4094 TosaErrorValidator.evWrongOutputList,
4095 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004096 "data_gen": {
4097 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4098 },
4099 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004100 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004101 "logical_not": {
4102 "op": Op.LOGICAL_NOT,
4103 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004104 "build_fcn": (
4105 build_unary,
4106 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004107 TosaTensorValuesGen.tvgLazyGenDefault,
4108 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004109 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004110 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004111 "error_if_validators": (
4112 TosaErrorValidator.evWrongInputType,
4113 TosaErrorValidator.evWrongOutputType,
4114 TosaErrorValidator.evWrongInputList,
4115 TosaErrorValidator.evWrongOutputList,
4116 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004117 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004118 "negate": {
4119 "op": Op.NEGATE,
4120 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004121 "build_fcn": (
4122 build_unary,
4123 TosaTensorGen.tgBasic,
4124 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004125 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004126 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004127 "qgen": TosaQuantGen.qgUnary,
4128 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004129 "error_if_validators": (
4130 TosaErrorValidator.evInputZeroPointNotZero,
4131 TosaErrorValidator.evOutputZeroPointNotZero,
4132 TosaErrorValidator.evWrongInputType,
4133 TosaErrorValidator.evWrongOutputType,
4134 TosaErrorValidator.evWrongInputList,
4135 TosaErrorValidator.evWrongOutputList,
4136 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004137 "data_gen": {
4138 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4139 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004140 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004141 "reciprocal": {
4142 "op": Op.RECIPROCAL,
4143 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004144 "build_fcn": (
4145 build_unary,
4146 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004147 TosaTensorValuesGen.tvgLazyGenDefault,
4148 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004149 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004150 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004151 "error_if_validators": (
4152 TosaErrorValidator.evWrongInputType,
4153 TosaErrorValidator.evWrongOutputType,
4154 TosaErrorValidator.evWrongInputList,
4155 TosaErrorValidator.evWrongOutputList,
4156 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004157 "data_gen": {
4158 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4159 },
4160 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004161 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004162 "rsqrt": {
4163 "op": Op.RSQRT,
4164 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004165 "build_fcn": (
4166 build_unary,
4167 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004168 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004169 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004170 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004171 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004172 "error_if_validators": (
4173 TosaErrorValidator.evWrongInputType,
4174 TosaErrorValidator.evWrongOutputType,
4175 TosaErrorValidator.evWrongInputList,
4176 TosaErrorValidator.evWrongOutputList,
4177 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004178 "data_gen": {
4179 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4180 },
4181 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004182 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004183 # Elementwise Ternary operators
4184 "select": {
4185 "op": Op.SELECT,
4186 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004187 "build_fcn": (
4188 build_select,
4189 TosaTensorGen.tgBroadcastFuzz,
4190 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004191 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004192 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004193 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004194 "error_if_validators": (
4195 TosaErrorValidator.evRankMismatch,
4196 TosaErrorValidator.evWrongInputType,
4197 TosaErrorValidator.evWrongOutputType,
4198 TosaErrorValidator.evWrongInputList,
4199 TosaErrorValidator.evWrongOutputList,
4200 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004201 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004202 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004203 "data_gen": {
4204 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4205 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004206 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004207 # Comparison operators
4208 "equal": {
4209 "op": Op.EQUAL,
4210 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004211 "build_fcn": (
4212 build_comparison,
4213 TosaTensorGen.tgBroadcastFuzz,
4214 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004215 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004216 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004217 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004218 "error_if_validators": (
4219 TosaErrorValidator.evRankMismatch,
4220 TosaErrorValidator.evWrongInputType,
4221 TosaErrorValidator.evWrongOutputType,
4222 TosaErrorValidator.evWrongInputList,
4223 TosaErrorValidator.evWrongOutputList,
4224 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004225 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004226 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004227 "data_gen": {
4228 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4229 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004230 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004231 "greater_equal": {
4232 "op": Op.GREATER_EQUAL,
4233 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004234 "build_fcn": (
4235 build_comparison,
4236 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004237 TosaTensorValuesGen.tvgLazyGenDefault,
4238 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004239 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004240 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004241 "error_if_validators": (
4242 TosaErrorValidator.evRankMismatch,
4243 TosaErrorValidator.evWrongInputType,
4244 TosaErrorValidator.evWrongOutputType,
4245 TosaErrorValidator.evWrongInputList,
4246 TosaErrorValidator.evWrongOutputList,
4247 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004248 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004249 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004250 "data_gen": {
4251 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4252 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004253 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004254 "greater": {
4255 "op": Op.GREATER,
4256 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004257 "build_fcn": (
4258 build_comparison,
4259 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004260 TosaTensorValuesGen.tvgLazyGenDefault,
4261 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004262 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004263 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004264 "error_if_validators": (
4265 TosaErrorValidator.evRankMismatch,
4266 TosaErrorValidator.evWrongInputType,
4267 TosaErrorValidator.evWrongOutputType,
4268 TosaErrorValidator.evWrongInputList,
4269 TosaErrorValidator.evWrongOutputList,
4270 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004271 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004272 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004273 "data_gen": {
4274 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4275 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004276 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004277 # Reduction operators
4278 "reduce_all": {
4279 "op": Op.REDUCE_ALL,
4280 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004281 "build_fcn": (
4282 build_reduce,
4283 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004284 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004285 TosaArgGen.agAxis,
4286 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004287 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004288 "error_if_validators": (
4289 TosaErrorValidator.evAxisLargerRank,
4290 TosaErrorValidator.evAxisSmallerZero,
4291 TosaErrorValidator.evShapeOfAxisNotOne,
4292 TosaErrorValidator.evWrongInputType,
4293 TosaErrorValidator.evWrongOutputType,
4294 TosaErrorValidator.evWrongRank,
4295 TosaErrorValidator.evWrongInputList,
4296 TosaErrorValidator.evWrongOutputList,
4297 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004298 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004299 "reduce_any": {
4300 "op": Op.REDUCE_ANY,
4301 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004302 "build_fcn": (
4303 build_reduce,
4304 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004305 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004306 TosaArgGen.agAxis,
4307 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004308 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004309 "error_if_validators": (
4310 TosaErrorValidator.evAxisLargerRank,
4311 TosaErrorValidator.evAxisSmallerZero,
4312 TosaErrorValidator.evShapeOfAxisNotOne,
4313 TosaErrorValidator.evWrongInputType,
4314 TosaErrorValidator.evWrongOutputType,
4315 TosaErrorValidator.evWrongRank,
4316 TosaErrorValidator.evWrongInputList,
4317 TosaErrorValidator.evWrongOutputList,
4318 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004319 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004320 "reduce_max": {
4321 "op": Op.REDUCE_MAX,
4322 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004323 "build_fcn": (
4324 build_reduce,
4325 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004326 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004327 TosaArgGen.agAxis,
4328 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004329 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004330 "error_if_validators": (
4331 TosaErrorValidator.evAxisLargerRank,
4332 TosaErrorValidator.evAxisSmallerZero,
4333 TosaErrorValidator.evShapeOfAxisNotOne,
4334 TosaErrorValidator.evWrongInputType,
4335 TosaErrorValidator.evWrongOutputType,
4336 TosaErrorValidator.evWrongRank,
4337 TosaErrorValidator.evWrongInputList,
4338 TosaErrorValidator.evWrongOutputList,
4339 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004340 "data_gen": {
4341 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4342 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004343 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004344 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004345 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004346 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004347 "build_fcn": (
4348 build_reduce,
4349 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004350 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004351 TosaArgGen.agAxis,
4352 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004353 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004354 "error_if_validators": (
4355 TosaErrorValidator.evAxisLargerRank,
4356 TosaErrorValidator.evAxisSmallerZero,
4357 TosaErrorValidator.evShapeOfAxisNotOne,
4358 TosaErrorValidator.evWrongInputType,
4359 TosaErrorValidator.evWrongOutputType,
4360 TosaErrorValidator.evWrongRank,
4361 TosaErrorValidator.evWrongInputList,
4362 TosaErrorValidator.evWrongOutputList,
4363 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004364 "data_gen": {
4365 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4366 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004367 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004368 "reduce_product": {
4369 "op": Op.REDUCE_PRODUCT,
4370 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004371 "build_fcn": (
4372 build_reduce,
4373 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004374 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004375 TosaArgGen.agAxis,
4376 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004377 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004378 "error_if_validators": (
4379 TosaErrorValidator.evAxisLargerRank,
4380 TosaErrorValidator.evAxisSmallerZero,
4381 TosaErrorValidator.evShapeOfAxisNotOne,
4382 TosaErrorValidator.evWrongInputType,
4383 TosaErrorValidator.evWrongOutputType,
4384 TosaErrorValidator.evWrongRank,
4385 TosaErrorValidator.evWrongInputList,
4386 TosaErrorValidator.evWrongOutputList,
4387 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004388 "data_gen": {
4389 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4390 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004391 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004392 "reduce_sum": {
4393 "op": Op.REDUCE_SUM,
4394 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004395 "build_fcn": (
4396 build_reduce,
4397 TosaTensorGen.tgBasic,
4398 TosaTensorValuesGen.tvgReduceSum,
4399 TosaArgGen.agAxis,
4400 ),
James Ward24dbc422022-10-19 12:20:31 +01004401 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004402 "error_if_validators": (
4403 TosaErrorValidator.evAxisLargerRank,
4404 TosaErrorValidator.evAxisSmallerZero,
4405 TosaErrorValidator.evShapeOfAxisNotOne,
4406 TosaErrorValidator.evWrongInputType,
4407 TosaErrorValidator.evWrongOutputType,
4408 TosaErrorValidator.evWrongRank,
4409 TosaErrorValidator.evWrongInputList,
4410 TosaErrorValidator.evWrongOutputList,
4411 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004412 "data_gen": {
4413 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4414 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004415 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004416 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004417 "concat": {
4418 "op": Op.CONCAT,
4419 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004420 "build_fcn": (
4421 build_concat,
4422 TosaTensorGen.tgConcat,
4423 TosaTensorValuesGen.tvgConcat,
4424 TosaArgGen.agAxis,
4425 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004426 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004427 "error_if_validators": (
4428 TosaErrorValidator.evAxisLargerRank,
4429 TosaErrorValidator.evAxisSmallerZero,
4430 TosaErrorValidator.evConcatInputRankMismatch,
4431 TosaErrorValidator.evConcatShapeSumMismatch,
4432 TosaErrorValidator.evConcatInputDimMismatch,
4433 TosaErrorValidator.evWrongInputType,
4434 TosaErrorValidator.evWrongOutputType,
4435 TosaErrorValidator.evWrongOutputList,
4436 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004437 "data_gen": {
4438 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4439 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004440 },
4441 "pad": {
4442 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004443 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004444 "build_fcn": (
4445 build_pad,
4446 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004447 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004448 TosaArgGen.agPad,
4449 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004450 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004451 "error_if_validators": (
4452 TosaErrorValidator.evWrongInputType,
4453 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004454 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004455 TosaErrorValidator.evWrongOutputType,
4456 TosaErrorValidator.evWrongInputList,
4457 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004458 TosaErrorValidator.evRankMismatch,
4459 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004460 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004461 "data_gen": {
4462 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4463 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004464 },
Won Jeona21b2e82023-08-10 10:33:01 +00004465 "dim": {
4466 "op": Op.DIM,
4467 "operands": (1, 0),
4468 "build_fcn": (
4469 build_dim,
4470 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004471 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004472 TosaArgGen.agAxis,
4473 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004474 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004475 "error_if_validators": (
4476 TosaErrorValidator.evAxisLargerRank,
4477 TosaErrorValidator.evAxisSmallerZero,
4478 TosaErrorValidator.evWrongInputType,
4479 TosaErrorValidator.evWrongInputList,
4480 TosaErrorValidator.evWrongOutputList,
4481 TosaErrorValidator.evWrongRank,
4482 ),
4483 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004484 "reshape": {
4485 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004486 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004487 "build_fcn": (
4488 build_reshape,
4489 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004490 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004491 TosaArgGen.agReshape,
4492 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004493 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004494 "error_if_validators": (
4495 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4496 TosaErrorValidator.evWrongInputType,
4497 TosaErrorValidator.evWrongOutputType,
4498 TosaErrorValidator.evWrongInputList,
4499 TosaErrorValidator.evWrongOutputList,
4500 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004501 "data_gen": {
4502 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4503 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004504 },
4505 "reverse": {
4506 "op": Op.REVERSE,
4507 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004508 "build_fcn": (
4509 build_reverse,
4510 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004511 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004512 TosaArgGen.agAxis,
4513 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004514 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004515 "error_if_validators": (
4516 TosaErrorValidator.evAxisSmallerZero,
4517 TosaErrorValidator.evAxisLargerRank,
4518 TosaErrorValidator.evWrongInputType,
4519 TosaErrorValidator.evWrongOutputType,
4520 TosaErrorValidator.evWrongInputList,
4521 TosaErrorValidator.evWrongOutputList,
4522 ),
evacha0198477222024-01-26 12:25:32 +00004523 "data_gen": {
4524 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4525 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004526 },
4527 "slice": {
4528 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004529 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004530 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004531 "build_fcn": (
4532 build_slice,
4533 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004534 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004535 TosaArgGen.agSlice,
4536 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004537 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004538 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004539 # TODO Turn off these error categories for now as the reference
4540 # model cannot allocate memory space for empty tensor. We probably
4541 # can report an accurate error messege at the right place during
4542 # exeuction.
4543 # TosaErrorValidator.evStartSmallerZero,
4544 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004545 TosaErrorValidator.evStartSizeOutsideBounds,
4546 TosaErrorValidator.evSizeOutputShapeMismatch,
4547 TosaErrorValidator.evInputSizeStartLengthMismatch,
4548 TosaErrorValidator.evWrongRank,
4549 TosaErrorValidator.evWrongInputType,
4550 TosaErrorValidator.evWrongOutputType,
4551 TosaErrorValidator.evWrongInputList,
4552 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004553 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004554 ),
evacha017f7d4252024-01-24 12:08:09 +00004555 "data_gen": {
4556 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4557 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004558 },
4559 "tile": {
4560 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004561 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004562 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004563 "build_fcn": (
4564 build_tile,
4565 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004566 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004567 TosaArgGen.agTile,
4568 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004569 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004570 "error_if_validators": (
4571 TosaErrorValidator.evWrongInputType,
4572 TosaErrorValidator.evWrongOutputType,
4573 TosaErrorValidator.evWrongInputList,
4574 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004575 TosaErrorValidator.evRankMismatch,
4576 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004577 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004578 "data_gen": {
4579 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4580 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004581 },
4582 "transpose": {
4583 "op": Op.TRANSPOSE,
4584 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004585 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004586 "build_fcn": (
4587 build_transpose,
4588 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004589 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004590 TosaArgGen.agTranspose,
4591 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004592 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004593 "error_if_validators": (
4594 TosaErrorValidator.evIndexOutsideBounds,
4595 TosaErrorValidator.evIndexUsedTwice,
4596 TosaErrorValidator.evWrongInputType,
4597 TosaErrorValidator.evWrongOutputType,
4598 TosaErrorValidator.evWrongInputList,
4599 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004600 TosaErrorValidator.evWrongRank,
4601 TosaErrorValidator.evRankMismatch,
4602 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004603 ),
evacha0198477222024-01-26 12:25:32 +00004604 "data_gen": {
4605 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4606 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004607 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004608 # Data nodes
4609 "const": {
4610 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004611 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004612 "build_fcn": (
4613 build_const,
4614 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004615 TosaTensorValuesGen.tvgLazyGenDefault,
4616 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004617 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004618 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha0198477222024-01-26 12:25:32 +00004619 "data_gen": {
4620 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4621 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004622 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004623 "identity": {
4624 "op": Op.IDENTITY,
4625 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004626 "build_fcn": (
4627 build_unary,
4628 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004629 TosaTensorValuesGen.tvgLazyGenDefault,
4630 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004631 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004632 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004633 "data_gen": {
4634 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4635 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004636 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004637 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004638 "gather": {
4639 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004640 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004641 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004642 "build_fcn": (
4643 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004644 TosaTensorGen.tgGather,
4645 TosaTensorValuesGen.tvgGather,
4646 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004647 ),
James Ward24dbc422022-10-19 12:20:31 +01004648 "types": (
4649 DType.INT8,
4650 DType.INT16,
4651 DType.INT32,
4652 DType.FP16,
4653 DType.BF16,
4654 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004655 DType.FP8E4M3,
4656 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004657 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004658 "error_if_validators": (
4659 TosaErrorValidator.evWrongInputType,
4660 TosaErrorValidator.evWrongOutputType,
4661 TosaErrorValidator.evWrongInputList,
4662 TosaErrorValidator.evWrongOutputList,
4663 TosaErrorValidator.evWrongRank,
4664 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004665 "data_gen": {
4666 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4667 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004668 },
4669 "scatter": {
4670 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004671 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004672 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004673 "build_fcn": (
4674 build_scatter,
4675 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004676 TosaTensorValuesGen.tvgScatter,
4677 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004678 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004679 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004680 "error_if_validators": (
4681 TosaErrorValidator.evWrongInputType,
4682 TosaErrorValidator.evWrongOutputType,
4683 TosaErrorValidator.evWrongInputList,
4684 TosaErrorValidator.evWrongOutputList,
4685 TosaErrorValidator.evWrongRank,
4686 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004687 "data_gen": {
4688 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4689 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004690 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004691 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004692 "resize": {
4693 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004694 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004695 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004696 "build_fcn": (
4697 build_resize,
4698 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004699 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004700 TosaArgGen.agResize,
4701 ),
James Ward24dbc422022-10-19 12:20:31 +01004702 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004703 "invalid_test_validators": (
4704 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004705 ),
4706 "error_if_validators": (
4707 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004708 TosaErrorValidator.evScaleSmallerEqualZero,
4709 TosaErrorValidator.evScaleNLargerMax,
4710 TosaErrorValidator.evScaleDLargerMax,
4711 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004712 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004713 TosaErrorValidator.evBorderSmallerMin,
4714 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004715 TosaErrorValidator.evWrongInputType,
4716 TosaErrorValidator.evWrongOutputType,
4717 TosaErrorValidator.evWrongRank,
4718 TosaErrorValidator.evWrongInputList,
4719 TosaErrorValidator.evWrongOutputList,
4720 TosaErrorValidator.evBatchMismatch,
4721 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004722 TosaErrorValidator.evResizeOutputShapeMismatch,
4723 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004724 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004725 "data_gen": {
4726 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4727 },
4728 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004729 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004730 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004731 "cast": {
4732 "op": Op.CAST,
4733 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004734 "build_fcn": (
4735 build_cast,
4736 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004737 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004738 TosaArgGen.agCast,
4739 ),
James Ward8b390432022-08-12 20:48:56 +01004740 "types": (
4741 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004742 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004743 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004744 DType.INT8,
4745 DType.INT16,
4746 DType.INT32,
4747 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004748 DType.FP8E4M3,
4749 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004750 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004751 "error_if_validators": (
4752 TosaErrorValidator.evWrongInputType,
4753 TosaErrorValidator.evWrongOutputType,
4754 TosaErrorValidator.evWrongInputList,
4755 TosaErrorValidator.evWrongOutputList,
4756 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004757 "data_gen": {
4758 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4759 },
4760 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004761 },
4762 "rescale": {
4763 "op": Op.RESCALE,
4764 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004765 "build_fcn": (
4766 build_rescale,
4767 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00004768 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004769 TosaArgGen.agRescale,
4770 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004771 "types": [
4772 DType.UINT8,
4773 DType.INT8,
4774 DType.INT16,
4775 DType.INT32,
4776 DType.INT48,
4777 DType.UINT16,
4778 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004779 "error_if_validators": (
4780 TosaErrorValidator.evInputZeroPointNotZero,
4781 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004782 TosaErrorValidator.evU16InputZeroPointNotValid,
4783 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004784 TosaErrorValidator.evScaleTrue,
4785 TosaErrorValidator.evScaleNotTrue,
4786 TosaErrorValidator.evWrongInputType,
4787 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004788 TosaErrorValidator.evWrongInputList,
4789 TosaErrorValidator.evWrongOutputList,
4790 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004791 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004792 # Custom
4793 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004794 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004795 # Two varients of cond_if, one that generates one of two constant tensors (no
4796 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4797 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004798 "cond_if_const": {
4799 "op": Op.COND_IF,
4800 "operands": (0, 2),
4801 "build_fcn": (
4802 build_cond_if_const,
4803 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004804 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004805 TosaArgGen.agCondIf,
4806 ),
4807 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004808 "error_if_validators": (
4809 TosaErrorValidator.evOutputListThenGraphMismatch,
4810 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004811 TosaErrorValidator.evCondIfCondNotMatchingBool,
4812 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004813 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004814 },
4815 "cond_if_binary": {
4816 "op": Op.COND_IF,
4817 "operands": (2, 0),
4818 "build_fcn": (
4819 build_cond_if_binary,
4820 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004821 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004822 TosaArgGen.agCondIf,
4823 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004824 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004825 "error_if_validators": (
4826 TosaErrorValidator.evInputListThenGraphMismatch,
4827 TosaErrorValidator.evInputListElseGraphMismatch,
4828 TosaErrorValidator.evOutputListThenGraphMismatch,
4829 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004830 TosaErrorValidator.evCondIfCondNotMatchingBool,
4831 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004832 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004833 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004834 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004835 "while_loop": {
4836 "op": Op.WHILE_LOOP,
4837 "operands": (0, 1),
4838 "build_fcn": (
4839 build_while_loop,
4840 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004841 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004842 TosaArgGen.agWhileLoop,
4843 ),
4844 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004845 "error_if_validators": (
4846 TosaErrorValidator.evInputListOutputListMismatch,
4847 TosaErrorValidator.evInputListCondGraphMismatch,
4848 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4849 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4850 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004851 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004852 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004853 },
Luke Hutton57287132023-02-06 14:54:18 +00004854 "fft2d": {
4855 "op": Op.FFT2D,
4856 "operands": (2, 0),
4857 "rank": (3, 3),
4858 "build_fcn": (
4859 build_fft2d,
4860 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004861 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004862 TosaArgGen.agFFT2d,
4863 ),
4864 "types": [DType.FP32],
4865 "error_if_validators": (
4866 TosaErrorValidator.evWrongInputType,
4867 TosaErrorValidator.evWrongOutputType,
4868 TosaErrorValidator.evWrongInputList,
4869 TosaErrorValidator.evWrongOutputList,
4870 TosaErrorValidator.evWrongRank,
4871 TosaErrorValidator.evBatchMismatch,
4872 TosaErrorValidator.evKernelNotPowerOfTwo,
4873 TosaErrorValidator.evFFTInputShapeMismatch,
4874 TosaErrorValidator.evFFTOutputShapeMismatch,
4875 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004876 "data_gen": {
4877 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4878 },
Luke Hutton57287132023-02-06 14:54:18 +00004879 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004880 "rfft2d": {
4881 "op": Op.RFFT2D,
4882 "operands": (1, 0),
4883 "rank": (3, 3),
4884 "build_fcn": (
4885 build_rfft2d,
4886 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004887 TosaTensorValuesGen.tvgLazyGenDefault,
4888 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004889 ),
4890 "types": [DType.FP32],
4891 "error_if_validators": (
4892 TosaErrorValidator.evWrongInputType,
4893 TosaErrorValidator.evWrongOutputType,
4894 TosaErrorValidator.evWrongInputList,
4895 TosaErrorValidator.evWrongOutputList,
4896 TosaErrorValidator.evWrongRank,
4897 TosaErrorValidator.evBatchMismatch,
4898 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004899 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004900 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004901 "data_gen": {
4902 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4903 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004904 },
Won Jeon74342e52024-01-09 00:34:40 +00004905 # Shape
4906 "add_shape": {
4907 "op": Op.ADD_SHAPE,
4908 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004909 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004910 "build_fcn": (
4911 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004912 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004913 TosaTensorValuesGen.tvgAddSub,
4914 TosaArgGen.agNone,
4915 ),
4916 "types": [DType.SHAPE],
4917 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4918 },
4919 "sub_shape": {
4920 "op": Op.SUB_SHAPE,
4921 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004922 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004923 "build_fcn": (
4924 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004925 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004926 TosaTensorValuesGen.tvgAddSub,
4927 TosaArgGen.agNone,
4928 ),
4929 "types": [DType.SHAPE],
4930 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4931 },
4932 "mul_shape": {
4933 "op": Op.MUL_SHAPE,
4934 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004935 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004936 "build_fcn": (
4937 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004938 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004939 TosaTensorValuesGen.tvgMul,
4940 TosaArgGen.agNone,
4941 ),
4942 "types": [DType.SHAPE],
4943 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4944 },
4945 "div_shape": {
4946 "op": Op.DIV_SHAPE,
4947 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004948 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004949 "build_fcn": (
4950 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004951 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004952 TosaTensorValuesGen.tvgIntDiv,
4953 TosaArgGen.agNone,
4954 ),
4955 "types": [DType.SHAPE],
4956 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4957 },
4958 "concat_shape": {
4959 "op": Op.CONCAT_SHAPE,
4960 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004961 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004962 "build_fcn": (
4963 build_concat,
4964 TosaTensorGen.tgConcat,
4965 TosaTensorValuesGen.tvgConcat,
4966 TosaArgGen.agNone,
4967 ),
4968 "types": [DType.SHAPE],
4969 "error_if_validators": (),
4970 },
4971 "const_shape": {
4972 "op": Op.CONST_SHAPE,
4973 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004974 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004975 "build_fcn": (
4976 build_const,
4977 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004978 TosaTensorValuesGen.tvgLazyGenDefault,
4979 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00004980 ),
4981 "types": [DType.SHAPE],
4982 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004983 }
4984
Kevin Cheng550ccc52021-03-03 11:21:43 -08004985
Eric Kunzee5e26762020-10-13 16:11:07 -07004986class OutputShaper:
4987 # Methods in this class compute the expected output shape and datatype
4988 # for common classes of operations
4989 def __init__(self):
4990 pass
4991
4992 # These methods return arguments that can be used for
4993 # creating a new output tensor
4994 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004995 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4996 if error_name != ErrorIf.RankMismatch:
4997 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004998 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004999
5000 shape = []
5001 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005002 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005003 shape.append(b.shape[i])
5004 else:
5005 shape.append(a.shape[i])
5006
Jerry Ge135c9552023-05-23 20:59:32 +00005007 fuzz_idx = rng.integers(0, len(a.shape))
5008 if error_name == ErrorIf.DimensionMismatch:
5009 shape[fuzz_idx] += 1
5010
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005011 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005012 all_dtypes = [
5013 DType.INT8,
5014 DType.INT16,
5015 DType.INT32,
5016 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005017 DType.FP16,
5018 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005019 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005020 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005021 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5022 outputDType = rng.choice(wrong_dtypes)
5023 else:
5024 outputDType = a.dtype
5025
5026 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005027
5028 @staticmethod
5029 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005030 assert len(a.shape) == len(b.shape)
5031 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005032
5033 shape = []
5034 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005035 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005036 shape.append(a.shape[i])
5037
Kevin Cheng550ccc52021-03-03 11:21:43 -08005038 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005039
5040 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005041 def unaryOp(ser, rng, a, error_name=None):
5042 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005043 all_dtypes = [
5044 DType.INT8,
5045 DType.INT16,
5046 DType.INT32,
5047 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005048 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005049 DType.FP16,
5050 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005051 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005052 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5053 outputDType = rng.choice(wrong_dtypes)
5054 else:
5055 outputDType = a.dtype
5056
5057 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005058
5059 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005060 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005061 if error_name != ErrorIf.RankMismatch:
5062 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005063 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005064
5065 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005066 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005067 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005068 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5069 else:
5070 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005071
Jerry Ge135c9552023-05-23 20:59:32 +00005072 fuzz_idx = rng.integers(0, len(a.shape))
5073 if error_name == ErrorIf.DimensionMismatch:
5074 shape[fuzz_idx] += 1
5075
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005076 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005077 all_dtypes = [
5078 DType.INT8,
5079 DType.INT16,
5080 DType.INT32,
5081 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005082 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005083 DType.FP16,
5084 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005085 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005086 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5087 outputDType = rng.choice(wrong_dtypes)
5088 else:
5089 outputDType = a.dtype
5090
5091 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005092
5093 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005094 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005095 if error_name != ErrorIf.RankMismatch:
5096 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005097 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005098
5099 # Do broadcast
5100 shape = []
5101 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005102 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005103 shape.append(b.shape[i])
5104 else:
5105 shape.append(a.shape[i])
5106
Jerry Ge135c9552023-05-23 20:59:32 +00005107 fuzz_idx = rng.integers(0, len(a.shape))
5108 if error_name == ErrorIf.DimensionMismatch:
5109 shape[fuzz_idx] += 1
5110
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005111 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005112 wrong_dtypes = [
5113 DType.INT8,
5114 DType.INT16,
5115 DType.INT32,
5116 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005117 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005118 DType.FP16,
5119 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005120 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005121 outputDType = rng.choice(wrong_dtypes)
5122 else:
5123 outputDType = DType.BOOL
5124
5125 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005126
5127 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005128 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005129 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005130 if error_name not in [
5131 ErrorIf.AxisSmallerZero,
5132 ErrorIf.AxisLargerRank,
5133 ErrorIf.ShapeOfAxisNotOne,
5134 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005135 shape[axis] = 1
5136 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5137 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005138
Matthew Haddond6ce7252021-09-29 15:35:44 +01005139 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005140 all_dtypes = [
5141 DType.INT8,
5142 DType.INT16,
5143 DType.INT32,
5144 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005145 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005146 DType.FP16,
5147 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005148 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005149 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5150 outputDType = rng.choice(wrong_dtypes)
5151 else:
5152 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005153
Matthew Haddond6ce7252021-09-29 15:35:44 +01005154 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005155
5156 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005157 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005158 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005159
5160 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5161 del shape[axis]
5162
5163 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5164 remove = rng.choice([True, False])
5165 if remove and len(shape) > 1:
5166 del shape[0]
5167 else:
5168 shape.append(1)
5169 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5170 for i in range(len(shape)):
5171 shape[i] = shape[i] + rng.integers(1, 10)
5172
5173 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005174 all_dtypes = [
5175 DType.INT8,
5176 DType.INT16,
5177 DType.INT32,
5178 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005179 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005180 DType.FP16,
5181 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005182 DType.FP8E4M3,
5183 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005184 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005185 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5186 outputDType = rng.choice(wrong_dtypes)
5187 else:
5188 outputDType = DType.INT32
5189
5190 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005191
5192 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005193 def conv2dOp(
5194 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5195 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005196
5197 # IFM: NHWC
5198 # Filter: OHWI
5199 # OFM: NHWC
5200
Kevin Cheng550ccc52021-03-03 11:21:43 -08005201 h = (
5202 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005203 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005204 + padding[0]
5205 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005206 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005207 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005208
Kevin Cheng550ccc52021-03-03 11:21:43 -08005209 w = (
5210 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005211 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005212 + padding[2]
5213 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005214 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005215 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005216
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005217 if error_name == ErrorIf.ConvOutputShapeMismatch:
5218 choices = [1, 2, 3]
5219 change = rng.choice(choices)
5220 # increment in multiples of stride to not hit non-integer error case
5221 if change in [1, 3]:
5222 h = h + (rng.choice(choices) * strides[0])
5223 if change in [2, 3]:
5224 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005225
Eric Kunzee5e26762020-10-13 16:11:07 -07005226 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5227
James Ward8b390432022-08-12 20:48:56 +01005228 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005229 # Pick some potentially correct output dtype if input type is incorrect
5230 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005231 else:
James Ward8b390432022-08-12 20:48:56 +01005232 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005233
5234 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005235 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005236 excludes = [DType.FP16, DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00005237 if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
5238 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005239 else:
5240 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005241 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005242 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005243
Kevin Cheng550ccc52021-03-03 11:21:43 -08005244 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005245
5246 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005247 def conv3dOp(
5248 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5249 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005250
5251 # IFM: NDHWC
5252 # Filter: ODHWI
5253 # OFM: NDHWC
5254
5255 d = (
5256 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005257 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005258 + padding[0]
5259 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005260 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005261 ) // strides[0] + 1
5262
5263 h = (
5264 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005265 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005266 + padding[2]
5267 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005268 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005269 ) // strides[1] + 1
5270
5271 w = (
5272 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005273 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005274 + padding[4]
5275 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005276 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005277 ) // strides[2] + 1
5278
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005279 if error_name == ErrorIf.ConvOutputShapeMismatch:
5280 choices = [1, 2, 3, 4]
5281 change = rng.choice(choices)
5282 # increment in multiples of stride to not hit non-integer error case
5283 if change in [1, 4]:
5284 d = d + (rng.choice(choices) * strides[0])
5285 if change in [2, 4]:
5286 h = h + (rng.choice(choices) * strides[1])
5287 if change in [3, 4]:
5288 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005289
Kevin Cheng1533b852021-09-01 12:51:58 -07005290 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5291
James Ward8b390432022-08-12 20:48:56 +01005292 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005293 # Pick some potentially correct output dtype if input type is incorrect
5294 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005295 else:
James Ward8b390432022-08-12 20:48:56 +01005296 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005297
5298 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005299 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005300 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005301 else:
5302 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005303 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005304 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005305
5306 return ser.addOutput(ofm_shape, out_dtype)
5307
5308 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005309 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005310 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005311 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005312 # IFM: NHWC
5313 # Filter: HWCM
5314 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005315
Kevin Cheng550ccc52021-03-03 11:21:43 -08005316 h = (
5317 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005318 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005319 + padding[0]
5320 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005321 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005322 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005323
Kevin Cheng550ccc52021-03-03 11:21:43 -08005324 w = (
5325 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005326 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005327 + padding[2]
5328 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005329 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005330 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005331
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005332 if error_name == ErrorIf.ConvOutputShapeMismatch:
5333 choices = [1, 2, 3]
5334 change = rng.choice(choices)
5335 # increment in multiples of stride to not hit non-integer error case
5336 if change in [1, 3]:
5337 h = h + (rng.choice(choices) * strides[0])
5338 if change in [2, 3]:
5339 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005340
Eric Kunzee5e26762020-10-13 16:11:07 -07005341 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5342
James Ward8b390432022-08-12 20:48:56 +01005343 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005344 # Pick some potentially correct output dtype if input type is incorrect
5345 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005346 else:
James Ward8b390432022-08-12 20:48:56 +01005347 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005348
5349 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005350 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005351 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005352 else:
5353 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005354 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005355 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005356
Kevin Cheng550ccc52021-03-03 11:21:43 -08005357 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005358
5359 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005360 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005361 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005362 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005363 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005364 h = 1
5365 w = 1
5366 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005367 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5368 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005369
5370 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005371 choices = [1, 2, 3]
5372 change = rng.choice(choices)
5373 # increment in multiples of stride to not hit non-integer error case
5374 if change in [1, 3]:
5375 h = h + (rng.choice(choices) * stride[0])
5376 if change in [2, 3]:
5377 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005378 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005379
5380 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005381 all_dtypes = [
5382 DType.INT8,
5383 DType.INT16,
5384 DType.INT32,
5385 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005386 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005387 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005388 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005389 DType.FP8E4M3,
5390 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005391 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005392 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5393 outputDType = rng.choice(wrong_dtypes)
5394 else:
5395 outputDType = ifm.dtype
5396
5397 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005398
5399 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005400 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005401 # input: N, IC
5402 # filter: OC, IC
5403 # output: N, OC
5404
5405 output_shape = [input.shape[0], filter.shape[0]]
5406
James Ward8b390432022-08-12 20:48:56 +01005407 # Validated in arg_gen (also invalidated for ErrorIf)
5408 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005409
Kevin Cheng550ccc52021-03-03 11:21:43 -08005410 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005411
5412 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005413 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005414 # a: N, H, C
5415 # b: N, C, W
5416 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005417
Kevin Cheng2d60f002021-06-09 14:18:32 -07005418 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005419
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005420 if error_name == ErrorIf.WrongOutputType:
5421 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005422 incorrect_types = (
5423 DType.INT4,
5424 DType.INT8,
5425 DType.INT16,
5426 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005427 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005428 DType.FP16,
5429 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005430 DType.FP8E4M3,
5431 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005432 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005433 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005434 incorrect_types = (
5435 DType.INT4,
5436 DType.INT8,
5437 DType.INT16,
5438 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005439 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005440 DType.FP16,
5441 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005442 DType.FP8E4M3,
5443 DType.FP8E5M2,
5444 )
5445 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5446 incorrect_types = (
5447 DType.INT4,
5448 DType.INT8,
5449 DType.INT16,
5450 DType.INT32,
5451 DType.INT48,
5452 DType.FP32,
5453 DType.BF16,
5454 DType.FP8E4M3,
5455 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005456 )
James Ward24dbc422022-10-19 12:20:31 +01005457 elif (
5458 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5459 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005460 incorrect_types = (
5461 DType.INT4,
5462 DType.INT8,
5463 DType.INT16,
5464 DType.INT32,
5465 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005466 DType.FP8E4M3,
5467 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005468 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005469 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005470 elif error_name == ErrorIf.WrongInputType:
5471 # Pick some potentially correct output dtype if input type is incorrect
5472 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005473 else:
James Ward8b390432022-08-12 20:48:56 +01005474 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005475
Kevin Cheng550ccc52021-03-03 11:21:43 -08005476 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005477
5478 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005479 def concatOp(ser, rng, axis, inputs, error_name=None):
5480 input1 = inputs[0]
5481 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005482
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005483 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005484 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005485 if not (
5486 # unable to concat tensors of different ranks
5487 error_name == ErrorIf.ConcatInputRankMismatch
5488 # unable to concat tensors along an invalid axis
5489 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005490 ):
5491 for tensor in remaining_inputs:
5492 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005493
Matthew Haddon01c359d2021-10-15 16:30:48 +01005494 if error_name == ErrorIf.ConcatShapeSumMismatch:
5495 output_shape[axis] += rng.integers(5, 10)
5496
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005497 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005498 all_dtypes = {
5499 DType.INT8,
5500 DType.INT16,
5501 DType.INT32,
5502 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005503 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005504 DType.FP16,
5505 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005506 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005507 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5508 outputDType = rng.choice(wrong_dtypes)
5509 else:
5510 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005511
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005512 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005513
5514 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005515 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005516
5517 output_shape = a.shape.copy()
5518
5519 for i in range(len(output_shape)):
5520 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5521
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005522 if error_name == ErrorIf.PadOutputShapeMismatch:
5523 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005524 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005525 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005526 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005527
Matthew Haddone807aae2021-10-11 18:12:58 +01005528 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005529 all_dtypes = [
5530 DType.INT8,
5531 DType.INT16,
5532 DType.INT32,
5533 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005534 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005535 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005536 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005537 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005538 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5539 outputDType = rng.choice(wrong_dtypes)
5540 else:
5541 outputDType = a.dtype
5542
5543 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005544
5545 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005546 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005547 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005548
5549 if error_name == ErrorIf.WrongOutputType:
5550 all_dtypes = [
5551 DType.INT8,
5552 DType.INT16,
5553 DType.INT32,
5554 DType.INT48,
5555 DType.FP32,
5556 DType.FP16,
5557 DType.BF16,
5558 ]
5559 wrong_dtypes = list(set(all_dtypes))
5560 outputDType = rng.choice(wrong_dtypes)
5561 else:
5562 outputDType = DType.SHAPE
5563
5564 return ser.addOutput(output_shape, outputDType)
5565
5566 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005567 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005568 output_shape = shape.copy()
5569
Matthew Haddone807aae2021-10-11 18:12:58 +01005570 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5571 for i in range(len(output_shape)):
5572 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5573
5574 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005575 all_dtypes = [
5576 DType.INT8,
5577 DType.INT16,
5578 DType.INT32,
5579 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005580 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005581 DType.FP16,
5582 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005583 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005584 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5585 outputDType = rng.choice(wrong_dtypes)
5586 else:
5587 outputDType = a.dtype
5588
5589 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005590
5591 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005592 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005593
Matthew Haddone807aae2021-10-11 18:12:58 +01005594 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005595 all_dtypes = [
5596 DType.INT8,
5597 DType.INT16,
5598 DType.INT32,
5599 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005600 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005601 DType.FP16,
5602 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005603 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005604 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005605 outputDType = rng.choice(wrong_dtypes)
5606 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005607 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005608
Luke Huttona4e48ca2023-02-22 11:53:48 +00005609 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005610 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005611 for index in range(len(output_shape)):
5612 if output_shape[index] <= 2:
5613 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5614 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005615 output_shape[index] = output_shape[index] + rng.choice(
5616 [-2, -1, 1, 2]
5617 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005618 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5619 output_shape = input.shape.copy()
5620 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005621 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005622
5623 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005624
5625 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005626 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005627
5628 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005629 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005630
5631 for i in range(len(output_shape)):
5632 output_shape[i] = a.shape[i] * multiples[i]
5633
Luke Huttona4e48ca2023-02-22 11:53:48 +00005634 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005635 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005636
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005637 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005638 all_dtypes = [
5639 DType.INT8,
5640 DType.INT16,
5641 DType.INT32,
5642 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005643 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005644 DType.FP16,
5645 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005646 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005647 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5648 outputDType = rng.choice(wrong_dtypes)
5649 else:
5650 outputDType = a.dtype
5651
5652 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005653
5654 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005655 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005656 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005657
Kevin Cheng550ccc52021-03-03 11:21:43 -08005658 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005659
Luke Huttona4e48ca2023-02-22 11:53:48 +00005660 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005661 for i in range(len(output_shape)):
5662 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005663
Luke Huttona4e48ca2023-02-22 11:53:48 +00005664 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5665 for i in range(len(output_shape)):
5666 output_shape[i] += rng.integers(1, 10)
5667 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005668 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005669
Matthew Haddone807aae2021-10-11 18:12:58 +01005670 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005671 all_dtypes = [
5672 DType.INT8,
5673 DType.INT16,
5674 DType.INT32,
5675 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005676 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005677 DType.FP16,
5678 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005679 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005680 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5681 outputDType = rng.choice(wrong_dtypes)
5682 else:
5683 outputDType = a.dtype
5684
5685 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005686
5687 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005688 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005689 if error_name != ErrorIf.WrongRank:
5690 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005691 assert len(indices.shape) == 2
5692 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005693
Kevin Cheng77d0f762020-11-24 10:26:32 -08005694 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5695
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005696 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005697 all_dtypes = [
5698 DType.INT8,
5699 DType.INT16,
5700 DType.INT32,
5701 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005702 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005703 DType.FP16,
5704 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005705 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005706 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5707 outputDType = rng.choice(wrong_dtypes)
5708 else:
5709 outputDType = values.dtype
5710
5711 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005712
5713 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005714 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005715 if error_name != ErrorIf.WrongRank:
5716 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005717 assert len(indices.shape) == 2
5718 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005719 assert values_in.shape[0] == indices.shape[0] # N
5720 assert input.shape[1] == indices.shape[1] # W
5721 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005722
5723 output_shape = values_in.shape
5724
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005725 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005726 all_dtypes = [
5727 DType.INT8,
5728 DType.INT16,
5729 DType.INT32,
5730 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005731 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005732 DType.FP16,
5733 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005734 DType.FP8E4M3,
5735 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005736 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005737 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5738 outputDType = rng.choice(wrong_dtypes)
5739 else:
5740 outputDType = values_in.dtype
5741
5742 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005743
5744 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005745 def tableOp(ser, rng, input, error_name=None):
5746 # Same shape as the input, dtype dependent on input dtype
5747 if error_name != ErrorIf.WrongInputType:
5748 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005749 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005750 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005751 wrong_dtypes = [
5752 DType.INT8,
5753 DType.INT16,
5754 DType.INT32,
5755 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005756 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005757 DType.FP16,
5758 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005759 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005760 wrong_dtypes.remove(output_dtype)
5761 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005762 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005763
5764 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005765 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005766 serializer,
5767 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005768 input,
5769 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005770 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005771 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005772 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005773 input_dtype,
5774 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005775 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005776 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005777 # Calculate OH, OW
5778 scale_y_n = scale[0]
5779 scale_y_d = scale[1]
5780 scale_x_n = scale[2]
5781 scale_x_d = scale[3]
5782 if error_name == ErrorIf.ScaleSmallerEqualZero:
5783 scale_y_n = max(scale_y_n, 1)
5784 scale_y_d = max(scale_y_d, 1)
5785 scale_x_n = max(scale_x_n, 1)
5786 scale_x_d = max(scale_x_d, 1)
5787
5788 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5789 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5790
5791 if error_name is not None:
5792 # Make sure the output tensor is valid, which can occur when
5793 # scale, offset or border have been changed for ERROR_IFs
5794 oh = max(oh, 1)
5795 ow = max(ow, 1)
5796 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005797 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5798 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005799
5800 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5801 choices = [1, 2, 3]
5802 change = rng.choice(choices)
5803 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5804 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005805 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005806 oh -= scale_y_d
5807 assert oh > 0 # Should have been caught in agResize
5808 else:
5809 oh += scale_y_d
5810 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005811 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005812 ow -= scale_x_d
5813 assert ow > 0 # Should have been caught in agResize
5814 else:
5815 ow += scale_x_d
5816
Matthew Haddon848efb42021-09-09 12:30:53 +01005817 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005818 output_dims = [
5819 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005820 oh,
5821 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005822 input.shape[0],
5823 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005824 elif error_name == ErrorIf.BatchMismatch:
5825 output_dims = [
5826 input.shape[0] + rng.integers(1, 10),
5827 oh,
5828 ow,
5829 input.shape[3],
5830 ]
5831 elif error_name == ErrorIf.ChannelMismatch:
5832 output_dims = [
5833 input.shape[0],
5834 oh,
5835 ow,
5836 input.shape[3] + rng.integers(1, 10),
5837 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005838 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005839 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005840
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005841 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005842
5843 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005844 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005845 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005846
5847 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005848 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005849 if error_name == ErrorIf.ConvOutputShapeMismatch:
5850 choices = [1, 2, 3]
5851 change = rng.choice(choices)
5852 if change in [1, 3]:
5853 output_shape[1] = output_shape[1] + rng.choice(choices)
5854 if change in [2, 3]:
5855 output_shape[2] = output_shape[2] + rng.choice(choices)
5856
James Ward8b390432022-08-12 20:48:56 +01005857 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005858 # Pick some potentially correct output dtype if input type is incorrect
5859 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005860 else:
James Ward8b390432022-08-12 20:48:56 +01005861 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005862
5863 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005864 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005865 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005866 else:
5867 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005868 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005869 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005870
Kevin Cheng550ccc52021-03-03 11:21:43 -08005871 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005872
5873 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005874 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5875 outputs = []
5876
5877 assert ifm1.dtype == ifm2.dtype
5878 input_dtype = ifm1.dtype
5879
5880 if error_name != ErrorIf.FFTInputShapeMismatch:
5881 assert ifm1.shape == ifm2.shape
5882
5883 input_shape = ifm1.shape
5884 if error_name != ErrorIf.WrongRank:
5885 assert len(input_shape) == 3
5886
5887 output_shape = input_shape.copy()
5888 output_dtype = input_dtype
5889
5890 if error_name == ErrorIf.WrongOutputType:
5891 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005892 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005893 output_dtype = rng.choice(wrong_dtypes)
5894 elif error_name == ErrorIf.BatchMismatch:
5895 output_shape[0] += rng.integers(1, 10)
5896 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5897 modify_dim = rng.choice([1, 2])
5898 output_shape[modify_dim] += rng.integers(1, 10)
5899
5900 outputs.append(serializer.addOutput(output_shape, output_dtype))
5901 outputs.append(serializer.addOutput(output_shape, output_dtype))
5902 return outputs
5903
5904 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005905 def rfft2dOp(serializer, rng, value, error_name=None):
5906 outputs = []
5907
5908 input_shape = value.shape
5909 if error_name != ErrorIf.WrongRank:
5910 assert len(input_shape) == 3
5911
5912 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5913
5914 output_dtype = value.dtype
5915 if error_name == ErrorIf.WrongOutputType:
5916 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005917 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005918 output_dtype = rng.choice(wrong_dtypes)
5919 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005920 output_shape[0] += rng.integers(1, 10)
5921 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5922 modify_dim = rng.choice([1, 2])
5923 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005924
5925 outputs.append(serializer.addOutput(output_shape, output_dtype))
5926 outputs.append(serializer.addOutput(output_shape, output_dtype))
5927 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005928
5929 @staticmethod
5930 def addShapeOp(ser, rng, a, b, error_name=None):
5931 if error_name != ErrorIf.RankMismatch:
5932 assert len(a.shape) == len(b.shape)
5933 assert a.dtype == b.dtype
5934
5935 shape = []
5936 for i in range(len(a.shape)):
5937 shape.append(a.shape[i])
5938
5939 fuzz_idx = rng.integers(0, len(a.shape))
5940 if error_name == ErrorIf.DimensionMismatch:
5941 shape[fuzz_idx] += 1
5942
5943 if error_name == ErrorIf.WrongOutputType:
5944 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5945 outputDType = rng.choice(wrong_dtypes)
5946 else:
5947 outputDType = DType.SHAPE
5948 return ser.addOutput(shape, outputDType)