blob: ee45f0eca0b597f8734250ac93684c3fef7ed612 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
Won Jeon2c34b462024-02-06 18:37:00 +000079 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010080 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Won Jeon2c34b462024-02-06 18:37:00 +0000155 if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Jerry Ge20ab3df2024-01-26 16:56:55 +0000198 elif dtype == DType.INT16:
199 return np.int16(self.rng.integers(low=low, high=high, size=shape))
200 elif dtype == DType.UINT16:
201 return np.uint16(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000202 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100203 return np.int64(self.rng.integers(low=low, high=high, size=shape))
Won Jeon2c34b462024-02-06 18:37:00 +0000204 elif dtype in (
205 DType.FP16,
206 DType.BF16,
207 DType.FP32,
208 DType.FP8E4M3,
209 DType.FP8E5M2,
210 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100211 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
212
213 if dtype == DType.FP16:
214 return np.float16(f_tensor)
215 else:
216 f32_tensor = np.float32(f_tensor)
217 if dtype == DType.BF16:
218 # Floor the last 16 bits of each f32 value
219 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
Won Jeon2c34b462024-02-06 18:37:00 +0000220 elif dtype == DType.FP8E4M3:
221 return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
222 elif dtype == DType.FP8E5M2:
223 return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
Jeremy Johnson1271c442023-09-05 11:39:26 +0100224 else:
225 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700226 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100227 # All other integer types
228 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
Kevin Cheng989cb052021-04-28 16:29:44 -0700230 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700231 placeholders = []
232
Kevin Cheng989cb052021-04-28 16:29:44 -0700233 assert len(shape_list) == len(dtype_list)
234
Jeremy Johnson1271c442023-09-05 11:39:26 +0100235 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700236 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100237 if not self.args.lazy_data_gen:
238 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700239 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700240
241 return placeholders
242
Kevin Cheng989cb052021-04-28 16:29:44 -0700243 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700244 consts = []
245
Kevin Cheng989cb052021-04-28 16:29:44 -0700246 assert len(shape_list) == len(dtype_list)
247
Jeremy Johnson1271c442023-09-05 11:39:26 +0100248 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700249 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100250 if not self.args.lazy_data_gen:
251 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700252 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700253
254 return consts
255
256 def makeShape(self, rank):
257 if self.targetted_shape:
258 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800259 return np.int32(
260 self.rng.integers(
261 low=self.args.tensor_shape_range[0],
262 high=self.args.tensor_shape_range[1],
263 size=rank,
264 )
265 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700266
267 def setTargetShape(self, shape):
268 self.targetted_shape = shape
269
270 def randInt(self, low=0, high=256):
271 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
272
273 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100274 low, high = self.getDTypeRange(dtype)
275
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100276 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100277 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100278 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100279 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100280 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100281 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
282 return gtu.vect_f32_to_bf16(rand_f32)
Won Jeon2c34b462024-02-06 18:37:00 +0000283 elif dtype == DType.FP8E4M3:
284 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
285 return gtu.vect_f32_to_fp8e4m3(rand_f32)
286 elif dtype == DType.FP8E5M2:
287 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
288 return gtu.vect_f32_to_fp8e5m2(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700289 elif dtype == DType.BOOL:
290 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000291 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 # Special size
293 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700294
295 return np.int32(self.rng.integers(low, high, size=1))[0]
296
297 def shapeStr(self, shape):
298
299 sStr = []
300 # Convert to strings
301 for i in shape:
302 sStr.append(str(i))
303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100306 def typeStr(self, dtype):
307 if isinstance(dtype, list) or isinstance(dtype, tuple):
308 assert len(dtype) >= 2
309 strs = [self.typeStr(t) for t in dtype]
310 # Limit types to the first 2 as the 3rd is the accumulator
311 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700312 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100313 if dtype in gtu.DTYPE_ATTRIBUTES:
314 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700315 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100316 raise Exception(
317 "Unknown dtype, cannot convert to string: {}".format(dtype)
318 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700319
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100320 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100321 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100322 if dtype in gtu.DTYPE_ATTRIBUTES:
323 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700324 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100325 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700326
Luke Hutton57287132023-02-06 14:54:18 +0000327 def constrictBatchSize(self, shape):
328 # Limit the batch size unless an explicit target shape set
329 if self.args.max_batch_size and not self.args.target_shapes:
330 shape[0] = min(shape[0], self.args.max_batch_size)
331 return shape
332
James Ward30124a82023-02-02 14:56:33 +0000333 def makeDimension(self):
334 return self.randInt(
335 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
336 )
337
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100338 def tensorComplianceMetaData(
339 self, op, inputType, argsDict, outputTensor, errorName
340 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000341 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
342 UNSUPPORTED_NON_FP32_INPUT_OPS = (
343 Op.MATMUL,
344 Op.CONV2D,
345 Op.FULLY_CONNECTED,
346 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000347 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000348 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000349 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100350 if (
351 errorName
352 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000353 or (
354 not gtu.dtypeIsSupportedByCompliance(inputType)
355 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
356 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100357 ):
358 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100359 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100360
Jeremy Johnson1271c442023-09-05 11:39:26 +0100361 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100362 compliance_tens = {
363 "mode": None,
364 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
365 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
366 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100367 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
368 mode = gtu.ComplianceMode.DOT_PRODUCT
369 compliance_tens["dot_product_info"] = {
370 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100371 "ks": int(argsDict["ksb"])
372 if "ksb" in argsDict
373 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100374 }
375 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
376 mode = gtu.ComplianceMode.FP_SPECIAL
377 elif "compliance" in op and "ulp" in op["compliance"]:
378 mode = gtu.ComplianceMode.ULP
379 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000380 elif "compliance" in op and "relative" in op["compliance"]:
381 mode = gtu.ComplianceMode.RELATIVE
382 compliance_tens["relative_info"] = {
383 "max": argsDict["max_abs_value"],
384 "scale": op["compliance"]["relative"],
385 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100386 elif op["op"] == Op.REDUCE_PRODUCT:
387 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000388 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000389 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000390 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000391 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
392 compliance_tens["abs_error_info"] = {
393 "lower_bound": op["compliance"]["abs_error_lower_bound"]
394 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100395 else:
396 mode = gtu.ComplianceMode.EXACT
397 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
398
399 return compliance_tens
400
401 # Build Op functions
402 # Create the output tensor (calling OutputShaper as needed)
403 # Do final tweaks to attributes (if necessary for errorIf)
404 # Add Op into graph
405 # Return resulting tensor information or BuildInfo
406
407 class BuildInfo:
408 """Enhanced build information containing result tensor and associated compliance dict."""
409
410 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000411 if isinstance(resultTensor, list):
412 assert complianceDict is None or isinstance(complianceDict, list)
413 self.resultTensorList = resultTensor
414 self.complianceDictList = complianceDict
415 else:
416 self.resultTensorList = [resultTensor]
417 if complianceDict is None:
418 self.complianceDictList = None
419 else:
420 self.complianceDictList = [complianceDict]
421
422 def getComplianceInfo(self):
423 if self.complianceDictList is None:
424 return None
425 else:
426 tens_dict = {}
427 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
428 if comp is not None:
429 tens_dict[tens.name] = comp
430
431 if tens_dict:
432 # Have some compliance data, so return the info
433 compliance = {
434 "version": "0.1",
435 "tensors": tens_dict,
436 }
437 else:
438 compliance = None
439 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700440
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000441 def build_unary(
442 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
443 ):
444 assert len(inputs) == 1
445 a = inputs[0]
446 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100447
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000448 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100449
450 # Ensure new output type has correct qinfo
451 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000452 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000453 qinfo = [
454 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000455 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000456 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100457
458 # Invalidate Input/Output list for error if checks.
459 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000460 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100461 pCount, cCount = op["operands"]
462 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000463 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
464 self, error_name, input_list, output_list
465 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100466
Les Bell729b0352021-11-24 10:28:21 +0000467 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100468 self.ser,
469 validator_fcns,
470 error_name,
471 op=op,
472 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000473 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000474 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000475 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100476 input_list=input_list,
477 output_list=output_list,
478 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000479 ):
480 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100481
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000482 attr = None
483 if op["op"] == Op.NEGATE:
484 attr = ts.TosaSerializerAttribute()
485 attr.NegateAttribute(qinfo[0], qinfo[1])
486
487 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000488
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000489 compliance = self.tensorComplianceMetaData(
490 op, a.dtype, args_dict, result_tensor, error_name
491 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000492 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700493
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000494 def build_binary_broadcast(
495 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
496 ):
497 assert len(inputs) == 2
498 a, b = inputs
499 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000500 self.ser, self.rng, a, b, error_name
501 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100502
503 # Invalidate Input/Output list for error if checks.
504 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000505 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100506 pCount, cCount = op["operands"]
507 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000508 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
509 self, error_name, input_list, output_list
510 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100511
Les Bell729b0352021-11-24 10:28:21 +0000512 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100513 self.ser,
514 validator_fcns,
515 error_name,
516 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000517 input1=a,
518 input2=b,
519 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000520 output_dtype=result_tensor.dtype,
521 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100522 input_list=input_list,
523 output_list=output_list,
524 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000525 ):
526 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100527
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000528 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000529
Jeremy Johnson9a758382023-11-07 16:27:35 +0000530 compliance = self.tensorComplianceMetaData(
531 op, a.dtype, args_dict, result_tensor, error_name
532 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000533
534 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700535
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100536 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700537 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000538 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700539 return result_tens
540
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000541 def build_arithmetic_right_shift(
Jeremy Johnson587cc842024-02-08 11:45:44 +0000542 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000543 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000544 assert len(inputs) == 2
545 a, b = inputs
546 round = args_dict["round"]
547 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000548 self.ser, self.rng, a, b, error_name
549 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100550
551 # Invalidate Input/Output list for error if checks.
552 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000553 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100554 pCount, cCount = op["operands"]
555 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000556 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
557 self, error_name, input_list, output_list
558 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100559
Les Bell729b0352021-11-24 10:28:21 +0000560 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100561 self.ser,
562 validator_fcns,
563 error_name,
564 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000565 input1=a,
566 input2=b,
567 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000568 output_dtype=result_tensor.dtype,
569 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100570 input_list=input_list,
571 output_list=output_list,
572 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000573 ):
574 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800575
576 attr = ts.TosaSerializerAttribute()
577 attr.ArithmeticRightShiftAttribute(round)
578
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000579 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000580
581 compliance = self.tensorComplianceMetaData(
582 op, a.dtype, args_dict, result_tensor, error_name
583 )
584
585 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800586
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100587 def build_mul(
588 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
589 ):
590 assert len(inputs) == 2
591 a, b = inputs
592 shift = args_dict["shift"]
593
594 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000595 self.ser, self.rng, a, b, error_name
596 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700597
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100598 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100599 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100600 result_tensor.setDtype(DType.INT32)
601
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100602 if error_name == ErrorIf.WrongOutputType:
603 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
604 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100605 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100606
607 # Invalidate Input/Output list for error if checks.
608 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100609 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100610 pCount, cCount = op["operands"]
611 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000612 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
613 self, error_name, input_list, output_list
614 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100615
Les Bell729b0352021-11-24 10:28:21 +0000616 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100617 self.ser,
618 validator_fcns,
619 error_name,
620 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000621 input1=a,
622 input2=b,
623 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100624 output_dtype=result_tensor.dtype,
625 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100626 input_list=input_list,
627 output_list=output_list,
628 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000629 ):
630 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700631
Kevin Chengaee1fac2020-11-11 13:54:06 -0800632 attr = ts.TosaSerializerAttribute()
633 attr.MulAttribute(shift)
634
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000635 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100636
637 compliance = self.tensorComplianceMetaData(
638 op, a.dtype, args_dict, result_tensor, error_name
639 )
640
641 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700642
Jeremy Johnson587cc842024-02-08 11:45:44 +0000643 def build_table(
644 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
645 ):
646 assert len(inputs) == 1
647 a = inputs[0]
648 table = args_dict["table"]
649 result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700650
Kevin Chengfe392ce2021-10-18 21:51:55 +0000651 attr = ts.TosaSerializerAttribute()
652 attr.TableAttribute(table)
653
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100654 # Invalidate Input/Output list for error if checks.
655 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000656 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100657 pCount, cCount = op["operands"]
658 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000659 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
660 self, error_name, input_list, output_list
661 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100662
Les Bell729b0352021-11-24 10:28:21 +0000663 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100664 self.ser,
665 validator_fcns,
666 error_name,
667 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000668 input_shape=a.shape,
669 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000670 output_dtype=result_tensor.dtype,
671 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100672 input_list=input_list,
673 output_list=output_list,
674 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000675 ):
676 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100677
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000678 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700679
Jeremy Johnson587cc842024-02-08 11:45:44 +0000680 compliance = self.tensorComplianceMetaData(
681 op, a.dtype, args_dict, result_tensor, error_name
682 )
683
684 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700685
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000686 def build_select(
687 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
688 ):
689 assert len(inputs) == 3
690 cond, a, b = inputs
691
692 result_tensor = OutputShaper.selectOp(
693 self.ser, self.rng, cond, a, b, error_name
694 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100695
696 # Invalidate Input/Output list for error if checks.
697 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000698 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100699 pCount, cCount = op["operands"]
700 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000701 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
702 self, error_name, input_list, output_list
703 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100704
Les Bell729b0352021-11-24 10:28:21 +0000705 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100706 self.ser,
707 validator_fcns,
708 error_name,
709 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000710 input1=cond,
711 input2=a,
712 input3=b,
713 input_shape=a.shape,
714 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000715 output_dtype=result_tensor.dtype,
716 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100717 input_list=input_list,
718 output_list=output_list,
719 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000720 ):
721 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100722
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000723 self.ser.addOperator(
724 op["op"],
725 input_list,
726 output_list,
727 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000728 compliance = self.tensorComplianceMetaData(
729 op, a.dtype, args_dict, result_tensor, error_name
730 )
731
732 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700733
Jeremy Johnsona0150012023-11-15 15:52:06 +0000734 def build_comparison(
735 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
736 ):
737 assert len(inputs) == 2
738 a, b = inputs
739
740 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000741 self.ser, self.rng, a, b, error_name
742 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100743
744 # Invalidate Input/Output list for error if checks.
745 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000746 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100747 pCount, cCount = op["operands"]
748 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000749 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
750 self, error_name, input_list, output_list
751 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100752
Les Bell729b0352021-11-24 10:28:21 +0000753 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100754 self.ser,
755 validator_fcns,
756 error_name,
757 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000758 input1=a,
759 input2=b,
760 input_shape=a.shape,
761 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000762 output_shape=result_tensor.shape,
763 output_dtype=result_tensor.dtype,
764 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100765 input_list=input_list,
766 output_list=output_list,
767 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000768 ):
769 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100770
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000771 self.ser.addOperator(
772 op["op"],
773 input_list,
774 output_list,
775 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000776
777 compliance = self.tensorComplianceMetaData(
778 op, a.dtype, args_dict, result_tensor, error_name
779 )
780 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700781
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000782 def build_argmax(
783 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
784 ):
785 assert len(inputs) == 1
786 a = inputs[0]
787 axis = args_dict["axis"]
788 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100789
790 # Invalidate Input/Output list for error if checks.
791 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000792 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100793 pCount, cCount = op["operands"]
794 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000795 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
796 self, error_name, input_list, output_list
797 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100798
Les Bell729b0352021-11-24 10:28:21 +0000799 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100800 self.ser,
801 validator_fcns,
802 error_name,
803 op=op,
804 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000805 input_shape=a.shape,
806 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000807 output_shape=result_tensor.shape,
808 output_dtype=result_tensor.dtype,
809 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100810 input_list=input_list,
811 output_list=output_list,
812 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000813 ):
814 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700815
816 attr = ts.TosaSerializerAttribute()
817 attr.AxisAttribute(axis)
818
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000819 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000820
821 compliance = self.tensorComplianceMetaData(
822 op, inputs[0].dtype, args_dict, result_tensor, error_name
823 )
824 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700825
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000826 def build_pool2d(
827 self,
828 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100829 inputs,
830 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000831 validator_fcns=None,
832 error_name=None,
833 qinfo=None,
834 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100835 assert len(inputs) == 1
836 input = inputs[0]
837 # max_pool has no accum_dtype
838 accum_dtype = (
839 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
840 )
841 stride = args_dict["stride"]
842 pad = args_dict["pad"]
843 kernel = args_dict["kernel"]
844
Jeremy Johnson0601f802023-11-08 16:28:09 +0000845 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000846 self.ser, self.rng, input, kernel, stride, pad, error_name
847 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100848
849 # Ensure new output type has correct qinfo
850 if error_name == ErrorIf.WrongInputType:
851 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000852 qinfo = [
853 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000854 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000855 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100856
857 # Invalidate Input/Output list for error if checks.
858 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000859 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100860 pCount, cCount = op["operands"]
861 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000862 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
863 self, error_name, input_list, output_list
864 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100865
Les Bell729b0352021-11-24 10:28:21 +0000866 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100867 self.ser,
868 validator_fcns,
869 error_name,
870 op=op,
871 input_shape=input.shape,
872 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000873 output_shape=result_tensor.shape,
874 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000875 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100876 kernel=kernel,
877 stride=stride,
878 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000879 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000880 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100881 input_list=input_list,
882 output_list=output_list,
883 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000884 ):
885 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700886
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000887 if qinfo is None:
888 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700889
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000890 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100891 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000892
893 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700894
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100895 compliance = self.tensorComplianceMetaData(
896 op, inputs[0].dtype, args_dict, result_tensor, error_name
897 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100898
899 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100900
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000901 def build_conv2d(
902 self,
903 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100904 inputs,
905 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000906 validator_fcns=None,
907 error_name=None,
908 qinfo=None,
909 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100910 assert len(inputs) == 3
911 ifm, filter, bias = inputs
912 accum_dtype = args_dict["acc_type"]
913 strides = args_dict["stride"]
914 padding = args_dict["pad"]
915 dilations = args_dict["dilation"]
916
Kevin Cheng550ccc52021-03-03 11:21:43 -0800917 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100918 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100919 self.ser,
920 self.rng,
921 ifm,
922 filter,
923 accum_dtype,
924 strides,
925 padding,
926 dilations,
927 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000928 )
929
930 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000931 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
932 DType.INT8,
933 DType.UINT8,
934 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000935 qinfo = [
936 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100937 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000938 ]
Les Bell0e027d42021-11-09 14:42:14 +0000939
940 # Invalidate Input/Output list for error_if checks.
941 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100942 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000943 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000944 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
945 self, error_name, input_list, output_list
946 )
Les Bell0e027d42021-11-09 14:42:14 +0000947
Les Bell729b0352021-11-24 10:28:21 +0000948 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000949 self.ser,
950 validator_fcns,
951 error_name,
952 op=op,
953 input_dtype=ifm.dtype,
954 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100955 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000956 qinfo=qinfo,
957 input_list=input_list,
958 num_operands=num_operands,
959 output_list=output_list,
960 pad=padding,
961 stride=strides,
962 dilation=dilations,
963 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100964 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100965 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000966 ):
967 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700968
Tai Lyd3797f02023-11-15 23:06:19 +0000969 # TODO - Test local_bound, for now set local bound attribute to False
970 local_bound = False
971
Eric Kunzee5e26762020-10-13 16:11:07 -0700972 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000973 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700974
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000975 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100976
977 compliance = self.tensorComplianceMetaData(
978 op, ifm.dtype, args_dict, result_tensor, error_name
979 )
980
981 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700982
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000983 def build_conv3d(
984 self,
985 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100986 inputs,
987 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000988 validator_fcns=None,
989 error_name=None,
990 qinfo=None,
991 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100992 assert len(inputs) == 3
993 ifm, filter, bias = inputs
994 accum_dtype = args_dict["acc_type"]
995 strides = args_dict["stride"]
996 padding = args_dict["pad"]
997 dilations = args_dict["dilation"]
998
Kevin Cheng1533b852021-09-01 12:51:58 -0700999 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +00001000 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +01001001 self.ser,
1002 self.rng,
1003 ifm,
1004 filter,
1005 accum_dtype,
1006 strides,
1007 padding,
1008 dilations,
1009 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001010 )
1011
1012 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001013 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1014 DType.INT8,
1015 DType.UINT8,
1016 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001017 qinfo = [
1018 TosaQuantGen.getZeroPoint(self, ifm.dtype),
evacha0147ab1762024-01-29 13:23:23 +00001019 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001020 ]
Les Bell0e027d42021-11-09 14:42:14 +00001021
1022 # Invalidate Input/Output list for error_if checks.
1023 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +00001024 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001025 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001026 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1027 self, error_name, input_list, output_list
1028 )
Les Bell0e027d42021-11-09 14:42:14 +00001029
Les Bell729b0352021-11-24 10:28:21 +00001030 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001031 self.ser,
1032 validator_fcns,
1033 error_name,
1034 op=op,
1035 input_dtype=ifm.dtype,
1036 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +00001037 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001038 qinfo=qinfo,
1039 input_list=input_list,
1040 num_operands=num_operands,
1041 output_list=output_list,
1042 pad=padding,
1043 stride=strides,
1044 dilation=dilations,
1045 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001046 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +00001047 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001048 ):
1049 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07001050
Tai Lyd3797f02023-11-15 23:06:19 +00001051 # TODO - Test local_bound, for now set local bound attribute to False
1052 local_bound = False
1053
Kevin Cheng1533b852021-09-01 12:51:58 -07001054 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001055 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -07001056
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001057 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001058
1059 compliance = self.tensorComplianceMetaData(
1060 op, ifm.dtype, args_dict, result_tensor, error_name
1061 )
1062
1063 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001064
Kevin Cheng550ccc52021-03-03 11:21:43 -08001065 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001066 self,
1067 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001068 inputs,
1069 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001070 validator_fcns=None,
1071 error_name=None,
1072 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001073 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001074 assert len(inputs) == 3
1075 ifm, filter, bias = inputs
1076 accum_dtype = args_dict["acc_type"]
1077 strides = args_dict["stride"]
1078 out_pad = args_dict["pad"]
1079 output_shape = args_dict["out_shape"]
1080
TatWai Chong24594f52022-06-08 00:48:04 -07001081 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001082 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001083 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001084 )
Les Bell0e027d42021-11-09 14:42:14 +00001085
1086 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001087 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1088 DType.INT8,
1089 DType.UINT8,
1090 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001091 qinfo = [
1092 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001093 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001094 ]
Les Bell0e027d42021-11-09 14:42:14 +00001095
1096 # Invalidate Input/Output list for error_if checks.
1097 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001098 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001099 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001100 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1101 self, error_name, input_list, output_list
1102 )
Les Bell0e027d42021-11-09 14:42:14 +00001103
Les Bell729b0352021-11-24 10:28:21 +00001104 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001105 self.ser,
1106 validator_fcns,
1107 error_name,
1108 op=op,
1109 input_dtype=ifm.dtype,
1110 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001111 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001112 qinfo=qinfo,
1113 input_list=input_list,
1114 num_operands=num_operands,
1115 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001116 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001117 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001118 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001119 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001120 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001121 ):
1122 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001123
Tai Lyd3797f02023-11-15 23:06:19 +00001124 # TODO - Test local_bound, for now set local bound attribute to False
1125 local_bound = False
1126
Eric Kunzee5e26762020-10-13 16:11:07 -07001127 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001128 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001129 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001130 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001131
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001132 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001133
1134 compliance = self.tensorComplianceMetaData(
1135 op, ifm.dtype, args_dict, result_tensor, error_name
1136 )
1137
1138 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001139
Kevin Cheng550ccc52021-03-03 11:21:43 -08001140 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001141 self,
1142 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001143 inputs,
1144 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001145 validator_fcns=None,
1146 error_name=None,
1147 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001148 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001149 assert len(inputs) == 3
1150 ifm, filter, bias = inputs
1151 accum_dtype = args_dict["acc_type"]
1152 strides = args_dict["stride"]
1153 padding = args_dict["pad"]
1154 dilations = args_dict["dilation"]
1155
Jeremy Johnson4f931302024-01-04 17:05:24 +00001156 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001157 self.ser,
1158 self.rng,
1159 ifm,
1160 filter,
1161 accum_dtype,
1162 strides,
1163 padding,
1164 dilations,
1165 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001166 )
1167
1168 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001169 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1170 DType.INT8,
1171 DType.UINT8,
1172 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001173 qinfo = [
1174 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001175 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001176 ]
Les Bell0e027d42021-11-09 14:42:14 +00001177
1178 # Invalidate Input/Output list for error_if checks.
1179 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001180 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001181 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001182 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1183 self, error_name, input_list, output_list
1184 )
Les Bell0e027d42021-11-09 14:42:14 +00001185
Les Bell729b0352021-11-24 10:28:21 +00001186 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001187 self.ser,
1188 validator_fcns,
1189 error_name,
1190 op=op,
1191 input_dtype=ifm.dtype,
1192 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001193 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001194 qinfo=qinfo,
1195 input_list=input_list,
1196 num_operands=num_operands,
1197 output_list=output_list,
1198 pad=padding,
1199 stride=strides,
1200 dilation=dilations,
1201 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001202 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001203 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001204 ):
1205 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001206
Tai Lyd3797f02023-11-15 23:06:19 +00001207 # TODO - Test local_bound, for now set local bound attribute to False
1208 local_bound = False
1209
Eric Kunzee5e26762020-10-13 16:11:07 -07001210 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001211 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001212
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001213 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001214
1215 compliance = self.tensorComplianceMetaData(
1216 op, ifm.dtype, args_dict, result_tensor, error_name
1217 )
1218
1219 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001220
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001221 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001222 self,
1223 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001224 inputs,
1225 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001226 validator_fcns=None,
1227 error_name=None,
1228 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001229 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001230 assert len(inputs) == 3
1231 ifm, filter, bias = inputs
1232 accum_dtype = args_dict["acc_type"]
1233
1234 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001235 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001236 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001237
1238 # Invalidate Input/Output list for error if checks.
1239 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001240 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001241 pCount, cCount = op["operands"]
1242 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001243 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1244 self, error_name, input_list, output_list
1245 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001246
Les Bell729b0352021-11-24 10:28:21 +00001247 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001248 self.ser,
1249 validator_fcns,
1250 error_name,
1251 op=op,
1252 input_shape=ifm.shape,
1253 input_dtype=ifm.dtype,
1254 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001255 output_shape=result_tensor.shape,
1256 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001257 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001258 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001259 input_list=input_list,
1260 output_list=output_list,
1261 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001262 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001263 ):
1264 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001265
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001266 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001267 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001268
1269 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001270
1271 compliance = self.tensorComplianceMetaData(
1272 op, ifm.dtype, args_dict, result_tensor, error_name
1273 )
1274
1275 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001276
James Ward8b390432022-08-12 20:48:56 +01001277 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001278 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001279 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001280 assert len(inputs) == 2
1281 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001282 accum_dtype = args_dict["acc_type"]
1283 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001284 self.ser, self.rng, a, b, accum_dtype, error_name
1285 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001286
1287 # Invalidate Input/Output list for error if checks.
1288 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001289 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001290 pCount, cCount = op["operands"]
1291 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001292 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1293 self, error_name, input_list, output_list
1294 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001295
Les Bell729b0352021-11-24 10:28:21 +00001296 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001297 self.ser,
1298 validator_fcns,
1299 error_name,
1300 op=op,
1301 input_shape=a.shape,
1302 input_dtype=a.dtype,
1303 input2_shape=b.shape,
1304 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001305 output_shape=result_tensor.shape,
1306 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001307 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001308 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001309 input_list=input_list,
1310 output_list=output_list,
1311 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001312 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001313 ):
1314 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001315
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001316 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001317 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001318
1319 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001320
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001321 compliance = self.tensorComplianceMetaData(
1322 op, a.dtype, args_dict, result_tensor, error_name
1323 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001324
1325 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001326
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001327 def build_reduce(
1328 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1329 ):
1330 assert len(inputs) == 1
1331 a = inputs[0]
1332 axis = args_dict["axis"]
1333 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001334
1335 # Invalidate Input/Output list for error if checks.
1336 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001337 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001338 pCount, cCount = op["operands"]
1339 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001340 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1341 self, error_name, input_list, output_list
1342 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001343
Les Bell729b0352021-11-24 10:28:21 +00001344 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001345 self.ser,
1346 validator_fcns,
1347 error_name,
1348 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001349 axis=axis,
1350 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001351 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001352 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001353 output_dtype=result_tensor.dtype,
1354 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001355 input_list=input_list,
1356 output_list=output_list,
1357 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001358 ):
1359 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001360
1361 attr = ts.TosaSerializerAttribute()
1362 attr.AxisAttribute(axis)
1363
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001364 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001365
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001366 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1367 # Number of products - needed for compliance
1368 args_dict["n"] = a.shape[axis]
1369
1370 compliance = self.tensorComplianceMetaData(
1371 op, a.dtype, args_dict, result_tensor, error_name
1372 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001373
1374 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001375
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001376 def build_clamp(
1377 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1378 ):
1379 assert len(inputs) == 1
1380 a = inputs[0]
1381
1382 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001383
Jeremy Johnson18e26662021-07-22 16:15:29 +01001384 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001385
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001386 if error_name == ErrorIf.MaxSmallerMin:
1387 # Make sure the numbers are different to invoke this error
1388 while v[0] == v[1]:
1389 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1390 max_val = min(v)
1391 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001392 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001393 max_val = max(v)
1394 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001395
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001396 # Invalidate Input/Output list for error if checks.
1397 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001398 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001399 pCount, cCount = op["operands"]
1400 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001401 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1402 self, error_name, input_list, output_list
1403 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001404
Les Bell729b0352021-11-24 10:28:21 +00001405 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001406 self.ser,
1407 validator_fcns,
1408 error_name,
1409 op=op,
1410 max_val=max_val,
1411 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001412 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001413 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001414 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001415 output_dtype=result_tensor.dtype,
1416 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001417 input_list=input_list,
1418 output_list=output_list,
1419 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001420 ):
1421 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001422
1423 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001424 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1425 if a.dtype == DType.FP16:
1426 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1427 min_val = min_val.astype(np.float32)
1428 max_val = max_val.astype(np.float32)
1429
1430 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001431 elif a.dtype in (DType.INT8, DType.INT16):
James Ward34071252022-12-07 15:48:47 +00001432 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Won Jeon2c34b462024-02-06 18:37:00 +00001433 else:
1434 # to avoid internal error for incorrect input types
1435 attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001436
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001437 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001438
1439 compliance = self.tensorComplianceMetaData(
1440 op, a.dtype, args_dict, result_tensor, error_name
1441 )
1442
1443 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001444
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001445 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1446 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001447 attr = ts.TosaSerializerAttribute()
1448
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001449 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001450
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001451 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001452 return result_tens
1453
1454 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001455 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1456 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001457
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001458 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001459 return result_tens
1460
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001461 def build_activation(
1462 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1463 ):
1464 assert len(inputs) == 1
1465 a = inputs[0]
1466
1467 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001468
1469 # Invalidate Input/Output list for error if checks.
1470 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001471 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001472 pCount, cCount = op["operands"]
1473 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001474 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1475 self, error_name, input_list, output_list
1476 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001477
Les Bell729b0352021-11-24 10:28:21 +00001478 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001479 self.ser,
1480 validator_fcns,
1481 error_name,
1482 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001483 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001484 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001485 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001486 output_dtype=result_tensor.dtype,
1487 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001488 input_list=input_list,
1489 output_list=output_list,
1490 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001491 ):
1492 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001493
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001494 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001495
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001496 compliance = self.tensorComplianceMetaData(
1497 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001498 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001499
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001500 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001501
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001502 def build_concat(
1503 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1504 ):
Won Jeon74342e52024-01-09 00:34:40 +00001505 if op["op"] == Op.CONCAT_SHAPE:
1506 axis = 0
1507 else:
1508 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001509 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001510 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001511
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001512 result_tensor = OutputShaper.concatOp(
1513 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001514 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001515
Matthew Haddon818ab902021-07-27 09:12:49 +01001516 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001517 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001518 input_tensor_names.append(tensor.name)
1519
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001520 # Invalidate Input/Output list for error if checks.
1521 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001522 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001523 pCount, cCount = op["operands"]
1524 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001525 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1526 self, error_name, input_list, output_list
1527 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001528
Les Bell729b0352021-11-24 10:28:21 +00001529 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001530 self.ser,
1531 validator_fcns,
1532 error_name,
1533 op=op,
1534 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001535 input_shape=inputs[0].shape,
1536 output_shape=result_tensor.shape,
1537 input_dtype=inputs[0].dtype,
1538 output_dtype=result_tensor.dtype,
1539 inputs=inputs,
1540 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001541 input_list=input_list,
1542 output_list=output_list,
1543 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001544 ):
1545 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001546
Won Jeon74342e52024-01-09 00:34:40 +00001547 if op["op"] == Op.CONCAT:
1548 attr = ts.TosaSerializerAttribute()
1549 attr.AxisAttribute(axis)
1550 else:
1551 assert op["op"] == Op.CONCAT_SHAPE
1552 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001553 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001554
1555 compliance = self.tensorComplianceMetaData(
1556 op, inputs[0].dtype, args_dict, result_tensor, error_name
1557 )
1558
1559 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001560
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001561 def build_pad(
1562 self,
1563 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001564 inputs,
1565 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001566 validator_fcns=None,
1567 error_name=None,
1568 qinfo=None,
1569 ):
Tai Lye095da72024-01-25 22:00:18 +00001570 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001571 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001572 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001573 padding = args_dict["pad"]
1574 pad_const_int = args_dict["pad_const_int"]
1575 pad_const_float = args_dict["pad_const_fp"]
1576
1577 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001578
Tai Lye095da72024-01-25 22:00:18 +00001579 # write empty padding into PadAttribute to ensure inputs[1] is used
Kevin Chengfe392ce2021-10-18 21:51:55 +00001580 attr = ts.TosaSerializerAttribute()
Tai Lye095da72024-01-25 22:00:18 +00001581 attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001582
Matthew Haddone807aae2021-10-11 18:12:58 +01001583 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001584 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001585 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001586 pCount, cCount = op["operands"]
1587 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001588 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1589 self, error_name, input_list, output_list
1590 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001591
Les Bell729b0352021-11-24 10:28:21 +00001592 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001593 self.ser,
1594 validator_fcns,
1595 error_name,
1596 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001597 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001598 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001599 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001600 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001601 pad=padding,
1602 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001603 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001604 input_list=input_list,
1605 output_list=output_list,
1606 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001607 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001608 ):
1609 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001610
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001611 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001612
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001613 compliance = self.tensorComplianceMetaData(
1614 op, a.dtype, args_dict, result_tensor, error_name
1615 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001616
1617 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001618
Won Jeona21b2e82023-08-10 10:33:01 +00001619 def build_dim(
1620 self,
1621 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001622 inputs,
1623 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001624 validator_fcns=None,
1625 error_name=None,
1626 qinfo=None,
1627 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001628 assert len(inputs) == 1
1629 a = inputs[0]
1630 axis = args_dict["axis"]
1631 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001632
1633 # Invalidate Input/Output list for error if checks.
1634 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001635 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001636 pCount, cCount = op["operands"]
1637 num_operands = pCount + cCount
1638 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1639 self, error_name, input_list, output_list
1640 )
1641
1642 if not TosaErrorValidator.evValidateErrorIfs(
1643 self.ser,
1644 validator_fcns,
1645 error_name,
1646 op=op,
1647 axis=axis,
1648 input_shape=a.shape,
1649 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001650 output_shape=result_tensor.shape,
1651 output_dtype=result_tensor.dtype,
1652 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001653 input_list=input_list,
1654 output_list=output_list,
1655 num_operands=num_operands,
1656 ):
1657 return None
1658
1659 attr = ts.TosaSerializerAttribute()
1660 attr.AxisAttribute(axis)
1661
1662 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001663 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001664
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001665 def build_reshape(
1666 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1667 ):
Tai Ly8690a082023-12-18 20:40:24 +00001668 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001669 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001670 shape = inputs[1]
1671 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001672 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001673 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001674 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001675
1676 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001677 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001678 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001679 pCount, cCount = op["operands"]
1680 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001681 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1682 self, error_name, input_list, output_list
1683 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001684
Les Bell729b0352021-11-24 10:28:21 +00001685 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001686 self.ser,
1687 validator_fcns,
1688 error_name,
1689 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001690 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001691 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001692 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001693 output_dtype=result_tensor.dtype,
1694 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001695 input_list=input_list,
1696 output_list=output_list,
1697 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001698 ):
1699 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001700
Tai Ly8690a082023-12-18 20:40:24 +00001701 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001702
1703 compliance = self.tensorComplianceMetaData(
1704 op, a.dtype, args_dict, result_tensor, error_name
1705 )
1706
1707 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001708
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001709 def build_reverse(
1710 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1711 ):
1712 assert len(inputs) == 1
1713 a = inputs[0]
1714 axis = args_dict["axis"]
1715 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001716
1717 # Invalidate Input/Output list for error if checks.
1718 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001719 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001720 pCount, cCount = op["operands"]
1721 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001722 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1723 self, error_name, input_list, output_list
1724 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001725
Les Bell729b0352021-11-24 10:28:21 +00001726 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001727 self.ser,
1728 validator_fcns,
1729 error_name,
1730 op=op,
1731 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001732 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001733 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001734 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001735 output_dtype=result_tensor.dtype,
1736 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001737 input_list=input_list,
1738 output_list=output_list,
1739 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001740 ):
1741 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001742
1743 attr = ts.TosaSerializerAttribute()
1744 attr.AxisAttribute(axis)
1745
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001746 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001747 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001748
evacha0198477222024-01-26 12:25:32 +00001749 def build_transpose(
1750 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1751 ):
1752 assert len(inputs) == 1
1753 a = inputs[0]
1754 perms = args_dict["perms"]
1755
1756 result_tensor = OutputShaper.transposeOp(
1757 self.ser, self.rng, a, perms, error_name
1758 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001759
Kevin Chengfe392ce2021-10-18 21:51:55 +00001760 attr = ts.TosaSerializerAttribute()
1761 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001762
Matthew Haddone807aae2021-10-11 18:12:58 +01001763 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001764 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001765 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001766 pCount, cCount = op["operands"]
1767 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001768 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1769 self, error_name, input_list, output_list
1770 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001771
Les Bell729b0352021-11-24 10:28:21 +00001772 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001773 self.ser,
1774 validator_fcns,
1775 error_name,
1776 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001777 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001778 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001779 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001780 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001781 output_dtype=result_tensor.dtype,
1782 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001783 input_list=input_list,
1784 output_list=output_list,
1785 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001786 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001787 ):
1788 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001789
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001790 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001791
1792 compliance = self.tensorComplianceMetaData(
1793 op, a.dtype, args_dict, result_tensor, error_name
1794 )
1795
1796 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001797
evacha017f7d4252024-01-24 12:08:09 +00001798 def build_slice(
1799 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1800 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001801 assert len(inputs) == 3
1802 a, start_var, size_var = inputs
1803 start_const = args_dict["start"]
1804 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001805
1806 result_tensor = OutputShaper.sliceOp(
TatWai Chongf15bad82024-01-31 21:33:27 -08001807 self.ser, self.rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001808 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001809
1810 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001811 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001812 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001813 pCount, cCount = op["operands"]
1814 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001815 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1816 self, error_name, input_list, output_list
1817 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001818
Les Bell729b0352021-11-24 10:28:21 +00001819 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001820 self.ser,
1821 validator_fcns,
1822 error_name,
1823 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001824 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001825 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001826 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001827 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001828 start=start_const,
1829 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001830 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001831 input_list=input_list,
1832 output_list=output_list,
1833 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001834 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001835 ):
1836 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001837
Tai Ly8ead6c42024-02-14 22:35:44 +00001838 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001839
1840 compliance = self.tensorComplianceMetaData(
1841 op, a.dtype, args_dict, result_tensor, error_name
1842 )
1843
1844 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001845
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001846 def build_tile(
1847 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1848 ):
Tai Ly8690a082023-12-18 20:40:24 +00001849 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001850 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001851 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001852 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001853 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001854 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001855 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001856
1857 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001858 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001859 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001860 pCount, cCount = op["operands"]
1861 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001862 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1863 self, error_name, input_list, output_list
1864 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001865
Les Bell729b0352021-11-24 10:28:21 +00001866 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001867 self.ser,
1868 validator_fcns,
1869 error_name,
1870 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001871 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001872 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001873 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001874 output_dtype=result_tensor.dtype,
1875 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001876 input_list=input_list,
1877 output_list=output_list,
1878 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001879 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001880 ):
1881 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001882
Tai Ly8690a082023-12-18 20:40:24 +00001883 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001884
1885 compliance = self.tensorComplianceMetaData(
1886 op, a.dtype, args_dict, result_tensor, error_name
1887 )
1888
1889 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001890
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001891 def build_gather(
1892 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1893 ):
1894 assert len(inputs) == 2
1895 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001896
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001897 result_tensor = OutputShaper.gatherOp(
1898 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001899 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001900
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001901 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001902 input_list = [values.name, indices.name]
1903 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001904 pCount, cCount = op["operands"]
1905 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001906 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1907 self, error_name, input_list, output_list
1908 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001909
Les Bell729b0352021-11-24 10:28:21 +00001910 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001911 self.ser,
1912 validator_fcns,
1913 error_name,
1914 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001915 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001916 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001917 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001918 output_dtype=result_tensor.dtype,
1919 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001920 input_list=input_list,
1921 output_list=output_list,
1922 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001923 ):
1924 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001925
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001926 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001927
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001928 compliance = self.tensorComplianceMetaData(
1929 op, values.dtype, args_dict, result_tensor, error_name
1930 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001931
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001932 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001933
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001934 def build_scatter(
1935 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1936 ):
1937 assert len(inputs) == 3
1938 values_in, indices, input = inputs
1939 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001940 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001941 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001942
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001943 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001944 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001945 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001946 pCount, cCount = op["operands"]
1947 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001948 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1949 self, error_name, input_list, output_list
1950 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001951
Les Bell729b0352021-11-24 10:28:21 +00001952 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001953 self.ser,
1954 validator_fcns,
1955 error_name,
1956 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001957 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001958 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001959 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001960 output_dtype=result_tensor.dtype,
1961 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001962 input_list=input_list,
1963 output_list=output_list,
1964 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001965 ):
1966 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001967
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001968 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001969
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001970 compliance = self.tensorComplianceMetaData(
1971 op, values_in.dtype, args_dict, result_tensor, error_name
1972 )
1973
1974 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001975
Kevin Cheng550ccc52021-03-03 11:21:43 -08001976 def build_resize(
1977 self,
1978 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001979 inputs,
1980 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01001981 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001982 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001983 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001984 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001985 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001986 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001987 scale_input = inputs[1]
1988 offset_input = inputs[2]
1989 border_input = inputs[3]
1990
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001991 mode = args_dict["mode"]
1992 scale = args_dict["scale"]
1993 offset = args_dict["offset"]
1994 border = args_dict["border"]
1995 output_dtype = args_dict["output_dtype"]
1996
1997 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08001998 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001999 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002000 input,
2001 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002002 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002003 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002004 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002005 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002006 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002007 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002008 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002009
Matthew Haddon848efb42021-09-09 12:30:53 +01002010 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002011 input_list = [
2012 input.name,
2013 scale_input.name,
2014 offset_input.name,
2015 border_input.name,
2016 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002017 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002018 pCount, cCount = op["operands"]
2019 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002020 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2021 self, error_name, input_list, output_list
2022 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002023
Les Bell729b0352021-11-24 10:28:21 +00002024 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002025 self.ser,
2026 validator_fcns,
2027 error_name,
2028 op=op,
2029 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002030 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002031 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002032 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002033 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002034 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002035 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002036 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002037 input_list=input_list,
2038 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002039 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002040 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002041 ):
2042 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002043
Eric Kunzee5e26762020-10-13 16:11:07 -07002044 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002045 # write empty scale/offset/border into ResizeAttribute
2046 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002047 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002048
2049 compliance = self.tensorComplianceMetaData(
2050 op, input.dtype, args_dict, result_tensor, error_name
2051 )
2052
2053 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002054
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002055 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
2056 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
2057 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002058 self.ser.addOperator(
2059 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2060 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002061 return result_tens
2062
evacha0198477222024-01-26 12:25:32 +00002063 def build_const(
2064 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2065 ):
2066 assert len(inputs) == 1
2067 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002068 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002069
2070 compliance = self.tensorComplianceMetaData(
2071 op, val.dtype, args_dict, val, error_name
2072 )
2073
2074 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002075
2076 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002077 def build_cast(
2078 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2079 ):
2080 assert len(inputs) == 1
2081 val = inputs[0]
2082 out_dtype = args_dict["out_type"]
2083
2084 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002085 self.ser, self.rng, val, out_dtype, error_name
2086 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002087
2088 # Invalidate Input/Output list for error if checks.
2089 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002090 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002091 pCount, cCount = op["operands"]
2092 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002093 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2094 self, error_name, input_list, output_list
2095 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002096
Les Bell729b0352021-11-24 10:28:21 +00002097 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002098 self.ser,
2099 validator_fcns,
2100 error_name,
2101 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002102 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002103 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002104 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002105 output_dtype=result_tensor.dtype,
2106 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002107 input_list=input_list,
2108 output_list=output_list,
2109 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002110 ):
2111 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002112
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002113 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002114
2115 compliance = self.tensorComplianceMetaData(
2116 op, val.dtype, args_dict, result_tensor, error_name
2117 )
2118
2119 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002120
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002121 def build_rescale(
2122 self,
2123 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002124 inputs,
2125 args_dict,
2126 validator_fcns=None,
2127 error_name=None,
2128 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002129 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002130 assert len(inputs) == 1
2131 val = inputs[0]
2132 out_dtype = args_dict["output_dtype"]
2133 scale32 = args_dict["scale"]
2134 double_round = args_dict["double_round"]
2135 per_channel = args_dict["per_channel"]
2136
2137 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002138 self.ser, self.rng, val, out_dtype, error_name
2139 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002140
2141 if per_channel:
2142 nc = val.shape[-1]
2143 else:
2144 nc = 1
2145
2146 in_type_width = self.typeWidth(val.dtype)
2147 out_type_width = self.typeWidth(out_dtype)
2148
Tai Ly8690a082023-12-18 20:40:24 +00002149 input_unsigned = False
2150 output_unsigned = False
2151
Kevin Cheng3a478572021-01-22 17:21:02 -08002152 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002153 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002154 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002155 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002156 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002157 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002158 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002159 elif error_name in [
2160 ErrorIf.InputZeroPointNotZero,
2161 ErrorIf.U16InputZeroPointNotValid,
2162 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002163 input_zp = self.randInt(-128, 128)
2164 if input_zp == 0:
2165 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002166 in_type_width += 1
2167 elif val.dtype == DType.UINT16:
2168 # Must come after ErrorIf.U16InputZeroPointNotValid check
2169 input_zp = self.rng.choice([0, 32768])
2170 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002171 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002172 else:
2173 input_zp = 0
2174
Kevin Cheng3a478572021-01-22 17:21:02 -08002175 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002176 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002177 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002178 elif out_dtype == DType.UINT8:
2179 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002180 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002181 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002182 elif error_name in [
2183 ErrorIf.OutputZeroPointNotZero,
2184 ErrorIf.U16OutputZeroPointNotValid,
2185 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002186 output_zp = self.randInt(-128, 128)
2187 if output_zp == 0:
2188 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002189 out_type_width += 1
2190 elif out_dtype == DType.UINT16:
2191 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2192 output_zp = self.rng.choice([0, 32768])
2193 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002194 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002195 else:
2196 output_zp = 0
2197
2198 # Calculate scale based on:
2199 # scale = a *(2^output_width)/(2^input_width))
2200
2201 a = np.float32(self.rng.random(size=[nc]))
2202 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2203
2204 if scale32:
2205 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002206 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002207 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2208 else:
2209 # Cap the scaling at 2^15 - 1 for scale16
2210 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2211
Kevin Cheng550ccc52021-03-03 11:21:43 -08002212 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002213
2214 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2215 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002216 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2217 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002218
2219 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002220 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2221 scale_arr[i], scale32
2222 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002223 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2224 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002225
Kevin Cheng550ccc52021-03-03 11:21:43 -08002226 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002227 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002228 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002229 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002230 assert val.placeholderFilename
2231 values = np.load(
2232 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2233 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002234 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2235 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2236 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002237 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2238 # Check we can safely convert to the expected dtype
2239 assert (
2240 val_adj.all() >= np.iinfo(values.dtype).min
2241 and val_adj.all() <= np.iinfo(values.dtype).max
2242 )
2243
2244 # Force casting to output datatype
2245 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2246
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002247 if not np.all(np.array_equal(values, val_adj)):
2248 # Values changed so overwrite file with new values
2249 np.save(
2250 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2251 val_adj,
2252 False,
2253 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002254
Matthew Haddonc2025212021-10-08 21:21:05 +01002255 # Invalidate Input/Output list for error if checks.
2256 input_list = [val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002257 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002258 pCount, cCount = op["operands"]
2259 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002260 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2261 self, error_name, input_list, output_list
2262 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002263
2264 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002265 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002266 self.ser,
2267 validator_fcns,
2268 error_name,
2269 op=op,
2270 input_dtype=val.dtype,
2271 output_dtype=out_dtype,
2272 input_shape=val.shape,
2273 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002274 scale32=scale32,
2275 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002276 input_list=input_list,
2277 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002278 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002279 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002280 ):
2281 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002282
Eric Kunzee5e26762020-10-13 16:11:07 -07002283 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002284 attr.RescaleAttribute(
2285 input_zp,
2286 output_zp,
2287 multiplier_arr,
2288 shift_arr,
2289 scale32,
2290 double_round,
2291 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002292 input_unsigned,
2293 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002294 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002295
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002296 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002297
2298 compliance = self.tensorComplianceMetaData(
2299 op, val.dtype, args_dict, result_tensor, error_name
2300 )
2301
2302 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002303
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002304 def _get_condition_tensor(self, op, cond, error_name):
2305 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002306 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002307 else:
2308 cond_type = DType.BOOL
2309 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2310 choice = self.rng.choice([1, 2])
2311 if choice == 1:
2312 cond_shape = [2]
2313 else:
2314 cond_shape = [1, 2]
2315 else:
2316 # Must be of size 1 (rank 0)
2317 cond_shape = []
2318 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2319 return cond_tens
2320
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002321 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002322 self,
2323 op,
2324 inputs,
2325 args_dict,
2326 validator_fcns=None,
2327 error_name=None,
2328 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002329 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002330 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002331 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002332 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002333 assert len(inputs) == 2
2334 then_tens, else_tens = inputs
2335
2336 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002337
2338 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002339 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002340
2341 # Make then/else tensors
2342 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002343
Jeremy Johnson587cc842024-02-08 11:45:44 +00002344 dtype = DType.INT32
2345
Matthew Haddon630c17c2021-10-14 15:05:41 +01002346 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002347 if error_name in [
2348 ErrorIf.CondIfOutputListThenGraphMismatch,
2349 ErrorIf.CondIfOutputListElseGraphMismatch,
2350 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002351 incorrect_shape = deepcopy(then_tens.shape)
2352 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002353 incorrect_shape[i] += (
2354 self.rng.choice([-3, -2, 2, 3])
2355 if incorrect_shape[i] > 3
2356 else self.rng.choice([1, 2, 4])
2357 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002358 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2359
Jeremy Johnson18e26662021-07-22 16:15:29 +01002360 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2361 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002362
2363 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002364 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002365
2366 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002367 then_block = "THEN_BLOCK"
2368 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002369 attr = ts.TosaSerializerAttribute()
2370 attr.CondIfAttribute(then_block, else_block)
2371
2372 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002373 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002374
Jerry Ge9e94af82022-10-27 09:57:00 -07002375 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002376 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002377 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002378 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002379 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002380 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002381 self.ser.addOutputTensor(then_tens)
2382
Jerry Ge9e94af82022-10-27 09:57:00 -07002383 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002384 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002385 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002386 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002387 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002388 self.ser.addOutputTensor(else_tens)
2389
Les Bell729b0352021-11-24 10:28:21 +00002390 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002391 self.ser,
2392 validator_fcns,
2393 error_name,
2394 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002395 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002396 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002397 ):
2398 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002399
Jeremy Johnson587cc842024-02-08 11:45:44 +00002400 compliance = self.tensorComplianceMetaData(
2401 op, dtype, args_dict, result_tensor, error_name
2402 )
2403
2404 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002405
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002406 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002407 self,
2408 op,
2409 inputs,
2410 args_dict,
2411 validator_fcns=None,
2412 error_name=None,
2413 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002414 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002415 # For cond_if with a binary op in the then/else blocks, take a and b and
2416 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002417 assert len(inputs) == 2
2418 a, b = inputs
2419
2420 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002421
2422 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002423 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002424
Jeremy Johnson587cc842024-02-08 11:45:44 +00002425 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002426
2427 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002428 then_block = "THEN_BLOCK"
2429 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002430 attr = ts.TosaSerializerAttribute()
2431 attr.CondIfAttribute(then_block, else_block)
2432
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002433 if error_name in [
2434 ErrorIf.CondIfInputListThenGraphMismatch,
2435 ErrorIf.CondIfInputListElseGraphMismatch,
2436 ErrorIf.CondIfOutputListElseGraphMismatch,
2437 ErrorIf.CondIfOutputListThenGraphMismatch,
2438 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002439 incorrect_shape = a.shape.copy()
2440 for i in range(len(incorrect_shape)):
2441 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2442 incorrect_block_input = deepcopy(a)
2443 incorrect_block_input.shape = incorrect_shape
2444
Eric Kunzee5e26762020-10-13 16:11:07 -07002445 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002446 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002447 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002448 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002449
James Ward24dbc422022-10-19 12:20:31 +01002450 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002451 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002452 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002453 then_op, else_op = (
2454 self.TOSA_OP_LIST["logical_right_shift"],
2455 self.TOSA_OP_LIST["logical_left_shift"],
2456 )
Les Bell6040b4d2021-10-11 12:50:31 +01002457 else:
2458 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002459
Jeremy Johnson587cc842024-02-08 11:45:44 +00002460 # Determine the element-wise binary operation that compliance will need to
2461 # check the results of
2462 compliance_op = then_op if cond else else_op
2463
2464 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002465 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002466 if (
2467 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2468 and block == then_block
2469 ) or (
2470 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2471 and block == else_block
2472 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002473 self.ser.addInputTensor(incorrect_block_input)
2474 self.ser.addInputTensor(b)
2475 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002476 elif (
2477 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2478 and block == then_block
2479 ) or (
2480 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2481 and block == else_block
2482 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002483 self.ser.addInputTensor(a)
2484 self.ser.addInputTensor(b)
2485 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2486 else:
2487 self.ser.addInputTensor(a)
2488 self.ser.addInputTensor(b)
2489 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002490 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002491
Les Bell729b0352021-11-24 10:28:21 +00002492 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002493 self.ser,
2494 validator_fcns,
2495 error_name,
2496 op=op,
2497 a=a,
2498 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002499 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002500 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002501 ):
2502 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002503
Jeremy Johnson587cc842024-02-08 11:45:44 +00002504 compliance = self.tensorComplianceMetaData(
2505 compliance_op, a.dtype, args_dict, result_tensor, error_name
2506 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002507
Jeremy Johnson587cc842024-02-08 11:45:44 +00002508 return TosaTestGen.BuildInfo(result_tensor, compliance)
2509
2510 def build_while_loop(
2511 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2512 ):
2513 assert len(inputs) == 1
2514 a = inputs[0]
2515 iter_val = args_dict["iterations"]
2516
Kevin Cheng550ccc52021-03-03 11:21:43 -08002517 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002518
Kevin Cheng550ccc52021-03-03 11:21:43 -08002519 cond_block = "COND_BLOCK"
2520 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002521
2522 attr = ts.TosaSerializerAttribute()
2523 attr.WhileLoopAttribute(cond_block, body_block)
2524
2525 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002526 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002527 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002528 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002529
2530 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002531 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2532 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002533 if error_name == ErrorIf.InputListOutputListMismatch:
2534 incorrect_acc = deepcopy(acc)
2535 for i in range(len(incorrect_acc.shape)):
2536 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2537 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2538 else:
2539 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002540
2541 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002542 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002543 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002544 [iter.name, a.name, acc.name],
2545 [iter_out.name, a_out.name, acc_out.name],
2546 attr,
2547 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002548 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002549
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002550 if error_name in [
2551 ErrorIf.InputListCondGraphMismatch,
2552 ErrorIf.InputListBodyGraphInputMismatch,
2553 ErrorIf.InputListBodyGraphOutputMismatch,
2554 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002555 incorrect_iter = deepcopy(iter)
2556 for i in range(len(incorrect_iter.shape)):
2557 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2558 if len(incorrect_iter.shape) == 0:
2559 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2560
2561 incorrect_acc = deepcopy(acc)
2562 for i in range(len(incorrect_acc.shape)):
2563 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2564
Eric Kunzee5e26762020-10-13 16:11:07 -07002565 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002566 self.ser.addBasicBlock(cond_block)
2567
Matthew Haddon630c17c2021-10-14 15:05:41 +01002568 if error_name == ErrorIf.InputListCondGraphMismatch:
2569 self.ser.addInputTensor(incorrect_iter)
2570 self.ser.addInputTensor(a)
2571 self.ser.addInputTensor(incorrect_acc)
2572 else:
2573 self.ser.addInputTensor(iter)
2574 self.ser.addInputTensor(a)
2575 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002576 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002577
2578 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002579 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002580 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002581 cond_type = DType.BOOL
2582 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2583 choice = self.rng.choice([1, 2])
2584 if choice == 1:
2585 cond_shape = [3]
2586 else:
2587 cond_shape = [1, 2]
2588 else:
2589 cond_shape = []
2590 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002591
Kevin Cheng550ccc52021-03-03 11:21:43 -08002592 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002593
2594 # BODY block (input: a, acc, iter, output: a, acc, iter)
2595 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002596 self.ser.addBasicBlock(body_block)
2597
Matthew Haddon630c17c2021-10-14 15:05:41 +01002598 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2599 self.ser.addInputTensor(incorrect_iter)
2600 self.ser.addInputTensor(a)
2601 self.ser.addInputTensor(incorrect_acc)
2602 else:
2603 self.ser.addInputTensor(iter)
2604 self.ser.addInputTensor(a)
2605 self.ser.addInputTensor(acc)
2606
Kevin Cheng550ccc52021-03-03 11:21:43 -08002607 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002608
2609 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002610 iter_body_out = self.ser.addIntermediate(
2611 incorrect_iter.shape, incorrect_iter.dtype
2612 )
2613 acc_body_out = self.ser.addIntermediate(
2614 incorrect_acc.shape, incorrect_acc.dtype
2615 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002616 else:
2617 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2618 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2619
Eric Kunzee5e26762020-10-13 16:11:07 -07002620 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2621 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2622 self.ser.addOutputTensor(iter_body_out)
2623 self.ser.addOutputTensor(a)
2624 self.ser.addOutputTensor(acc_body_out)
2625
Les Bell729b0352021-11-24 10:28:21 +00002626 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002627 self.ser,
2628 validator_fcns,
2629 error_name,
2630 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002631 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002632 ):
2633 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002634
Jeremy Johnson587cc842024-02-08 11:45:44 +00002635 compliance = self.tensorComplianceMetaData(
2636 op, a.dtype, args_dict, acc_out, error_name
2637 )
2638
2639 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002640
Luke Hutton57287132023-02-06 14:54:18 +00002641 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002642 self,
2643 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002644 inputs,
2645 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002646 validator_fcns=None,
2647 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002648 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002649 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002650 assert len(inputs) == 2
2651 val1, val2 = inputs
2652 inverse = args_dict["inverse"]
2653
Luke Hutton57287132023-02-06 14:54:18 +00002654 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2655
2656 input_names = [val1.name, val2.name]
2657 pCount, cCount = op["operands"]
2658 num_operands = pCount + cCount
2659
2660 output_names = [res.name for res in results]
2661 output_shapes = [res.shape for res in results]
2662 output_dtypes = [res.dtype for res in results]
2663
2664 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2665 self, error_name, input_names, output_names
2666 )
2667
2668 if not TosaErrorValidator.evValidateErrorIfs(
2669 self.ser,
2670 validator_fcns,
2671 error_name,
2672 op=op,
2673 inverse=inverse,
2674 input1=val1,
2675 input2=val2,
2676 input_shape=val1.shape,
2677 input_dtype=val1.dtype,
2678 output_shape=output_shapes,
2679 output_dtype=output_dtypes,
2680 result_tensors=results,
2681 input_list=input_names,
2682 output_list=output_names,
2683 num_operands=num_operands,
2684 ):
2685 return None
2686
Tai Lyd3797f02023-11-15 23:06:19 +00002687 # TODO - Test local_bound, for now set local bound attribute to False
2688 local_bound = False
2689
Luke Hutton57287132023-02-06 14:54:18 +00002690 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002691 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002692
2693 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002694
2695 compliance = []
2696 for res in results:
2697 compliance.append(
2698 self.tensorComplianceMetaData(
2699 op, val1.dtype, args_dict, res, error_name
2700 )
2701 )
2702
2703 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002704
Tai Lyd3797f02023-11-15 23:06:19 +00002705 def build_rfft2d(
2706 self,
2707 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002708 inputs,
2709 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002710 validator_fcns=None,
2711 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002712 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002713 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002714 assert len(inputs) == 1
2715 val = inputs[0]
Luke Hutton261b7b62023-01-10 14:50:31 +00002716 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2717
2718 input_names = [val.name]
2719 pCount, cCount = op["operands"]
2720 num_operands = pCount + cCount
2721
2722 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002723 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002724 output_dtypes = [res.dtype for res in results]
2725
2726 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2727 self, error_name, input_names, output_names
2728 )
2729
2730 if not TosaErrorValidator.evValidateErrorIfs(
2731 self.ser,
2732 validator_fcns,
2733 error_name,
2734 op=op,
2735 input_shape=val.shape,
2736 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002737 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002738 output_dtype=output_dtypes,
2739 result_tensors=results,
2740 input_list=input_names,
2741 output_list=output_names,
2742 num_operands=num_operands,
2743 ):
2744 return None
2745
Tai Lyd3797f02023-11-15 23:06:19 +00002746 # TODO - Test local_bound, for now set local bound attribute to False
2747 local_bound = False
2748
2749 attr = ts.TosaSerializerAttribute()
2750 attr.RFFTAttribute(local_bound)
2751
2752 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002753
2754 compliance = []
2755 for res in results:
2756 compliance.append(
2757 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2758 )
2759
2760 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002761
Won Jeon74342e52024-01-09 00:34:40 +00002762 def build_shape_op(
2763 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2764 ):
2765 assert len(inputs) == 2
2766 a, b = inputs
2767
2768 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2769
2770 # Invalidate Input/Output list for error if checks.
2771 input_list = [a.name, b.name]
2772 output_list = [result_tensor.name]
2773 pCount, cCount = op["operands"]
2774 num_operands = pCount + cCount
2775 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2776 self, error_name, input_list, output_list
2777 )
2778
2779 if not TosaErrorValidator.evValidateErrorIfs(
2780 self.ser,
2781 validator_fcns,
2782 error_name,
2783 op=op,
2784 input1=a,
2785 input2=b,
2786 input_shape=a.shape,
2787 input_dtype=a.dtype,
2788 output_shape=result_tensor.shape,
2789 output_dtype=result_tensor.dtype,
2790 result_tensors=[result_tensor],
2791 input_list=input_list,
2792 output_list=output_list,
2793 num_operands=num_operands,
2794 ):
2795 return None
2796
2797 self.ser.addOperator(
2798 op["op"],
2799 input_list,
2800 output_list,
2801 )
2802 compliance = self.tensorComplianceMetaData(
2803 op, a.dtype, args_dict, result_tensor, error_name
2804 )
2805
2806 return TosaTestGen.BuildInfo(result_tensor, compliance)
2807
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002808 def create_filter_lists(
2809 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2810 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002811 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2812 default_test_rank_range = range(1, 5)
2813 if not shapeFilter:
2814 shapeFilter = [None]
2815
2816 # Calculate the filters based on what is requested and what the operator allows
2817 rmin, rmax = op["rank"]
2818 if rankFilter is not None:
2819 cleanRankFilter = []
2820 # Ensure rankFilter values are allowed by operator
2821 for rank in rankFilter:
2822 if rank >= rmin and rank <= rmax:
2823 cleanRankFilter.append(rank)
2824 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002825 # Ensure default behaviour is bounded by default range or by operator,
2826 # whichever is the smaller range of ranks.
2827 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002828 cleanRankFilter = (
2829 opRankRange
2830 if len(opRankRange) <= len(default_test_rank_range)
2831 else default_test_rank_range
2832 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002833 else:
2834 cleanRankFilter = range(rmin, rmax + 1)
2835
2836 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002837
Matthew Haddon1c00b712021-10-01 15:51:03 +01002838 if dtypeFilter is not None:
2839 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002840 # Create list of operator dtypes filtered by requested dtypes
2841 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002842 if dtype in dtypeFilter or (
2843 isinstance(dtype, list) and dtype[0] in dtypeFilter
2844 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002845 cleanDtypeFilter.append(dtype)
2846 else:
2847 cleanDtypeFilter = dtypes
2848
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002849 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002850 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002851 "shapeFilter": shapeFilter,
2852 "rankFilter": cleanRankFilter,
2853 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002854 }
2855 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002856 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002857 if validator is not None:
2858 validator_info = validator(check=False, op=op)
2859 else:
2860 return None
2861
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002862 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002863
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002864 # Set parameters as required
2865 if error_arguments["rank"] is not None:
2866 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002867 else:
2868 rankFilter = cleanRankFilter
2869
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002870 if error_arguments["dtype"] is not None:
2871 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002872 else:
2873 dtypeFilter = cleanDtypeFilter
2874
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002875 if error_arguments["shape"] is not None:
2876 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002877 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002878 shapeFilter = shapeFilter[
2879 :2
2880 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002881
2882 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002883 "shapeFilter": shapeFilter,
2884 "rankFilter": rankFilter,
2885 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002886 }
2887 return filterDict
2888
Kevin Cheng550ccc52021-03-03 11:21:43 -08002889 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002890 self,
2891 opName,
2892 shapeFilter=[None],
2893 rankFilter=None,
2894 dtypeFilter=None,
2895 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002896 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002897
2898 try:
2899 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002900 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002901 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002902
2903 # Initialize a new random number generator
2904 self.rng = np.random.default_rng(self.random_seed)
2905
Jeremy Johnson1271c442023-09-05 11:39:26 +01002906 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002907
Eric Kunzee5e26762020-10-13 16:11:07 -07002908 # Test list consists of a tuple of:
2909 # (opName, testNameStr, dtype, shapeList, argumentsList)
2910 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002911 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002912 error_if_validators = op["error_if_validators"]
2913 else:
2914 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002915
Matthew Haddon1c00b712021-10-01 15:51:03 +01002916 for validator in error_if_validators:
2917 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002918 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002919 else:
2920 error_name = None
2921
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002922 filterDict = self.create_filter_lists(
2923 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2924 )
2925 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002926 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002927 cleanRankFilter = filterDict["rankFilter"]
2928 cleanDtypeFilter = filterDict["dtypeFilter"]
2929 cleanShapeFilter = filterDict["shapeFilter"]
2930 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002931
2932 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002933 for t in cleanDtypeFilter:
2934 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002935 # Filter out by rank
2936 if shape is not None and len(shape) != r:
2937 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002938 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002939 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002940
Matthew Haddon74567092021-07-16 15:38:20 +01002941 shapeStr = self.shapeStr(shapeList[0])
2942 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002943
Matthew Haddon74567092021-07-16 15:38:20 +01002944 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2945 argList = []
2946 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002947 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002948 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002949 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002950
Matthew Haddon74567092021-07-16 15:38:20 +01002951 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002952 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002953 if argStr:
2954 testStr = "{}_{}_{}_{}".format(
2955 opName, shapeStr, typeStr, argStr
2956 )
2957 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002958 testStr = "{}_{}_{}".format(
2959 opName, shapeStr, typeStr
2960 )
2961 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002962 if argStr:
2963 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2964 opName, error_name, shapeStr, typeStr, argStr
2965 )
2966 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002967 testStr = "{}_ERRORIF_{}_{}_{}".format(
2968 opName, error_name, shapeStr, typeStr
2969 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002970
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002971 testList.append(
2972 (opName, testStr, t, error_name, shapeList, args)
2973 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002974
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002975 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002976 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2977 if "invalid_test_validators" in op:
2978 invalid_test_validators = op["invalid_test_validators"]
2979 clean_testList = []
2980 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002981 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002982 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002983 if validator_fcn(
2984 opName=test[0],
2985 input_dtype=test[2],
2986 shapeList=test[4],
2987 args=test[5],
2988 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002989 remove_test = True
2990 if not remove_test:
2991 clean_testList.append(test)
2992 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002993
2994 return testList
2995
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002996 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002997 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002998 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002999 try:
3000 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003001 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003002 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003003
Jeremy Johnson0c716862023-04-13 17:18:19 +01003004 if self.args.verbose:
3005 print(f"Creating {testStr}")
3006
Eric Kunzee5e26762020-10-13 16:11:07 -07003007 # Create a serializer
3008 self.createSerializer(opName, testStr)
3009
Jeremy Johnson1271c442023-09-05 11:39:26 +01003010 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003011 if "error_if_validators" in op:
3012 error_if_validators = op["error_if_validators"]
3013 else:
3014 error_if_validators = None
3015
Kevin Cheng550ccc52021-03-03 11:21:43 -08003016 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003017 num_operands = pCount + cCount
3018
3019 if isinstance(dtype_or_dtypeList, list):
3020 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003021 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003022 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003023 else:
3024 dtypeList = [dtype_or_dtypeList] * (num_operands)
3025
Won Jeon74342e52024-01-09 00:34:40 +00003026 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003027 assert (
3028 len(shapeList) == num_operands
3029 ), "shapeList length {} must match number of operands {}".format(
3030 len(shapeList), num_operands
3031 )
3032 assert (
3033 len(dtypeList) == num_operands
3034 ), "dtypeList length {} must match number of operands {}".format(
3035 len(dtypeList), num_operands
3036 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003037
3038 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003039 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003040 except KeyError:
3041 qgen = None
3042
3043 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003044
Matthew Haddon1c00b712021-10-01 15:51:03 +01003045 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003046 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003047 else:
3048 qinfo = None
3049
Jeremy Johnson1271c442023-09-05 11:39:26 +01003050 # Extra meta data for the desc.json
3051 tensMeta = {}
3052
Jeremy Johnson587cc842024-02-08 11:45:44 +00003053 # Check we are using the new interface with an argsDict dictionary
3054 assert isinstance(
3055 argsDict, dict
3056 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003057
Jeremy Johnson587cc842024-02-08 11:45:44 +00003058 # New interface with args info in dictionary
3059 assert "dg_type" in argsDict
3060 tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name)
3061 if tvgInfo.dataGenDict:
3062 tensMeta["data_gen"] = tvgInfo.dataGenDict
3063 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003064
Jeremy Johnson587cc842024-02-08 11:45:44 +00003065 result = build_fcn(
3066 self,
3067 op,
3068 tens,
3069 argsDict,
3070 validator_fcns=error_if_validators,
3071 error_name=error_name,
3072 qinfo=qinfo,
3073 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003074
Jeremy Johnson1271c442023-09-05 11:39:26 +01003075 if result:
Les Bell729b0352021-11-24 10:28:21 +00003076 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003077 if isinstance(result, TosaTestGen.BuildInfo):
3078 # Add the compliance meta data (if any)
3079 compliance = result.getComplianceInfo()
3080 if compliance:
3081 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003082 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00003083 else:
3084 # The test is not valid
3085 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01003086
Eric Kunzee5e26762020-10-13 16:11:07 -07003087 def createDynamicOpLists(self):
3088
Jeremy Johnson00423432022-09-12 17:27:37 +01003089 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3090 # Already created these lists (can occur when class is initialized more than once)
3091 return
3092
Eric Kunzee5e26762020-10-13 16:11:07 -07003093 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003094 if not self.args.level8k:
3095 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3096 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3097 else:
3098 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3099 KERNELS_2D = [[1, bigK], [bigK, 2]]
3100 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003101
Kevin Cheng1533b852021-09-01 12:51:58 -07003102 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003103 testName = "conv2d_{}x{}".format(k[0], k[1])
3104 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3105 self.TOSA_OP_LIST[testName]["filter"] = k
3106 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003107
Kevin Cheng550ccc52021-03-03 11:21:43 -08003108 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3109 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3110 "depthwise_conv2d_TEMPLATE"
3111 ].copy()
3112 self.TOSA_OP_LIST[testName]["filter"] = k
3113 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003114
Kevin Cheng550ccc52021-03-03 11:21:43 -08003115 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3116 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3117 "transpose_conv2d_TEMPLATE"
3118 ].copy()
3119 self.TOSA_OP_LIST[testName]["filter"] = k
3120 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003121
Kevin Cheng1533b852021-09-01 12:51:58 -07003122 for k in KERNELS_3D:
3123 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3124 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3125 self.TOSA_OP_LIST[testName]["filter"] = k
3126 self.TOSA_OP_LIST[testName]["template"] = False
3127
Eric Kunzee5e26762020-10-13 16:11:07 -07003128 # Delete any templates after having created any dynamic ops
3129 # This is a two-pass operation because it's bad practice to delete
3130 # keys from dictionaries while iterating
3131 keyList = []
3132 for k in self.TOSA_OP_LIST:
3133 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003134 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003135 keyList.append(k)
3136 continue
3137 except KeyError:
3138 pass
3139
3140 for k in keyList:
3141 del self.TOSA_OP_LIST[k]
3142
3143 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003144 """Fill in default fields for ops if they aren't already specified.
3145 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003146 for op in self.TOSA_OP_LIST:
3147
3148 # Required fields
3149 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003150 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003151 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003152 raise Exception(
3153 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3154 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003155
3156 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003157 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003158 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003159 raise Exception(
3160 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3161 op
3162 )
3163 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003164
3165 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003166 _ = self.TOSA_OP_LIST[op]["types"]
3167 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003168 raise Exception(
3169 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3170 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003171
3172 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003173 _ = self.TOSA_OP_LIST[op]["op"]
3174 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003175 raise Exception(
3176 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3177 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003178
3179 # Put in default rank range, if missing
3180 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003181 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003182 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003183 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003184
3185 # Tensor operator list
3186 # 'op': op name
3187 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003188 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3189 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003190 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3191 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003192 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003193
Kevin Cheng550ccc52021-03-03 11:21:43 -08003194 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003195 TYPE_INT_FP = [
3196 DType.INT8,
3197 DType.INT16,
3198 DType.INT32,
3199 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003200 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003201 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003202 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003203
Kevin Cheng550ccc52021-03-03 11:21:43 -08003204 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003205 TYPE_FI32 = [
3206 DType.FP32,
3207 DType.FP16,
3208 DType.BF16,
3209 DType.INT32,
3210 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003211 TYPE_FIB = [
3212 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003213 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003214 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003215 DType.INT8,
3216 DType.INT16,
3217 DType.INT32,
3218 DType.BOOL,
3219 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003220 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003221
Won Jeon2c34b462024-02-06 18:37:00 +00003222 TYPE_NARROW_INT_FP = [
3223 DType.INT8,
3224 DType.INT16,
3225 DType.FP16,
3226 DType.BF16,
3227 DType.FP32,
3228 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003229
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003230 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003231 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003232 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003233 [DType.INT8, DType.INT8, DType.INT32],
3234 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003235 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003236 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003237 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003238 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003239 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3240 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003241 ]
3242
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003243 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003244
3245 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003246 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003247 "argmax": {
3248 "op": Op.ARGMAX,
3249 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003250 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003251 "build_fcn": (
3252 build_argmax,
3253 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003254 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003255 TosaArgGen.agAxis,
3256 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003257 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003258 "error_if_validators": (
3259 TosaErrorValidator.evAxisSmallerZero,
3260 TosaErrorValidator.evAxisLargerRank,
3261 TosaErrorValidator.evArgmaxOutputRankMismatch,
3262 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3263 TosaErrorValidator.evWrongRank,
3264 TosaErrorValidator.evWrongInputType,
3265 TosaErrorValidator.evWrongOutputType,
3266 TosaErrorValidator.evWrongInputList,
3267 TosaErrorValidator.evWrongOutputList,
3268 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003269 "data_gen": {
3270 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3271 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003272 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003273 "avg_pool2d": {
3274 "op": Op.AVG_POOL2D,
3275 "operands": (1, 0),
3276 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003277 "build_fcn": (
3278 build_pool2d,
3279 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003280 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003281 TosaArgGen.agPooling,
3282 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003283 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003284 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003285 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003286 "error_if_validators": (
3287 TosaErrorValidator.evKernelSmallerOne,
3288 TosaErrorValidator.evStrideSmallerOne,
3289 TosaErrorValidator.evPadSmallerZero,
3290 TosaErrorValidator.evWrongRank,
3291 TosaErrorValidator.evWrongInputType,
3292 TosaErrorValidator.evWrongOutputType,
3293 TosaErrorValidator.evWrongInputList,
3294 TosaErrorValidator.evWrongOutputList,
3295 TosaErrorValidator.evInputZeroPointNotZero,
3296 TosaErrorValidator.evOutputZeroPointNotZero,
3297 TosaErrorValidator.evPadLargerEqualKernel,
3298 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003299 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003300 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003301 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003302 "data_gen": {
3303 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3304 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003305 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003306 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003307 "conv2d_TEMPLATE": {
3308 "op": Op.CONV2D,
3309 "operands": (1, 2),
3310 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003311 "build_fcn": (
3312 build_conv2d,
3313 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003314 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003315 TosaArgGen.agConv,
3316 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003317 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003318 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003319 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3320 "error_if_validators": (
3321 TosaErrorValidator.evWrongInputType,
3322 TosaErrorValidator.evWrongOutputType,
3323 TosaErrorValidator.evWrongInputList,
3324 TosaErrorValidator.evWrongOutputList,
3325 TosaErrorValidator.evInputZeroPointNotZero,
3326 TosaErrorValidator.evWeightZeroPointNotZero,
3327 TosaErrorValidator.evPadSmallerZero,
3328 TosaErrorValidator.evStrideSmallerOne,
3329 TosaErrorValidator.evDilationSmallerOne,
3330 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003331 TosaErrorValidator.evConvOutputShapeMismatch,
3332 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003333 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003334 "data_gen": {
3335 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3336 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003337 "template": True,
3338 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003339 # Templated operator. Filled in by createDynamicOpLists
3340 "conv3d_TEMPLATE": {
3341 "op": Op.CONV3D,
3342 "operands": (1, 2),
3343 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003344 "build_fcn": (
3345 build_conv3d,
3346 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003347 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003348 TosaArgGen.agConv,
3349 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003350 "qgen": TosaQuantGen.qgConv,
3351 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003352 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3353 "error_if_validators": (
3354 TosaErrorValidator.evWrongInputType,
3355 TosaErrorValidator.evWrongOutputType,
3356 TosaErrorValidator.evWrongInputList,
3357 TosaErrorValidator.evWrongOutputList,
3358 TosaErrorValidator.evInputZeroPointNotZero,
3359 TosaErrorValidator.evWeightZeroPointNotZero,
3360 TosaErrorValidator.evPadSmallerZero,
3361 TosaErrorValidator.evStrideSmallerOne,
3362 TosaErrorValidator.evDilationSmallerOne,
3363 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003364 TosaErrorValidator.evConvOutputShapeMismatch,
3365 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003366 ),
evacha0147ab1762024-01-29 13:23:23 +00003367 "data_gen": {
3368 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3369 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003370 "template": True,
3371 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003372 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003373 "depthwise_conv2d_TEMPLATE": {
3374 "op": Op.DEPTHWISE_CONV2D,
3375 "operands": (1, 2),
3376 "filter": [1, 1],
3377 "rank": (4, 4),
3378 "build_fcn": (
3379 build_depthwise_conv2d,
3380 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003381 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003382 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003383 ),
3384 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003385 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003386 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3387 "error_if_validators": (
3388 TosaErrorValidator.evWrongInputType,
3389 TosaErrorValidator.evWrongOutputType,
3390 TosaErrorValidator.evWrongInputList,
3391 TosaErrorValidator.evWrongOutputList,
3392 TosaErrorValidator.evInputZeroPointNotZero,
3393 TosaErrorValidator.evWeightZeroPointNotZero,
3394 TosaErrorValidator.evPadSmallerZero,
3395 TosaErrorValidator.evStrideSmallerOne,
3396 TosaErrorValidator.evDilationSmallerOne,
3397 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003398 TosaErrorValidator.evConvOutputShapeMismatch,
3399 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003400 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003401 "data_gen": {
3402 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3403 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003404 "template": True,
3405 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003406 "fully_connected": {
3407 "op": Op.FULLY_CONNECTED,
3408 "operands": (1, 2),
3409 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003410 "build_fcn": (
3411 build_fully_connected,
3412 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003413 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003414 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003415 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003416 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003417 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003418 "error_if_validators": (
3419 TosaErrorValidator.evInputZeroPointNotZero,
3420 TosaErrorValidator.evWeightZeroPointNotZero,
3421 TosaErrorValidator.evWrongRank,
3422 TosaErrorValidator.evWrongInputType,
3423 TosaErrorValidator.evWrongOutputType,
3424 TosaErrorValidator.evWrongInputList,
3425 TosaErrorValidator.evWrongOutputList,
3426 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003427 "data_gen": {
3428 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3429 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003430 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003431 "matmul": {
3432 "op": Op.MATMUL,
3433 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003434 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003435 "build_fcn": (
3436 build_matmul,
3437 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003438 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003439 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003440 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003441 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003442 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003443 "error_if_validators": (
3444 TosaErrorValidator.evInputZeroPointNotZero,
3445 TosaErrorValidator.evWrongRank,
3446 TosaErrorValidator.evWrongInputType,
3447 TosaErrorValidator.evWrongOutputType,
3448 TosaErrorValidator.evWrongInputList,
3449 TosaErrorValidator.evWrongOutputList,
3450 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003451 "data_gen": {
3452 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003453 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003454 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003455 "max_pool2d": {
3456 "op": Op.MAX_POOL2D,
3457 "operands": (1, 0),
3458 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003459 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003460 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003461 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003462 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003463 TosaArgGen.agPooling,
3464 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003465 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003466 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003467 "error_if_validators": (
3468 TosaErrorValidator.evKernelSmallerOne,
3469 TosaErrorValidator.evStrideSmallerOne,
3470 TosaErrorValidator.evPadSmallerZero,
3471 TosaErrorValidator.evWrongRank,
3472 TosaErrorValidator.evWrongInputType,
3473 TosaErrorValidator.evWrongOutputType,
3474 TosaErrorValidator.evWrongInputList,
3475 TosaErrorValidator.evWrongOutputList,
3476 TosaErrorValidator.evPadLargerEqualKernel,
3477 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003478 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003479 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003480 "data_gen": {
3481 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3482 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003483 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003484 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003485 "transpose_conv2d_TEMPLATE": {
3486 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003487 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003488 "rank": (4, 4),
3489 "build_fcn": (
3490 build_transpose_conv2d,
3491 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003492 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003493 TosaArgGen.agTransposeConv2D,
3494 ),
3495 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003496 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003497 "invalid_test_validators": (
3498 TosaInvalidValidator.ivHeightWidthInvalid,
3499 TosaInvalidValidator.ivNonPositiveOutputShape,
3500 ),
3501 "error_if_validators": (
3502 TosaErrorValidator.evWrongInputType,
3503 TosaErrorValidator.evWrongOutputType,
3504 TosaErrorValidator.evWrongInputList,
3505 TosaErrorValidator.evWrongOutputList,
3506 TosaErrorValidator.evInputZeroPointNotZero,
3507 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003508 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003509 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003510 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003511 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003512 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003513 "data_gen": {
3514 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3515 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003516 "template": True,
3517 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003518 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003519 "clamp": {
3520 "op": Op.CLAMP,
3521 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003522 "build_fcn": (
3523 build_clamp,
3524 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003525 TosaTensorValuesGen.tvgLazyGenDefault,
3526 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003527 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003528 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003529 "error_if_validators": (
3530 TosaErrorValidator.evMaxSmallerMin,
3531 TosaErrorValidator.evWrongInputType,
3532 TosaErrorValidator.evWrongOutputType,
3533 TosaErrorValidator.evWrongInputList,
3534 TosaErrorValidator.evWrongOutputList,
3535 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003536 "data_gen": {
3537 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3538 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003539 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003540 "sigmoid": {
3541 "op": Op.SIGMOID,
3542 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003543 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003544 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003545 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003546 TosaTensorValuesGen.tvgLazyGenDefault,
3547 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003548 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003549 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003550 "error_if_validators": (
3551 TosaErrorValidator.evWrongInputType,
3552 TosaErrorValidator.evWrongOutputType,
3553 TosaErrorValidator.evWrongInputList,
3554 TosaErrorValidator.evWrongOutputList,
3555 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003556 "data_gen": {
3557 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3558 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003559 },
3560 "tanh": {
3561 "op": Op.TANH,
3562 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003563 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003564 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003565 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003566 TosaTensorValuesGen.tvgLazyGenDefault,
3567 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003568 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003569 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003570 "error_if_validators": (
3571 TosaErrorValidator.evWrongInputType,
3572 TosaErrorValidator.evWrongOutputType,
3573 TosaErrorValidator.evWrongInputList,
3574 TosaErrorValidator.evWrongOutputList,
3575 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003576 "data_gen": {
3577 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3578 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003579 "compliance": {
3580 "abs_error_lower_bound": 0.5,
3581 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003582 },
Won Jeon78155c62023-06-10 00:20:04 +00003583 "erf": {
3584 "op": Op.ERF,
3585 "operands": (1, 0),
3586 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003587 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003588 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003589 TosaTensorValuesGen.tvgLazyGenDefault,
3590 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003591 ),
3592 "types": TYPE_FP,
3593 "error_if_validators": (
3594 TosaErrorValidator.evWrongInputType,
3595 TosaErrorValidator.evWrongOutputType,
3596 TosaErrorValidator.evWrongInputList,
3597 TosaErrorValidator.evWrongOutputList,
3598 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003599 "data_gen": {
3600 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3601 },
3602 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003603 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003604 # Elementwise Binary Operators
3605 "add": {
3606 "op": Op.ADD,
3607 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003608 "build_fcn": (
3609 build_binary_broadcast,
3610 TosaTensorGen.tgBroadcastFuzz,
3611 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003612 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003613 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003614 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003615 "error_if_validators": (
3616 TosaErrorValidator.evRankMismatch,
3617 TosaErrorValidator.evWrongInputType,
3618 TosaErrorValidator.evWrongOutputType,
3619 TosaErrorValidator.evWrongInputList,
3620 TosaErrorValidator.evWrongOutputList,
3621 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003622 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003623 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003624 "data_gen": {
3625 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3626 },
3627 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003628 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003629 "arithmetic_right_shift": {
3630 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3631 "operands": (2, 0),
3632 "build_fcn": (
3633 build_arithmetic_right_shift,
3634 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003635 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003636 TosaArgGen.agArithmeticRightShift,
3637 ),
3638 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003639 "error_if_validators": (
3640 TosaErrorValidator.evRankMismatch,
3641 TosaErrorValidator.evWrongInputType,
3642 TosaErrorValidator.evWrongOutputType,
3643 TosaErrorValidator.evWrongInputList,
3644 TosaErrorValidator.evWrongOutputList,
3645 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003646 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003647 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003648 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003649 "bitwise_and": {
3650 "op": Op.BITWISE_AND,
3651 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003652 "build_fcn": (
3653 build_binary_broadcast,
3654 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003655 TosaTensorValuesGen.tvgLazyGenDefault,
3656 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003657 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003658 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003659 "error_if_validators": (
3660 TosaErrorValidator.evRankMismatch,
3661 TosaErrorValidator.evWrongInputType,
3662 TosaErrorValidator.evWrongOutputType,
3663 TosaErrorValidator.evWrongInputList,
3664 TosaErrorValidator.evWrongOutputList,
3665 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003666 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003667 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003668 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003669 "bitwise_or": {
3670 "op": Op.BITWISE_OR,
3671 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003672 "build_fcn": (
3673 build_binary_broadcast,
3674 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003675 TosaTensorValuesGen.tvgLazyGenDefault,
3676 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003677 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003678 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003679 "error_if_validators": (
3680 TosaErrorValidator.evRankMismatch,
3681 TosaErrorValidator.evWrongInputType,
3682 TosaErrorValidator.evWrongOutputType,
3683 TosaErrorValidator.evWrongInputList,
3684 TosaErrorValidator.evWrongOutputList,
3685 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003686 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003687 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003688 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003689 "bitwise_xor": {
3690 "op": Op.BITWISE_XOR,
3691 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003692 "build_fcn": (
3693 build_binary_broadcast,
3694 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003695 TosaTensorValuesGen.tvgLazyGenDefault,
3696 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003697 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003698 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003699 "error_if_validators": (
3700 TosaErrorValidator.evRankMismatch,
3701 TosaErrorValidator.evWrongInputType,
3702 TosaErrorValidator.evWrongOutputType,
3703 TosaErrorValidator.evWrongInputList,
3704 TosaErrorValidator.evWrongOutputList,
3705 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003706 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003707 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003708 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003709 "intdiv": {
3710 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003711 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003712 "build_fcn": (
3713 build_binary_broadcast,
3714 TosaTensorGen.tgBroadcastFuzz,
3715 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003716 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003717 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003718 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003719 "error_if_validators": (
3720 TosaErrorValidator.evRankMismatch,
3721 TosaErrorValidator.evWrongInputType,
3722 TosaErrorValidator.evWrongOutputType,
3723 TosaErrorValidator.evWrongInputList,
3724 TosaErrorValidator.evWrongOutputList,
3725 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003726 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003727 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003728 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003729 "logical_and": {
3730 "op": Op.LOGICAL_AND,
3731 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003732 "build_fcn": (
3733 build_binary_broadcast,
3734 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003735 TosaTensorValuesGen.tvgLazyGenDefault,
3736 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003737 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003738 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003739 "error_if_validators": (
3740 TosaErrorValidator.evRankMismatch,
3741 TosaErrorValidator.evWrongInputType,
3742 TosaErrorValidator.evWrongOutputType,
3743 TosaErrorValidator.evWrongInputList,
3744 TosaErrorValidator.evWrongOutputList,
3745 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003746 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003747 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003748 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003749 "logical_left_shift": {
3750 "op": Op.LOGICAL_LEFT_SHIFT,
3751 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003752 "build_fcn": (
3753 build_binary_broadcast,
3754 TosaTensorGen.tgBroadcastFuzz,
3755 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003756 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003757 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003758 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003759 "error_if_validators": (
3760 TosaErrorValidator.evRankMismatch,
3761 TosaErrorValidator.evWrongInputType,
3762 TosaErrorValidator.evWrongOutputType,
3763 TosaErrorValidator.evWrongInputList,
3764 TosaErrorValidator.evWrongOutputList,
3765 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003766 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003767 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003768 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003769 "logical_right_shift": {
3770 "op": Op.LOGICAL_RIGHT_SHIFT,
3771 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003772 "build_fcn": (
3773 build_binary_broadcast,
3774 TosaTensorGen.tgBroadcastFuzz,
3775 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003776 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003777 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003778 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003779 "error_if_validators": (
3780 TosaErrorValidator.evRankMismatch,
3781 TosaErrorValidator.evWrongInputType,
3782 TosaErrorValidator.evWrongOutputType,
3783 TosaErrorValidator.evWrongInputList,
3784 TosaErrorValidator.evWrongOutputList,
3785 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003786 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003787 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003788 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003789 "logical_or": {
3790 "op": Op.LOGICAL_OR,
3791 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003792 "build_fcn": (
3793 build_binary_broadcast,
3794 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003795 TosaTensorValuesGen.tvgLazyGenDefault,
3796 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003797 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003798 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003799 "error_if_validators": (
3800 TosaErrorValidator.evRankMismatch,
3801 TosaErrorValidator.evWrongInputType,
3802 TosaErrorValidator.evWrongOutputType,
3803 TosaErrorValidator.evWrongInputList,
3804 TosaErrorValidator.evWrongOutputList,
3805 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003806 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003807 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003808 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003809 "logical_xor": {
3810 "op": Op.LOGICAL_XOR,
3811 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003812 "build_fcn": (
3813 build_binary_broadcast,
3814 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003815 TosaTensorValuesGen.tvgLazyGenDefault,
3816 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003817 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003818 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003819 "error_if_validators": (
3820 TosaErrorValidator.evRankMismatch,
3821 TosaErrorValidator.evWrongInputType,
3822 TosaErrorValidator.evWrongOutputType,
3823 TosaErrorValidator.evWrongInputList,
3824 TosaErrorValidator.evWrongOutputList,
3825 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003826 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003827 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003828 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003829 "maximum": {
3830 "op": Op.MAXIMUM,
3831 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003832 "build_fcn": (
3833 build_binary_broadcast,
3834 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003835 TosaTensorValuesGen.tvgLazyGenDefault,
3836 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003837 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003838 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003839 "error_if_validators": (
3840 TosaErrorValidator.evRankMismatch,
3841 TosaErrorValidator.evWrongInputType,
3842 TosaErrorValidator.evWrongOutputType,
3843 TosaErrorValidator.evWrongInputList,
3844 TosaErrorValidator.evWrongOutputList,
3845 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003846 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003847 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003848 "data_gen": {
3849 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3850 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003851 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003852 "minimum": {
3853 "op": Op.MINIMUM,
3854 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003855 "build_fcn": (
3856 build_binary_broadcast,
3857 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003858 TosaTensorValuesGen.tvgLazyGenDefault,
3859 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003860 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003861 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003862 "error_if_validators": (
3863 TosaErrorValidator.evRankMismatch,
3864 TosaErrorValidator.evWrongInputType,
3865 TosaErrorValidator.evWrongOutputType,
3866 TosaErrorValidator.evWrongInputList,
3867 TosaErrorValidator.evWrongOutputList,
3868 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003869 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003870 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003871 "data_gen": {
3872 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3873 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003874 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003875 "mul": {
3876 "op": Op.MUL,
3877 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003878 "build_fcn": (
3879 build_mul,
3880 TosaTensorGen.tgBroadcastFuzz,
3881 TosaTensorValuesGen.tvgMul,
3882 TosaArgGen.agMul,
3883 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003884 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003885 "error_if_validators": (
3886 TosaErrorValidator.evWrongInputType,
3887 TosaErrorValidator.evWrongOutputType,
3888 TosaErrorValidator.evWrongInputList,
3889 TosaErrorValidator.evWrongOutputList,
3890 TosaErrorValidator.evRankMismatch,
3891 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003892 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003893 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003894 "data_gen": {
3895 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3896 },
3897 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003898 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003899 "pow": {
3900 "op": Op.POW,
3901 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003902 "build_fcn": (
3903 build_binary_broadcast,
3904 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003905 TosaTensorValuesGen.tvgPow,
3906 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003907 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003908 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003909 "error_if_validators": (
3910 TosaErrorValidator.evRankMismatch,
3911 TosaErrorValidator.evWrongInputType,
3912 TosaErrorValidator.evWrongOutputType,
3913 TosaErrorValidator.evWrongInputList,
3914 TosaErrorValidator.evWrongOutputList,
3915 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003916 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003917 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003918 "data_gen": {
3919 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3920 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003921 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003922 "sub": {
3923 "op": Op.SUB,
3924 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003925 "build_fcn": (
3926 build_binary_broadcast,
3927 TosaTensorGen.tgBroadcastFuzz,
3928 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003929 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003930 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003931 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003932 "error_if_validators": (
3933 TosaErrorValidator.evRankMismatch,
3934 TosaErrorValidator.evWrongInputType,
3935 TosaErrorValidator.evWrongOutputType,
3936 TosaErrorValidator.evWrongInputList,
3937 TosaErrorValidator.evWrongOutputList,
3938 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003939 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003940 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003941 "data_gen": {
3942 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3943 },
3944 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003945 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003946 "table": {
3947 "op": Op.TABLE,
3948 # Use the automatic generation functions to create the input array
3949 # but create the table tensor in the build function, as it may be
3950 # a different type from the input
3951 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003952 "build_fcn": (
3953 build_table,
3954 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003955 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003956 TosaArgGen.agTable,
3957 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003958 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003959 "error_if_validators": (
3960 TosaErrorValidator.evWrongInputType,
3961 TosaErrorValidator.evWrongOutputType,
3962 TosaErrorValidator.evWrongInputList,
3963 TosaErrorValidator.evWrongOutputList,
3964 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003965 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003966 # Elementwise Unary operators
3967 "abs": {
3968 "op": Op.ABS,
3969 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003970 "build_fcn": (
3971 build_unary,
3972 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003973 TosaTensorValuesGen.tvgLazyGenDefault,
3974 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003975 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003976 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003977 "error_if_validators": (
3978 TosaErrorValidator.evWrongInputType,
3979 TosaErrorValidator.evWrongOutputType,
3980 TosaErrorValidator.evWrongInputList,
3981 TosaErrorValidator.evWrongOutputList,
3982 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003983 "data_gen": {
3984 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3985 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003986 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003987 "bitwise_not": {
3988 "op": Op.BITWISE_NOT,
3989 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003990 "build_fcn": (
3991 build_unary,
3992 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003993 TosaTensorValuesGen.tvgLazyGenDefault,
3994 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003995 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003996 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003997 "error_if_validators": (
3998 TosaErrorValidator.evWrongInputType,
3999 TosaErrorValidator.evWrongOutputType,
4000 TosaErrorValidator.evWrongInputList,
4001 TosaErrorValidator.evWrongOutputList,
4002 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004003 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004004 "ceil": {
4005 "op": Op.CEIL,
4006 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004007 "build_fcn": (
4008 build_unary,
4009 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004010 TosaTensorValuesGen.tvgLazyGenDefault,
4011 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004012 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004013 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004014 "error_if_validators": (
4015 TosaErrorValidator.evWrongInputType,
4016 TosaErrorValidator.evWrongOutputType,
4017 TosaErrorValidator.evWrongInputList,
4018 TosaErrorValidator.evWrongOutputList,
4019 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004020 "data_gen": {
4021 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4022 },
4023 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004024 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004025 "clz": {
4026 "op": Op.CLZ,
4027 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004028 "build_fcn": (
4029 build_unary,
4030 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004031 TosaTensorValuesGen.tvgLazyGenDefault,
4032 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004033 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004034 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004035 "error_if_validators": (
4036 TosaErrorValidator.evWrongInputType,
4037 TosaErrorValidator.evWrongOutputType,
4038 TosaErrorValidator.evWrongInputList,
4039 TosaErrorValidator.evWrongOutputList,
4040 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004041 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004042 "exp": {
4043 "op": Op.EXP,
4044 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004045 "build_fcn": (
4046 build_unary,
4047 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004048 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004049 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004050 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004051 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004052 "error_if_validators": (
4053 TosaErrorValidator.evWrongInputType,
4054 TosaErrorValidator.evWrongOutputType,
4055 TosaErrorValidator.evWrongInputList,
4056 TosaErrorValidator.evWrongOutputList,
4057 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004058 "data_gen": {
4059 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4060 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004061 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004062 "floor": {
4063 "op": Op.FLOOR,
4064 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004065 "build_fcn": (
4066 build_unary,
4067 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004068 TosaTensorValuesGen.tvgLazyGenDefault,
4069 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004070 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004071 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004072 "error_if_validators": (
4073 TosaErrorValidator.evWrongInputType,
4074 TosaErrorValidator.evWrongOutputType,
4075 TosaErrorValidator.evWrongInputList,
4076 TosaErrorValidator.evWrongOutputList,
4077 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004078 "data_gen": {
4079 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4080 },
4081 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004082 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004083 "log": {
4084 "op": Op.LOG,
4085 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004086 "build_fcn": (
4087 build_unary,
4088 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004089 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004090 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004091 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004092 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004093 "error_if_validators": (
4094 TosaErrorValidator.evWrongInputType,
4095 TosaErrorValidator.evWrongOutputType,
4096 TosaErrorValidator.evWrongInputList,
4097 TosaErrorValidator.evWrongOutputList,
4098 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004099 "data_gen": {
4100 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4101 },
4102 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004103 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004104 "logical_not": {
4105 "op": Op.LOGICAL_NOT,
4106 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004107 "build_fcn": (
4108 build_unary,
4109 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004110 TosaTensorValuesGen.tvgLazyGenDefault,
4111 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004112 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004113 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004114 "error_if_validators": (
4115 TosaErrorValidator.evWrongInputType,
4116 TosaErrorValidator.evWrongOutputType,
4117 TosaErrorValidator.evWrongInputList,
4118 TosaErrorValidator.evWrongOutputList,
4119 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004120 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004121 "negate": {
4122 "op": Op.NEGATE,
4123 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004124 "build_fcn": (
4125 build_unary,
4126 TosaTensorGen.tgBasic,
4127 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004128 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004129 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004130 "qgen": TosaQuantGen.qgUnary,
4131 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004132 "error_if_validators": (
4133 TosaErrorValidator.evInputZeroPointNotZero,
4134 TosaErrorValidator.evOutputZeroPointNotZero,
4135 TosaErrorValidator.evWrongInputType,
4136 TosaErrorValidator.evWrongOutputType,
4137 TosaErrorValidator.evWrongInputList,
4138 TosaErrorValidator.evWrongOutputList,
4139 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004140 "data_gen": {
4141 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4142 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004143 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004144 "reciprocal": {
4145 "op": Op.RECIPROCAL,
4146 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004147 "build_fcn": (
4148 build_unary,
4149 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004150 TosaTensorValuesGen.tvgLazyGenDefault,
4151 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004152 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004153 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004154 "error_if_validators": (
4155 TosaErrorValidator.evWrongInputType,
4156 TosaErrorValidator.evWrongOutputType,
4157 TosaErrorValidator.evWrongInputList,
4158 TosaErrorValidator.evWrongOutputList,
4159 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004160 "data_gen": {
4161 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4162 },
4163 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004164 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004165 "rsqrt": {
4166 "op": Op.RSQRT,
4167 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004168 "build_fcn": (
4169 build_unary,
4170 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004171 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004172 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004173 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004174 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004175 "error_if_validators": (
4176 TosaErrorValidator.evWrongInputType,
4177 TosaErrorValidator.evWrongOutputType,
4178 TosaErrorValidator.evWrongInputList,
4179 TosaErrorValidator.evWrongOutputList,
4180 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004181 "data_gen": {
4182 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4183 },
4184 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004185 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004186 # Elementwise Ternary operators
4187 "select": {
4188 "op": Op.SELECT,
4189 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004190 "build_fcn": (
4191 build_select,
4192 TosaTensorGen.tgBroadcastFuzz,
4193 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004194 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004195 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004196 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004197 "error_if_validators": (
4198 TosaErrorValidator.evRankMismatch,
4199 TosaErrorValidator.evWrongInputType,
4200 TosaErrorValidator.evWrongOutputType,
4201 TosaErrorValidator.evWrongInputList,
4202 TosaErrorValidator.evWrongOutputList,
4203 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004204 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004205 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004206 "data_gen": {
4207 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4208 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004209 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004210 # Comparison operators
4211 "equal": {
4212 "op": Op.EQUAL,
4213 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004214 "build_fcn": (
4215 build_comparison,
4216 TosaTensorGen.tgBroadcastFuzz,
4217 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004218 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004219 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004220 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004221 "error_if_validators": (
4222 TosaErrorValidator.evRankMismatch,
4223 TosaErrorValidator.evWrongInputType,
4224 TosaErrorValidator.evWrongOutputType,
4225 TosaErrorValidator.evWrongInputList,
4226 TosaErrorValidator.evWrongOutputList,
4227 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004228 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004229 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004230 "data_gen": {
4231 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4232 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004233 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004234 "greater_equal": {
4235 "op": Op.GREATER_EQUAL,
4236 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004237 "build_fcn": (
4238 build_comparison,
4239 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004240 TosaTensorValuesGen.tvgLazyGenDefault,
4241 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004242 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004243 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004244 "error_if_validators": (
4245 TosaErrorValidator.evRankMismatch,
4246 TosaErrorValidator.evWrongInputType,
4247 TosaErrorValidator.evWrongOutputType,
4248 TosaErrorValidator.evWrongInputList,
4249 TosaErrorValidator.evWrongOutputList,
4250 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004251 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004252 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004253 "data_gen": {
4254 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4255 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004256 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004257 "greater": {
4258 "op": Op.GREATER,
4259 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004260 "build_fcn": (
4261 build_comparison,
4262 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004263 TosaTensorValuesGen.tvgLazyGenDefault,
4264 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004265 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004266 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004267 "error_if_validators": (
4268 TosaErrorValidator.evRankMismatch,
4269 TosaErrorValidator.evWrongInputType,
4270 TosaErrorValidator.evWrongOutputType,
4271 TosaErrorValidator.evWrongInputList,
4272 TosaErrorValidator.evWrongOutputList,
4273 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004274 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004275 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004276 "data_gen": {
4277 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4278 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004279 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004280 # Reduction operators
4281 "reduce_all": {
4282 "op": Op.REDUCE_ALL,
4283 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004284 "build_fcn": (
4285 build_reduce,
4286 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004287 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004288 TosaArgGen.agAxis,
4289 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004290 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004291 "error_if_validators": (
4292 TosaErrorValidator.evAxisLargerRank,
4293 TosaErrorValidator.evAxisSmallerZero,
4294 TosaErrorValidator.evShapeOfAxisNotOne,
4295 TosaErrorValidator.evWrongInputType,
4296 TosaErrorValidator.evWrongOutputType,
4297 TosaErrorValidator.evWrongRank,
4298 TosaErrorValidator.evWrongInputList,
4299 TosaErrorValidator.evWrongOutputList,
4300 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004301 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004302 "reduce_any": {
4303 "op": Op.REDUCE_ANY,
4304 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004305 "build_fcn": (
4306 build_reduce,
4307 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004308 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004309 TosaArgGen.agAxis,
4310 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004311 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004312 "error_if_validators": (
4313 TosaErrorValidator.evAxisLargerRank,
4314 TosaErrorValidator.evAxisSmallerZero,
4315 TosaErrorValidator.evShapeOfAxisNotOne,
4316 TosaErrorValidator.evWrongInputType,
4317 TosaErrorValidator.evWrongOutputType,
4318 TosaErrorValidator.evWrongRank,
4319 TosaErrorValidator.evWrongInputList,
4320 TosaErrorValidator.evWrongOutputList,
4321 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004322 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004323 "reduce_max": {
4324 "op": Op.REDUCE_MAX,
4325 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004326 "build_fcn": (
4327 build_reduce,
4328 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004329 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004330 TosaArgGen.agAxis,
4331 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004332 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004333 "error_if_validators": (
4334 TosaErrorValidator.evAxisLargerRank,
4335 TosaErrorValidator.evAxisSmallerZero,
4336 TosaErrorValidator.evShapeOfAxisNotOne,
4337 TosaErrorValidator.evWrongInputType,
4338 TosaErrorValidator.evWrongOutputType,
4339 TosaErrorValidator.evWrongRank,
4340 TosaErrorValidator.evWrongInputList,
4341 TosaErrorValidator.evWrongOutputList,
4342 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004343 "data_gen": {
4344 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4345 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004346 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004347 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004348 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004349 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004350 "build_fcn": (
4351 build_reduce,
4352 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004353 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004354 TosaArgGen.agAxis,
4355 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004356 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004357 "error_if_validators": (
4358 TosaErrorValidator.evAxisLargerRank,
4359 TosaErrorValidator.evAxisSmallerZero,
4360 TosaErrorValidator.evShapeOfAxisNotOne,
4361 TosaErrorValidator.evWrongInputType,
4362 TosaErrorValidator.evWrongOutputType,
4363 TosaErrorValidator.evWrongRank,
4364 TosaErrorValidator.evWrongInputList,
4365 TosaErrorValidator.evWrongOutputList,
4366 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004367 "data_gen": {
4368 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4369 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004370 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004371 "reduce_product": {
4372 "op": Op.REDUCE_PRODUCT,
4373 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004374 "build_fcn": (
4375 build_reduce,
4376 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004377 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004378 TosaArgGen.agAxis,
4379 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004380 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004381 "error_if_validators": (
4382 TosaErrorValidator.evAxisLargerRank,
4383 TosaErrorValidator.evAxisSmallerZero,
4384 TosaErrorValidator.evShapeOfAxisNotOne,
4385 TosaErrorValidator.evWrongInputType,
4386 TosaErrorValidator.evWrongOutputType,
4387 TosaErrorValidator.evWrongRank,
4388 TosaErrorValidator.evWrongInputList,
4389 TosaErrorValidator.evWrongOutputList,
4390 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004391 "data_gen": {
4392 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4393 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004394 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004395 "reduce_sum": {
4396 "op": Op.REDUCE_SUM,
4397 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004398 "build_fcn": (
4399 build_reduce,
4400 TosaTensorGen.tgBasic,
4401 TosaTensorValuesGen.tvgReduceSum,
4402 TosaArgGen.agAxis,
4403 ),
James Ward24dbc422022-10-19 12:20:31 +01004404 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004405 "error_if_validators": (
4406 TosaErrorValidator.evAxisLargerRank,
4407 TosaErrorValidator.evAxisSmallerZero,
4408 TosaErrorValidator.evShapeOfAxisNotOne,
4409 TosaErrorValidator.evWrongInputType,
4410 TosaErrorValidator.evWrongOutputType,
4411 TosaErrorValidator.evWrongRank,
4412 TosaErrorValidator.evWrongInputList,
4413 TosaErrorValidator.evWrongOutputList,
4414 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004415 "data_gen": {
4416 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4417 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004418 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004419 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004420 "concat": {
4421 "op": Op.CONCAT,
4422 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004423 "build_fcn": (
4424 build_concat,
4425 TosaTensorGen.tgConcat,
4426 TosaTensorValuesGen.tvgConcat,
4427 TosaArgGen.agAxis,
4428 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004429 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004430 "error_if_validators": (
4431 TosaErrorValidator.evAxisLargerRank,
4432 TosaErrorValidator.evAxisSmallerZero,
4433 TosaErrorValidator.evConcatInputRankMismatch,
4434 TosaErrorValidator.evConcatShapeSumMismatch,
4435 TosaErrorValidator.evConcatInputDimMismatch,
4436 TosaErrorValidator.evWrongInputType,
4437 TosaErrorValidator.evWrongOutputType,
4438 TosaErrorValidator.evWrongOutputList,
4439 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004440 "data_gen": {
4441 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4442 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004443 },
4444 "pad": {
4445 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004446 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004447 "build_fcn": (
4448 build_pad,
4449 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004450 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004451 TosaArgGen.agPad,
4452 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004453 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004454 "error_if_validators": (
4455 TosaErrorValidator.evWrongInputType,
4456 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004457 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004458 TosaErrorValidator.evWrongOutputType,
4459 TosaErrorValidator.evWrongInputList,
4460 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004461 TosaErrorValidator.evRankMismatch,
4462 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004463 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004464 "data_gen": {
4465 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4466 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004467 },
Won Jeona21b2e82023-08-10 10:33:01 +00004468 "dim": {
4469 "op": Op.DIM,
4470 "operands": (1, 0),
4471 "build_fcn": (
4472 build_dim,
4473 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004474 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004475 TosaArgGen.agAxis,
4476 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004477 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004478 "error_if_validators": (
4479 TosaErrorValidator.evAxisLargerRank,
4480 TosaErrorValidator.evAxisSmallerZero,
4481 TosaErrorValidator.evWrongInputType,
4482 TosaErrorValidator.evWrongInputList,
4483 TosaErrorValidator.evWrongOutputList,
4484 TosaErrorValidator.evWrongRank,
4485 ),
4486 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004487 "reshape": {
4488 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004489 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004490 "build_fcn": (
4491 build_reshape,
4492 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004493 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004494 TosaArgGen.agReshape,
4495 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004496 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004497 "error_if_validators": (
4498 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4499 TosaErrorValidator.evWrongInputType,
4500 TosaErrorValidator.evWrongOutputType,
4501 TosaErrorValidator.evWrongInputList,
4502 TosaErrorValidator.evWrongOutputList,
4503 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004504 "data_gen": {
4505 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4506 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004507 },
4508 "reverse": {
4509 "op": Op.REVERSE,
4510 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004511 "build_fcn": (
4512 build_reverse,
4513 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004514 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004515 TosaArgGen.agAxis,
4516 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004517 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004518 "error_if_validators": (
4519 TosaErrorValidator.evAxisSmallerZero,
4520 TosaErrorValidator.evAxisLargerRank,
4521 TosaErrorValidator.evWrongInputType,
4522 TosaErrorValidator.evWrongOutputType,
4523 TosaErrorValidator.evWrongInputList,
4524 TosaErrorValidator.evWrongOutputList,
4525 ),
evacha0198477222024-01-26 12:25:32 +00004526 "data_gen": {
4527 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4528 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004529 },
4530 "slice": {
4531 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004532 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004533 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004534 "build_fcn": (
4535 build_slice,
4536 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004537 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004538 TosaArgGen.agSlice,
4539 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004540 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004541 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004542 # TODO Turn off these error categories for now as the reference
4543 # model cannot allocate memory space for empty tensor. We probably
4544 # can report an accurate error messege at the right place during
4545 # exeuction.
4546 # TosaErrorValidator.evStartSmallerZero,
4547 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004548 TosaErrorValidator.evStartSizeOutsideBounds,
4549 TosaErrorValidator.evSizeOutputShapeMismatch,
4550 TosaErrorValidator.evInputSizeStartLengthMismatch,
4551 TosaErrorValidator.evWrongRank,
4552 TosaErrorValidator.evWrongInputType,
4553 TosaErrorValidator.evWrongOutputType,
4554 TosaErrorValidator.evWrongInputList,
4555 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004556 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004557 ),
evacha017f7d4252024-01-24 12:08:09 +00004558 "data_gen": {
4559 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4560 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004561 },
4562 "tile": {
4563 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004564 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004565 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004566 "build_fcn": (
4567 build_tile,
4568 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004569 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004570 TosaArgGen.agTile,
4571 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004572 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004573 "error_if_validators": (
4574 TosaErrorValidator.evWrongInputType,
4575 TosaErrorValidator.evWrongOutputType,
4576 TosaErrorValidator.evWrongInputList,
4577 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004578 TosaErrorValidator.evRankMismatch,
4579 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004580 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004581 "data_gen": {
4582 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4583 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004584 },
4585 "transpose": {
4586 "op": Op.TRANSPOSE,
4587 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004588 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004589 "build_fcn": (
4590 build_transpose,
4591 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004592 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004593 TosaArgGen.agTranspose,
4594 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004595 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004596 "error_if_validators": (
4597 TosaErrorValidator.evIndexOutsideBounds,
4598 TosaErrorValidator.evIndexUsedTwice,
4599 TosaErrorValidator.evWrongInputType,
4600 TosaErrorValidator.evWrongOutputType,
4601 TosaErrorValidator.evWrongInputList,
4602 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004603 TosaErrorValidator.evWrongRank,
4604 TosaErrorValidator.evRankMismatch,
4605 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004606 ),
evacha0198477222024-01-26 12:25:32 +00004607 "data_gen": {
4608 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4609 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004610 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004611 # Data nodes
4612 "const": {
4613 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004614 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004615 "build_fcn": (
4616 build_const,
4617 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004618 TosaTensorValuesGen.tvgLazyGenDefault,
4619 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004620 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004621 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha0198477222024-01-26 12:25:32 +00004622 "data_gen": {
4623 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4624 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004625 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004626 "identity": {
4627 "op": Op.IDENTITY,
4628 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004629 "build_fcn": (
4630 build_unary,
4631 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004632 TosaTensorValuesGen.tvgLazyGenDefault,
4633 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004634 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004635 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004636 "data_gen": {
4637 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4638 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004639 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004640 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004641 "gather": {
4642 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004643 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004644 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004645 "build_fcn": (
4646 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004647 TosaTensorGen.tgGather,
4648 TosaTensorValuesGen.tvgGather,
4649 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004650 ),
James Ward24dbc422022-10-19 12:20:31 +01004651 "types": (
4652 DType.INT8,
4653 DType.INT16,
4654 DType.INT32,
4655 DType.FP16,
4656 DType.BF16,
4657 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004658 DType.FP8E4M3,
4659 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004660 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004661 "error_if_validators": (
4662 TosaErrorValidator.evWrongInputType,
4663 TosaErrorValidator.evWrongOutputType,
4664 TosaErrorValidator.evWrongInputList,
4665 TosaErrorValidator.evWrongOutputList,
4666 TosaErrorValidator.evWrongRank,
4667 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004668 "data_gen": {
4669 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4670 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004671 },
4672 "scatter": {
4673 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004674 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004675 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004676 "build_fcn": (
4677 build_scatter,
4678 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004679 TosaTensorValuesGen.tvgScatter,
4680 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004681 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004682 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004683 "error_if_validators": (
4684 TosaErrorValidator.evWrongInputType,
4685 TosaErrorValidator.evWrongOutputType,
4686 TosaErrorValidator.evWrongInputList,
4687 TosaErrorValidator.evWrongOutputList,
4688 TosaErrorValidator.evWrongRank,
4689 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004690 "data_gen": {
4691 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4692 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004693 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004694 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004695 "resize": {
4696 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004697 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004698 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004699 "build_fcn": (
4700 build_resize,
4701 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004702 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004703 TosaArgGen.agResize,
4704 ),
James Ward24dbc422022-10-19 12:20:31 +01004705 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004706 "invalid_test_validators": (
4707 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004708 ),
4709 "error_if_validators": (
4710 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004711 TosaErrorValidator.evScaleSmallerEqualZero,
4712 TosaErrorValidator.evScaleNLargerMax,
4713 TosaErrorValidator.evScaleDLargerMax,
4714 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004715 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004716 TosaErrorValidator.evBorderSmallerMin,
4717 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004718 TosaErrorValidator.evWrongInputType,
4719 TosaErrorValidator.evWrongOutputType,
4720 TosaErrorValidator.evWrongRank,
4721 TosaErrorValidator.evWrongInputList,
4722 TosaErrorValidator.evWrongOutputList,
4723 TosaErrorValidator.evBatchMismatch,
4724 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004725 TosaErrorValidator.evResizeOutputShapeMismatch,
4726 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004727 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004728 "data_gen": {
4729 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4730 },
4731 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004732 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004733 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004734 "cast": {
4735 "op": Op.CAST,
4736 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004737 "build_fcn": (
4738 build_cast,
4739 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004740 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004741 TosaArgGen.agCast,
4742 ),
James Ward8b390432022-08-12 20:48:56 +01004743 "types": (
4744 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004745 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004746 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004747 DType.INT8,
4748 DType.INT16,
4749 DType.INT32,
4750 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004751 DType.FP8E4M3,
4752 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004753 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004754 "error_if_validators": (
4755 TosaErrorValidator.evWrongInputType,
4756 TosaErrorValidator.evWrongOutputType,
4757 TosaErrorValidator.evWrongInputList,
4758 TosaErrorValidator.evWrongOutputList,
4759 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004760 "data_gen": {
4761 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4762 },
4763 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004764 },
4765 "rescale": {
4766 "op": Op.RESCALE,
4767 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004768 "build_fcn": (
4769 build_rescale,
4770 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00004771 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004772 TosaArgGen.agRescale,
4773 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004774 "types": [
4775 DType.UINT8,
4776 DType.INT8,
4777 DType.INT16,
4778 DType.INT32,
4779 DType.INT48,
4780 DType.UINT16,
4781 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004782 "error_if_validators": (
4783 TosaErrorValidator.evInputZeroPointNotZero,
4784 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004785 TosaErrorValidator.evU16InputZeroPointNotValid,
4786 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004787 TosaErrorValidator.evScaleTrue,
4788 TosaErrorValidator.evScaleNotTrue,
4789 TosaErrorValidator.evWrongInputType,
4790 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004791 TosaErrorValidator.evWrongInputList,
4792 TosaErrorValidator.evWrongOutputList,
4793 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004794 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004795 # Custom
4796 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004797 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004798 # Two varients of cond_if, one that generates one of two constant tensors (no
4799 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4800 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004801 "cond_if_const": {
4802 "op": Op.COND_IF,
4803 "operands": (0, 2),
4804 "build_fcn": (
4805 build_cond_if_const,
4806 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004807 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004808 TosaArgGen.agCondIf,
4809 ),
4810 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004811 "error_if_validators": (
4812 TosaErrorValidator.evOutputListThenGraphMismatch,
4813 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004814 TosaErrorValidator.evCondIfCondNotMatchingBool,
4815 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004816 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004817 },
4818 "cond_if_binary": {
4819 "op": Op.COND_IF,
4820 "operands": (2, 0),
4821 "build_fcn": (
4822 build_cond_if_binary,
4823 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004824 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004825 TosaArgGen.agCondIf,
4826 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004827 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004828 "error_if_validators": (
4829 TosaErrorValidator.evInputListThenGraphMismatch,
4830 TosaErrorValidator.evInputListElseGraphMismatch,
4831 TosaErrorValidator.evOutputListThenGraphMismatch,
4832 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004833 TosaErrorValidator.evCondIfCondNotMatchingBool,
4834 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004835 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004836 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004837 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004838 "while_loop": {
4839 "op": Op.WHILE_LOOP,
4840 "operands": (0, 1),
4841 "build_fcn": (
4842 build_while_loop,
4843 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004844 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004845 TosaArgGen.agWhileLoop,
4846 ),
4847 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004848 "error_if_validators": (
4849 TosaErrorValidator.evInputListOutputListMismatch,
4850 TosaErrorValidator.evInputListCondGraphMismatch,
4851 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4852 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4853 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004854 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004855 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004856 },
Luke Hutton57287132023-02-06 14:54:18 +00004857 "fft2d": {
4858 "op": Op.FFT2D,
4859 "operands": (2, 0),
4860 "rank": (3, 3),
4861 "build_fcn": (
4862 build_fft2d,
4863 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004864 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004865 TosaArgGen.agFFT2d,
4866 ),
4867 "types": [DType.FP32],
4868 "error_if_validators": (
4869 TosaErrorValidator.evWrongInputType,
4870 TosaErrorValidator.evWrongOutputType,
4871 TosaErrorValidator.evWrongInputList,
4872 TosaErrorValidator.evWrongOutputList,
4873 TosaErrorValidator.evWrongRank,
4874 TosaErrorValidator.evBatchMismatch,
4875 TosaErrorValidator.evKernelNotPowerOfTwo,
4876 TosaErrorValidator.evFFTInputShapeMismatch,
4877 TosaErrorValidator.evFFTOutputShapeMismatch,
4878 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004879 "data_gen": {
4880 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4881 },
Luke Hutton57287132023-02-06 14:54:18 +00004882 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004883 "rfft2d": {
4884 "op": Op.RFFT2D,
4885 "operands": (1, 0),
4886 "rank": (3, 3),
4887 "build_fcn": (
4888 build_rfft2d,
4889 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004890 TosaTensorValuesGen.tvgLazyGenDefault,
4891 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004892 ),
4893 "types": [DType.FP32],
4894 "error_if_validators": (
4895 TosaErrorValidator.evWrongInputType,
4896 TosaErrorValidator.evWrongOutputType,
4897 TosaErrorValidator.evWrongInputList,
4898 TosaErrorValidator.evWrongOutputList,
4899 TosaErrorValidator.evWrongRank,
4900 TosaErrorValidator.evBatchMismatch,
4901 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004902 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004903 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004904 "data_gen": {
4905 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4906 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004907 },
Won Jeon74342e52024-01-09 00:34:40 +00004908 # Shape
4909 "add_shape": {
4910 "op": Op.ADD_SHAPE,
4911 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004912 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004913 "build_fcn": (
4914 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004915 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004916 TosaTensorValuesGen.tvgAddSub,
4917 TosaArgGen.agNone,
4918 ),
4919 "types": [DType.SHAPE],
4920 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4921 },
4922 "sub_shape": {
4923 "op": Op.SUB_SHAPE,
4924 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004925 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004926 "build_fcn": (
4927 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004928 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004929 TosaTensorValuesGen.tvgAddSub,
4930 TosaArgGen.agNone,
4931 ),
4932 "types": [DType.SHAPE],
4933 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4934 },
4935 "mul_shape": {
4936 "op": Op.MUL_SHAPE,
4937 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004938 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004939 "build_fcn": (
4940 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004941 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004942 TosaTensorValuesGen.tvgMul,
4943 TosaArgGen.agNone,
4944 ),
4945 "types": [DType.SHAPE],
4946 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4947 },
4948 "div_shape": {
4949 "op": Op.DIV_SHAPE,
4950 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004951 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004952 "build_fcn": (
4953 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004954 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004955 TosaTensorValuesGen.tvgIntDiv,
4956 TosaArgGen.agNone,
4957 ),
4958 "types": [DType.SHAPE],
4959 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4960 },
4961 "concat_shape": {
4962 "op": Op.CONCAT_SHAPE,
4963 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004964 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004965 "build_fcn": (
4966 build_concat,
4967 TosaTensorGen.tgConcat,
4968 TosaTensorValuesGen.tvgConcat,
4969 TosaArgGen.agNone,
4970 ),
4971 "types": [DType.SHAPE],
4972 "error_if_validators": (),
4973 },
4974 "const_shape": {
4975 "op": Op.CONST_SHAPE,
4976 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004977 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004978 "build_fcn": (
4979 build_const,
4980 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004981 TosaTensorValuesGen.tvgLazyGenDefault,
4982 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00004983 ),
4984 "types": [DType.SHAPE],
4985 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004986 }
4987
Kevin Cheng550ccc52021-03-03 11:21:43 -08004988
Eric Kunzee5e26762020-10-13 16:11:07 -07004989class OutputShaper:
4990 # Methods in this class compute the expected output shape and datatype
4991 # for common classes of operations
4992 def __init__(self):
4993 pass
4994
4995 # These methods return arguments that can be used for
4996 # creating a new output tensor
4997 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004998 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4999 if error_name != ErrorIf.RankMismatch:
5000 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005001 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005002
5003 shape = []
5004 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005005 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005006 shape.append(b.shape[i])
5007 else:
5008 shape.append(a.shape[i])
5009
Jerry Ge135c9552023-05-23 20:59:32 +00005010 fuzz_idx = rng.integers(0, len(a.shape))
5011 if error_name == ErrorIf.DimensionMismatch:
5012 shape[fuzz_idx] += 1
5013
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005014 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005015 all_dtypes = [
5016 DType.INT8,
5017 DType.INT16,
5018 DType.INT32,
5019 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005020 DType.FP16,
5021 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005022 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005023 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005024 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5025 outputDType = rng.choice(wrong_dtypes)
5026 else:
5027 outputDType = a.dtype
5028
5029 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005030
5031 @staticmethod
5032 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005033 assert len(a.shape) == len(b.shape)
5034 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005035
5036 shape = []
5037 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005038 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005039 shape.append(a.shape[i])
5040
Kevin Cheng550ccc52021-03-03 11:21:43 -08005041 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005042
5043 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005044 def unaryOp(ser, rng, a, error_name=None):
5045 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005046 all_dtypes = [
5047 DType.INT8,
5048 DType.INT16,
5049 DType.INT32,
5050 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005051 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005052 DType.FP16,
5053 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005054 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005055 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5056 outputDType = rng.choice(wrong_dtypes)
5057 else:
5058 outputDType = a.dtype
5059
5060 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005061
5062 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005063 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005064 if error_name != ErrorIf.RankMismatch:
5065 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005066 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005067
5068 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005069 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005070 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005071 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5072 else:
5073 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005074
Jerry Ge135c9552023-05-23 20:59:32 +00005075 fuzz_idx = rng.integers(0, len(a.shape))
5076 if error_name == ErrorIf.DimensionMismatch:
5077 shape[fuzz_idx] += 1
5078
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005079 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005080 all_dtypes = [
5081 DType.INT8,
5082 DType.INT16,
5083 DType.INT32,
5084 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005085 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005086 DType.FP16,
5087 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005088 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005089 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5090 outputDType = rng.choice(wrong_dtypes)
5091 else:
5092 outputDType = a.dtype
5093
5094 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005095
5096 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005097 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005098 if error_name != ErrorIf.RankMismatch:
5099 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005100 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005101
5102 # Do broadcast
5103 shape = []
5104 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005105 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005106 shape.append(b.shape[i])
5107 else:
5108 shape.append(a.shape[i])
5109
Jerry Ge135c9552023-05-23 20:59:32 +00005110 fuzz_idx = rng.integers(0, len(a.shape))
5111 if error_name == ErrorIf.DimensionMismatch:
5112 shape[fuzz_idx] += 1
5113
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005114 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005115 wrong_dtypes = [
5116 DType.INT8,
5117 DType.INT16,
5118 DType.INT32,
5119 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005120 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005121 DType.FP16,
5122 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005123 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005124 outputDType = rng.choice(wrong_dtypes)
5125 else:
5126 outputDType = DType.BOOL
5127
5128 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005129
5130 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005131 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005132 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005133 if error_name not in [
5134 ErrorIf.AxisSmallerZero,
5135 ErrorIf.AxisLargerRank,
5136 ErrorIf.ShapeOfAxisNotOne,
5137 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005138 shape[axis] = 1
5139 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5140 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005141
Matthew Haddond6ce7252021-09-29 15:35:44 +01005142 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005143 all_dtypes = [
5144 DType.INT8,
5145 DType.INT16,
5146 DType.INT32,
5147 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005148 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005149 DType.FP16,
5150 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005151 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005152 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5153 outputDType = rng.choice(wrong_dtypes)
5154 else:
5155 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005156
Matthew Haddond6ce7252021-09-29 15:35:44 +01005157 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005158
5159 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005160 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005161 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005162
5163 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5164 del shape[axis]
5165
5166 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5167 remove = rng.choice([True, False])
5168 if remove and len(shape) > 1:
5169 del shape[0]
5170 else:
5171 shape.append(1)
5172 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5173 for i in range(len(shape)):
5174 shape[i] = shape[i] + rng.integers(1, 10)
5175
5176 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005177 all_dtypes = [
5178 DType.INT8,
5179 DType.INT16,
5180 DType.INT32,
5181 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005182 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005183 DType.FP16,
5184 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005185 DType.FP8E4M3,
5186 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005187 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005188 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5189 outputDType = rng.choice(wrong_dtypes)
5190 else:
5191 outputDType = DType.INT32
5192
5193 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005194
5195 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005196 def conv2dOp(
5197 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5198 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005199
5200 # IFM: NHWC
5201 # Filter: OHWI
5202 # OFM: NHWC
5203
Kevin Cheng550ccc52021-03-03 11:21:43 -08005204 h = (
5205 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005206 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005207 + padding[0]
5208 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005209 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005210 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005211
Kevin Cheng550ccc52021-03-03 11:21:43 -08005212 w = (
5213 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005214 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005215 + padding[2]
5216 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005217 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005218 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005219
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005220 if error_name == ErrorIf.ConvOutputShapeMismatch:
5221 choices = [1, 2, 3]
5222 change = rng.choice(choices)
5223 # increment in multiples of stride to not hit non-integer error case
5224 if change in [1, 3]:
5225 h = h + (rng.choice(choices) * strides[0])
5226 if change in [2, 3]:
5227 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005228
Eric Kunzee5e26762020-10-13 16:11:07 -07005229 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5230
James Ward8b390432022-08-12 20:48:56 +01005231 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005232 # Pick some potentially correct output dtype if input type is incorrect
5233 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005234 else:
James Ward8b390432022-08-12 20:48:56 +01005235 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005236
5237 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005238 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005239 excludes = [DType.FP16, DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00005240 if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
5241 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005242 else:
5243 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005244 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005245 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005246
Kevin Cheng550ccc52021-03-03 11:21:43 -08005247 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005248
5249 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005250 def conv3dOp(
5251 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5252 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005253
5254 # IFM: NDHWC
5255 # Filter: ODHWI
5256 # OFM: NDHWC
5257
5258 d = (
5259 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005260 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005261 + padding[0]
5262 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005263 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005264 ) // strides[0] + 1
5265
5266 h = (
5267 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005268 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005269 + padding[2]
5270 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005271 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005272 ) // strides[1] + 1
5273
5274 w = (
5275 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005276 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005277 + padding[4]
5278 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005279 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005280 ) // strides[2] + 1
5281
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005282 if error_name == ErrorIf.ConvOutputShapeMismatch:
5283 choices = [1, 2, 3, 4]
5284 change = rng.choice(choices)
5285 # increment in multiples of stride to not hit non-integer error case
5286 if change in [1, 4]:
5287 d = d + (rng.choice(choices) * strides[0])
5288 if change in [2, 4]:
5289 h = h + (rng.choice(choices) * strides[1])
5290 if change in [3, 4]:
5291 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005292
Kevin Cheng1533b852021-09-01 12:51:58 -07005293 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5294
James Ward8b390432022-08-12 20:48:56 +01005295 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005296 # Pick some potentially correct output dtype if input type is incorrect
5297 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005298 else:
James Ward8b390432022-08-12 20:48:56 +01005299 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005300
5301 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005302 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005303 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005304 else:
5305 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005306 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005307 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005308
5309 return ser.addOutput(ofm_shape, out_dtype)
5310
5311 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005312 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005313 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005314 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005315 # IFM: NHWC
5316 # Filter: HWCM
5317 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005318
Kevin Cheng550ccc52021-03-03 11:21:43 -08005319 h = (
5320 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005321 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005322 + padding[0]
5323 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005324 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005325 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005326
Kevin Cheng550ccc52021-03-03 11:21:43 -08005327 w = (
5328 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005329 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005330 + padding[2]
5331 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005332 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005333 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005334
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005335 if error_name == ErrorIf.ConvOutputShapeMismatch:
5336 choices = [1, 2, 3]
5337 change = rng.choice(choices)
5338 # increment in multiples of stride to not hit non-integer error case
5339 if change in [1, 3]:
5340 h = h + (rng.choice(choices) * strides[0])
5341 if change in [2, 3]:
5342 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005343
Eric Kunzee5e26762020-10-13 16:11:07 -07005344 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5345
James Ward8b390432022-08-12 20:48:56 +01005346 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005347 # Pick some potentially correct output dtype if input type is incorrect
5348 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005349 else:
James Ward8b390432022-08-12 20:48:56 +01005350 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005351
5352 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005353 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005354 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005355 else:
5356 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005357 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005358 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005359
Kevin Cheng550ccc52021-03-03 11:21:43 -08005360 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005361
5362 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005363 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005364 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005365 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005366 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005367 h = 1
5368 w = 1
5369 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005370 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5371 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005372
5373 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005374 choices = [1, 2, 3]
5375 change = rng.choice(choices)
5376 # increment in multiples of stride to not hit non-integer error case
5377 if change in [1, 3]:
5378 h = h + (rng.choice(choices) * stride[0])
5379 if change in [2, 3]:
5380 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005381 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005382
5383 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005384 all_dtypes = [
5385 DType.INT8,
5386 DType.INT16,
5387 DType.INT32,
5388 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005389 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005390 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005391 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005392 DType.FP8E4M3,
5393 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005394 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005395 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5396 outputDType = rng.choice(wrong_dtypes)
5397 else:
5398 outputDType = ifm.dtype
5399
5400 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005401
5402 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005403 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005404 # input: N, IC
5405 # filter: OC, IC
5406 # output: N, OC
5407
5408 output_shape = [input.shape[0], filter.shape[0]]
5409
James Ward8b390432022-08-12 20:48:56 +01005410 # Validated in arg_gen (also invalidated for ErrorIf)
5411 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005412
Kevin Cheng550ccc52021-03-03 11:21:43 -08005413 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005414
5415 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005416 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005417 # a: N, H, C
5418 # b: N, C, W
5419 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005420
Kevin Cheng2d60f002021-06-09 14:18:32 -07005421 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005422
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005423 if error_name == ErrorIf.WrongOutputType:
5424 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005425 incorrect_types = (
5426 DType.INT4,
5427 DType.INT8,
5428 DType.INT16,
5429 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005430 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005431 DType.FP16,
5432 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005433 DType.FP8E4M3,
5434 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005435 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005436 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005437 incorrect_types = (
5438 DType.INT4,
5439 DType.INT8,
5440 DType.INT16,
5441 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005442 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005443 DType.FP16,
5444 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005445 DType.FP8E4M3,
5446 DType.FP8E5M2,
5447 )
5448 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5449 incorrect_types = (
5450 DType.INT4,
5451 DType.INT8,
5452 DType.INT16,
5453 DType.INT32,
5454 DType.INT48,
5455 DType.FP32,
5456 DType.BF16,
5457 DType.FP8E4M3,
5458 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005459 )
James Ward24dbc422022-10-19 12:20:31 +01005460 elif (
5461 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5462 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005463 incorrect_types = (
5464 DType.INT4,
5465 DType.INT8,
5466 DType.INT16,
5467 DType.INT32,
5468 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005469 DType.FP8E4M3,
5470 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005471 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005472 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005473 elif error_name == ErrorIf.WrongInputType:
5474 # Pick some potentially correct output dtype if input type is incorrect
5475 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005476 else:
James Ward8b390432022-08-12 20:48:56 +01005477 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005478
Kevin Cheng550ccc52021-03-03 11:21:43 -08005479 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005480
5481 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005482 def concatOp(ser, rng, axis, inputs, error_name=None):
5483 input1 = inputs[0]
5484 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005485
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005486 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005487 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005488 if not (
5489 # unable to concat tensors of different ranks
5490 error_name == ErrorIf.ConcatInputRankMismatch
5491 # unable to concat tensors along an invalid axis
5492 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005493 ):
5494 for tensor in remaining_inputs:
5495 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005496
Matthew Haddon01c359d2021-10-15 16:30:48 +01005497 if error_name == ErrorIf.ConcatShapeSumMismatch:
5498 output_shape[axis] += rng.integers(5, 10)
5499
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005500 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005501 all_dtypes = {
5502 DType.INT8,
5503 DType.INT16,
5504 DType.INT32,
5505 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005506 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005507 DType.FP16,
5508 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005509 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005510 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5511 outputDType = rng.choice(wrong_dtypes)
5512 else:
5513 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005514
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005515 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005516
5517 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005518 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005519
5520 output_shape = a.shape.copy()
5521
5522 for i in range(len(output_shape)):
5523 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5524
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005525 if error_name == ErrorIf.PadOutputShapeMismatch:
5526 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005527 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005528 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005529 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005530
Matthew Haddone807aae2021-10-11 18:12:58 +01005531 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005532 all_dtypes = [
5533 DType.INT8,
5534 DType.INT16,
5535 DType.INT32,
5536 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005537 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005538 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005539 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005540 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005541 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5542 outputDType = rng.choice(wrong_dtypes)
5543 else:
5544 outputDType = a.dtype
5545
5546 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005547
5548 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005549 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005550 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005551
5552 if error_name == ErrorIf.WrongOutputType:
5553 all_dtypes = [
5554 DType.INT8,
5555 DType.INT16,
5556 DType.INT32,
5557 DType.INT48,
5558 DType.FP32,
5559 DType.FP16,
5560 DType.BF16,
5561 ]
5562 wrong_dtypes = list(set(all_dtypes))
5563 outputDType = rng.choice(wrong_dtypes)
5564 else:
5565 outputDType = DType.SHAPE
5566
5567 return ser.addOutput(output_shape, outputDType)
5568
5569 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005570 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005571 output_shape = shape.copy()
5572
Matthew Haddone807aae2021-10-11 18:12:58 +01005573 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5574 for i in range(len(output_shape)):
5575 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5576
5577 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005578 all_dtypes = [
5579 DType.INT8,
5580 DType.INT16,
5581 DType.INT32,
5582 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005583 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005584 DType.FP16,
5585 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005586 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005587 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5588 outputDType = rng.choice(wrong_dtypes)
5589 else:
5590 outputDType = a.dtype
5591
5592 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005593
5594 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005595 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005596
Matthew Haddone807aae2021-10-11 18:12:58 +01005597 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005598 all_dtypes = [
5599 DType.INT8,
5600 DType.INT16,
5601 DType.INT32,
5602 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005603 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005604 DType.FP16,
5605 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005606 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005607 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005608 outputDType = rng.choice(wrong_dtypes)
5609 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005610 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005611
Luke Huttona4e48ca2023-02-22 11:53:48 +00005612 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005613 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005614 for index in range(len(output_shape)):
5615 if output_shape[index] <= 2:
5616 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5617 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005618 output_shape[index] = output_shape[index] + rng.choice(
5619 [-2, -1, 1, 2]
5620 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005621 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5622 output_shape = input.shape.copy()
5623 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005624 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005625
5626 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005627
5628 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005629 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005630
5631 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005632 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005633
5634 for i in range(len(output_shape)):
5635 output_shape[i] = a.shape[i] * multiples[i]
5636
Luke Huttona4e48ca2023-02-22 11:53:48 +00005637 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005638 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005639
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005640 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005641 all_dtypes = [
5642 DType.INT8,
5643 DType.INT16,
5644 DType.INT32,
5645 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005646 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005647 DType.FP16,
5648 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005649 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005650 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5651 outputDType = rng.choice(wrong_dtypes)
5652 else:
5653 outputDType = a.dtype
5654
5655 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005656
5657 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005658 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005659 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005660
Kevin Cheng550ccc52021-03-03 11:21:43 -08005661 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005662
Luke Huttona4e48ca2023-02-22 11:53:48 +00005663 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005664 for i in range(len(output_shape)):
5665 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005666
Luke Huttona4e48ca2023-02-22 11:53:48 +00005667 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5668 for i in range(len(output_shape)):
5669 output_shape[i] += rng.integers(1, 10)
5670 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005671 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005672
Matthew Haddone807aae2021-10-11 18:12:58 +01005673 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005674 all_dtypes = [
5675 DType.INT8,
5676 DType.INT16,
5677 DType.INT32,
5678 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005679 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005680 DType.FP16,
5681 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005682 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005683 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5684 outputDType = rng.choice(wrong_dtypes)
5685 else:
5686 outputDType = a.dtype
5687
5688 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005689
5690 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005691 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005692 if error_name != ErrorIf.WrongRank:
5693 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005694 assert len(indices.shape) == 2
5695 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005696
Kevin Cheng77d0f762020-11-24 10:26:32 -08005697 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5698
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005699 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005700 all_dtypes = [
5701 DType.INT8,
5702 DType.INT16,
5703 DType.INT32,
5704 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005705 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005706 DType.FP16,
5707 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005708 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005709 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5710 outputDType = rng.choice(wrong_dtypes)
5711 else:
5712 outputDType = values.dtype
5713
5714 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005715
5716 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005717 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005718 if error_name != ErrorIf.WrongRank:
5719 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005720 assert len(indices.shape) == 2
5721 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005722 assert values_in.shape[0] == indices.shape[0] # N
5723 assert input.shape[1] == indices.shape[1] # W
5724 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005725
5726 output_shape = values_in.shape
5727
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005728 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005729 all_dtypes = [
5730 DType.INT8,
5731 DType.INT16,
5732 DType.INT32,
5733 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005734 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005735 DType.FP16,
5736 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005737 DType.FP8E4M3,
5738 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005739 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005740 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5741 outputDType = rng.choice(wrong_dtypes)
5742 else:
5743 outputDType = values_in.dtype
5744
5745 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005746
5747 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005748 def tableOp(ser, rng, input, error_name=None):
5749 # Same shape as the input, dtype dependent on input dtype
5750 if error_name != ErrorIf.WrongInputType:
5751 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005752 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005753 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005754 wrong_dtypes = [
5755 DType.INT8,
5756 DType.INT16,
5757 DType.INT32,
5758 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005759 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005760 DType.FP16,
5761 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005762 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005763 wrong_dtypes.remove(output_dtype)
5764 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005765 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005766
5767 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005768 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005769 serializer,
5770 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005771 input,
5772 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005773 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005774 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005775 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005776 input_dtype,
5777 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005778 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005779 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005780 # Calculate OH, OW
5781 scale_y_n = scale[0]
5782 scale_y_d = scale[1]
5783 scale_x_n = scale[2]
5784 scale_x_d = scale[3]
5785 if error_name == ErrorIf.ScaleSmallerEqualZero:
5786 scale_y_n = max(scale_y_n, 1)
5787 scale_y_d = max(scale_y_d, 1)
5788 scale_x_n = max(scale_x_n, 1)
5789 scale_x_d = max(scale_x_d, 1)
5790
5791 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5792 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5793
5794 if error_name is not None:
5795 # Make sure the output tensor is valid, which can occur when
5796 # scale, offset or border have been changed for ERROR_IFs
5797 oh = max(oh, 1)
5798 ow = max(ow, 1)
5799 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005800 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5801 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005802
5803 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5804 choices = [1, 2, 3]
5805 change = rng.choice(choices)
5806 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5807 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005808 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005809 oh -= scale_y_d
5810 assert oh > 0 # Should have been caught in agResize
5811 else:
5812 oh += scale_y_d
5813 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005814 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005815 ow -= scale_x_d
5816 assert ow > 0 # Should have been caught in agResize
5817 else:
5818 ow += scale_x_d
5819
Matthew Haddon848efb42021-09-09 12:30:53 +01005820 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005821 output_dims = [
5822 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005823 oh,
5824 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005825 input.shape[0],
5826 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005827 elif error_name == ErrorIf.BatchMismatch:
5828 output_dims = [
5829 input.shape[0] + rng.integers(1, 10),
5830 oh,
5831 ow,
5832 input.shape[3],
5833 ]
5834 elif error_name == ErrorIf.ChannelMismatch:
5835 output_dims = [
5836 input.shape[0],
5837 oh,
5838 ow,
5839 input.shape[3] + rng.integers(1, 10),
5840 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005841 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005842 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005843
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005844 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005845
5846 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005847 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005848 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005849
5850 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005851 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005852 if error_name == ErrorIf.ConvOutputShapeMismatch:
5853 choices = [1, 2, 3]
5854 change = rng.choice(choices)
5855 if change in [1, 3]:
5856 output_shape[1] = output_shape[1] + rng.choice(choices)
5857 if change in [2, 3]:
5858 output_shape[2] = output_shape[2] + rng.choice(choices)
5859
James Ward8b390432022-08-12 20:48:56 +01005860 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005861 # Pick some potentially correct output dtype if input type is incorrect
5862 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005863 else:
James Ward8b390432022-08-12 20:48:56 +01005864 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005865
5866 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005867 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005868 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005869 else:
5870 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005871 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005872 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005873
Kevin Cheng550ccc52021-03-03 11:21:43 -08005874 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005875
5876 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005877 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5878 outputs = []
5879
5880 assert ifm1.dtype == ifm2.dtype
5881 input_dtype = ifm1.dtype
5882
5883 if error_name != ErrorIf.FFTInputShapeMismatch:
5884 assert ifm1.shape == ifm2.shape
5885
5886 input_shape = ifm1.shape
5887 if error_name != ErrorIf.WrongRank:
5888 assert len(input_shape) == 3
5889
5890 output_shape = input_shape.copy()
5891 output_dtype = input_dtype
5892
5893 if error_name == ErrorIf.WrongOutputType:
5894 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005895 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005896 output_dtype = rng.choice(wrong_dtypes)
5897 elif error_name == ErrorIf.BatchMismatch:
5898 output_shape[0] += rng.integers(1, 10)
5899 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5900 modify_dim = rng.choice([1, 2])
5901 output_shape[modify_dim] += rng.integers(1, 10)
5902
5903 outputs.append(serializer.addOutput(output_shape, output_dtype))
5904 outputs.append(serializer.addOutput(output_shape, output_dtype))
5905 return outputs
5906
5907 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005908 def rfft2dOp(serializer, rng, value, error_name=None):
5909 outputs = []
5910
5911 input_shape = value.shape
5912 if error_name != ErrorIf.WrongRank:
5913 assert len(input_shape) == 3
5914
5915 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5916
5917 output_dtype = value.dtype
5918 if error_name == ErrorIf.WrongOutputType:
5919 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005920 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005921 output_dtype = rng.choice(wrong_dtypes)
5922 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005923 output_shape[0] += rng.integers(1, 10)
5924 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5925 modify_dim = rng.choice([1, 2])
5926 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005927
5928 outputs.append(serializer.addOutput(output_shape, output_dtype))
5929 outputs.append(serializer.addOutput(output_shape, output_dtype))
5930 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005931
5932 @staticmethod
5933 def addShapeOp(ser, rng, a, b, error_name=None):
5934 if error_name != ErrorIf.RankMismatch:
5935 assert len(a.shape) == len(b.shape)
5936 assert a.dtype == b.dtype
5937
5938 shape = []
5939 for i in range(len(a.shape)):
5940 shape.append(a.shape[i])
5941
5942 fuzz_idx = rng.integers(0, len(a.shape))
5943 if error_name == ErrorIf.DimensionMismatch:
5944 shape[fuzz_idx] += 1
5945
5946 if error_name == ErrorIf.WrongOutputType:
5947 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5948 outputDType = rng.choice(wrong_dtypes)
5949 else:
5950 outputDType = DType.SHAPE
5951 return ser.addOutput(shape, outputDType)