blob: a03a763fd59b04065feee190b5886437f08d9d4a [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
Won Jeon2c34b462024-02-06 18:37:00 +000079 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010080 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Won Jeon2c34b462024-02-06 18:37:00 +0000155 if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Jerry Ge20ab3df2024-01-26 16:56:55 +0000198 elif dtype == DType.INT16:
199 return np.int16(self.rng.integers(low=low, high=high, size=shape))
200 elif dtype == DType.UINT16:
201 return np.uint16(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000202 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100203 return np.int64(self.rng.integers(low=low, high=high, size=shape))
Won Jeon2c34b462024-02-06 18:37:00 +0000204 elif dtype in (
205 DType.FP16,
206 DType.BF16,
207 DType.FP32,
208 DType.FP8E4M3,
209 DType.FP8E5M2,
210 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100211 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
212
213 if dtype == DType.FP16:
214 return np.float16(f_tensor)
215 else:
216 f32_tensor = np.float32(f_tensor)
217 if dtype == DType.BF16:
218 # Floor the last 16 bits of each f32 value
219 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
Won Jeon2c34b462024-02-06 18:37:00 +0000220 elif dtype == DType.FP8E4M3:
221 return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
222 elif dtype == DType.FP8E5M2:
223 return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
Jeremy Johnson1271c442023-09-05 11:39:26 +0100224 else:
225 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700226 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100227 # All other integer types
228 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
Kevin Cheng989cb052021-04-28 16:29:44 -0700230 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700231 placeholders = []
232
Kevin Cheng989cb052021-04-28 16:29:44 -0700233 assert len(shape_list) == len(dtype_list)
234
Jeremy Johnson1271c442023-09-05 11:39:26 +0100235 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700236 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100237 if not self.args.lazy_data_gen:
238 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700239 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700240
241 return placeholders
242
Kevin Cheng989cb052021-04-28 16:29:44 -0700243 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700244 consts = []
245
Kevin Cheng989cb052021-04-28 16:29:44 -0700246 assert len(shape_list) == len(dtype_list)
247
Jeremy Johnson1271c442023-09-05 11:39:26 +0100248 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700249 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100250 if not self.args.lazy_data_gen:
251 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700252 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700253
254 return consts
255
256 def makeShape(self, rank):
257 if self.targetted_shape:
258 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800259 return np.int32(
260 self.rng.integers(
261 low=self.args.tensor_shape_range[0],
262 high=self.args.tensor_shape_range[1],
263 size=rank,
264 )
265 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700266
267 def setTargetShape(self, shape):
268 self.targetted_shape = shape
269
270 def randInt(self, low=0, high=256):
271 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
272
273 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100274 low, high = self.getDTypeRange(dtype)
275
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100276 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100277 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100278 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100279 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100280 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100281 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
282 return gtu.vect_f32_to_bf16(rand_f32)
Won Jeon2c34b462024-02-06 18:37:00 +0000283 elif dtype == DType.FP8E4M3:
284 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
285 return gtu.vect_f32_to_fp8e4m3(rand_f32)
286 elif dtype == DType.FP8E5M2:
287 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
288 return gtu.vect_f32_to_fp8e5m2(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700289 elif dtype == DType.BOOL:
290 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000291 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 # Special size
293 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700294
295 return np.int32(self.rng.integers(low, high, size=1))[0]
296
297 def shapeStr(self, shape):
298
299 sStr = []
300 # Convert to strings
301 for i in shape:
302 sStr.append(str(i))
303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100306 def typeStr(self, dtype):
307 if isinstance(dtype, list) or isinstance(dtype, tuple):
308 assert len(dtype) >= 2
309 strs = [self.typeStr(t) for t in dtype]
310 # Limit types to the first 2 as the 3rd is the accumulator
311 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700312 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100313 if dtype in gtu.DTYPE_ATTRIBUTES:
314 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700315 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100316 raise Exception(
317 "Unknown dtype, cannot convert to string: {}".format(dtype)
318 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700319
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100320 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100321 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100322 if dtype in gtu.DTYPE_ATTRIBUTES:
323 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700324 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100325 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700326
Luke Hutton57287132023-02-06 14:54:18 +0000327 def constrictBatchSize(self, shape):
328 # Limit the batch size unless an explicit target shape set
329 if self.args.max_batch_size and not self.args.target_shapes:
330 shape[0] = min(shape[0], self.args.max_batch_size)
331 return shape
332
James Ward30124a82023-02-02 14:56:33 +0000333 def makeDimension(self):
334 return self.randInt(
335 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
336 )
337
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100338 def tensorComplianceMetaData(
339 self, op, inputType, argsDict, outputTensor, errorName
340 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000341 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
342 UNSUPPORTED_NON_FP32_INPUT_OPS = (
343 Op.MATMUL,
344 Op.CONV2D,
345 Op.FULLY_CONNECTED,
346 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000347 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000348 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000349 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100350 if (
351 errorName
352 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000353 or (
354 not gtu.dtypeIsSupportedByCompliance(inputType)
355 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
356 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100357 ):
358 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100359 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100360
Jeremy Johnson1271c442023-09-05 11:39:26 +0100361 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100362 compliance_tens = {
363 "mode": None,
364 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
365 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
366 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100367 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
368 mode = gtu.ComplianceMode.DOT_PRODUCT
369 compliance_tens["dot_product_info"] = {
370 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100371 "ks": int(argsDict["ksb"])
372 if "ksb" in argsDict
373 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100374 }
375 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
376 mode = gtu.ComplianceMode.FP_SPECIAL
377 elif "compliance" in op and "ulp" in op["compliance"]:
378 mode = gtu.ComplianceMode.ULP
379 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000380 elif "compliance" in op and "relative" in op["compliance"]:
381 mode = gtu.ComplianceMode.RELATIVE
382 compliance_tens["relative_info"] = {
383 "max": argsDict["max_abs_value"],
384 "scale": op["compliance"]["relative"],
385 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100386 elif op["op"] == Op.REDUCE_PRODUCT:
387 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000388 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000389 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000390 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000391 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
392 compliance_tens["abs_error_info"] = {
393 "lower_bound": op["compliance"]["abs_error_lower_bound"]
394 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100395 else:
396 mode = gtu.ComplianceMode.EXACT
397 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
398
399 return compliance_tens
400
401 # Build Op functions
402 # Create the output tensor (calling OutputShaper as needed)
403 # Do final tweaks to attributes (if necessary for errorIf)
404 # Add Op into graph
405 # Return resulting tensor information or BuildInfo
406
407 class BuildInfo:
408 """Enhanced build information containing result tensor and associated compliance dict."""
409
410 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000411 if isinstance(resultTensor, list):
412 assert complianceDict is None or isinstance(complianceDict, list)
413 self.resultTensorList = resultTensor
414 self.complianceDictList = complianceDict
415 else:
416 self.resultTensorList = [resultTensor]
417 if complianceDict is None:
418 self.complianceDictList = None
419 else:
420 self.complianceDictList = [complianceDict]
421
422 def getComplianceInfo(self):
423 if self.complianceDictList is None:
424 return None
425 else:
426 tens_dict = {}
427 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
428 if comp is not None:
429 tens_dict[tens.name] = comp
430
431 if tens_dict:
432 # Have some compliance data, so return the info
433 compliance = {
434 "version": "0.1",
435 "tensors": tens_dict,
436 }
437 else:
438 compliance = None
439 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700440
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000441 def build_unary(
442 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
443 ):
444 assert len(inputs) == 1
445 a = inputs[0]
446 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100447
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000448 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100449
450 # Ensure new output type has correct qinfo
451 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000452 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000453 qinfo = [
454 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000455 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000456 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100457
458 # Invalidate Input/Output list for error if checks.
459 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000460 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100461 pCount, cCount = op["operands"]
462 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000463 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
464 self, error_name, input_list, output_list
465 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100466
Les Bell729b0352021-11-24 10:28:21 +0000467 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100468 self.ser,
469 validator_fcns,
470 error_name,
471 op=op,
472 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000473 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000474 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000475 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100476 input_list=input_list,
477 output_list=output_list,
478 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000479 ):
480 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100481
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000482 attr = None
483 if op["op"] == Op.NEGATE:
484 attr = ts.TosaSerializerAttribute()
485 attr.NegateAttribute(qinfo[0], qinfo[1])
486
487 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000488
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000489 compliance = self.tensorComplianceMetaData(
490 op, a.dtype, args_dict, result_tensor, error_name
491 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000492 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700493
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000494 def build_binary_broadcast(
495 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
496 ):
497 assert len(inputs) == 2
498 a, b = inputs
499 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000500 self.ser, self.rng, a, b, error_name
501 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100502
503 # Invalidate Input/Output list for error if checks.
504 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000505 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100506 pCount, cCount = op["operands"]
507 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000508 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
509 self, error_name, input_list, output_list
510 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100511
Les Bell729b0352021-11-24 10:28:21 +0000512 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100513 self.ser,
514 validator_fcns,
515 error_name,
516 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000517 input1=a,
518 input2=b,
519 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000520 output_dtype=result_tensor.dtype,
521 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100522 input_list=input_list,
523 output_list=output_list,
524 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000525 ):
526 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100527
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000528 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000529
Jeremy Johnson9a758382023-11-07 16:27:35 +0000530 compliance = self.tensorComplianceMetaData(
531 op, a.dtype, args_dict, result_tensor, error_name
532 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000533
534 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700535
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100536 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700537 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000538 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700539 return result_tens
540
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000541 def build_arithmetic_right_shift(
Jeremy Johnson587cc842024-02-08 11:45:44 +0000542 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000543 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000544 assert len(inputs) == 2
545 a, b = inputs
546 round = args_dict["round"]
547 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000548 self.ser, self.rng, a, b, error_name
549 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100550
551 # Invalidate Input/Output list for error if checks.
552 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000553 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100554 pCount, cCount = op["operands"]
555 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000556 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
557 self, error_name, input_list, output_list
558 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100559
Les Bell729b0352021-11-24 10:28:21 +0000560 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100561 self.ser,
562 validator_fcns,
563 error_name,
564 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000565 input1=a,
566 input2=b,
567 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000568 output_dtype=result_tensor.dtype,
569 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100570 input_list=input_list,
571 output_list=output_list,
572 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000573 ):
574 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800575
576 attr = ts.TosaSerializerAttribute()
577 attr.ArithmeticRightShiftAttribute(round)
578
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000579 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000580
581 compliance = self.tensorComplianceMetaData(
582 op, a.dtype, args_dict, result_tensor, error_name
583 )
584
585 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800586
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100587 def build_mul(
588 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
589 ):
590 assert len(inputs) == 2
591 a, b = inputs
592 shift = args_dict["shift"]
593
594 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000595 self.ser, self.rng, a, b, error_name
596 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700597
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100598 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100599 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100600 result_tensor.setDtype(DType.INT32)
601
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100602 if error_name == ErrorIf.WrongOutputType:
603 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
604 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100605 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100606
607 # Invalidate Input/Output list for error if checks.
608 input_list = [a.name, b.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100609 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100610 pCount, cCount = op["operands"]
611 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000612 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
613 self, error_name, input_list, output_list
614 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100615
Les Bell729b0352021-11-24 10:28:21 +0000616 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100617 self.ser,
618 validator_fcns,
619 error_name,
620 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000621 input1=a,
622 input2=b,
623 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100624 output_dtype=result_tensor.dtype,
625 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100626 input_list=input_list,
627 output_list=output_list,
628 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000629 ):
630 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700631
Kevin Chengaee1fac2020-11-11 13:54:06 -0800632 attr = ts.TosaSerializerAttribute()
633 attr.MulAttribute(shift)
634
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000635 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100636
637 compliance = self.tensorComplianceMetaData(
638 op, a.dtype, args_dict, result_tensor, error_name
639 )
640
641 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700642
Jeremy Johnson587cc842024-02-08 11:45:44 +0000643 def build_table(
644 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
645 ):
646 assert len(inputs) == 1
647 a = inputs[0]
648 table = args_dict["table"]
649 result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700650
Kevin Chengfe392ce2021-10-18 21:51:55 +0000651 attr = ts.TosaSerializerAttribute()
652 attr.TableAttribute(table)
653
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100654 # Invalidate Input/Output list for error if checks.
655 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000656 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100657 pCount, cCount = op["operands"]
658 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000659 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
660 self, error_name, input_list, output_list
661 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100662
Les Bell729b0352021-11-24 10:28:21 +0000663 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100664 self.ser,
665 validator_fcns,
666 error_name,
667 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000668 input_shape=a.shape,
669 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000670 output_dtype=result_tensor.dtype,
671 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100672 input_list=input_list,
673 output_list=output_list,
674 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000675 ):
676 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100677
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000678 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700679
Jeremy Johnson587cc842024-02-08 11:45:44 +0000680 compliance = self.tensorComplianceMetaData(
681 op, a.dtype, args_dict, result_tensor, error_name
682 )
683
684 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700685
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000686 def build_select(
687 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
688 ):
689 assert len(inputs) == 3
690 cond, a, b = inputs
691
692 result_tensor = OutputShaper.selectOp(
693 self.ser, self.rng, cond, a, b, error_name
694 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100695
696 # Invalidate Input/Output list for error if checks.
697 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000698 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100699 pCount, cCount = op["operands"]
700 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000701 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
702 self, error_name, input_list, output_list
703 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100704
Les Bell729b0352021-11-24 10:28:21 +0000705 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100706 self.ser,
707 validator_fcns,
708 error_name,
709 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000710 input1=cond,
711 input2=a,
712 input3=b,
713 input_shape=a.shape,
714 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000715 output_dtype=result_tensor.dtype,
716 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100717 input_list=input_list,
718 output_list=output_list,
719 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000720 ):
721 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100722
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000723 self.ser.addOperator(
724 op["op"],
725 input_list,
726 output_list,
727 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000728 compliance = self.tensorComplianceMetaData(
729 op, a.dtype, args_dict, result_tensor, error_name
730 )
731
732 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700733
Jeremy Johnsona0150012023-11-15 15:52:06 +0000734 def build_comparison(
735 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
736 ):
737 assert len(inputs) == 2
738 a, b = inputs
739
740 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000741 self.ser, self.rng, a, b, error_name
742 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100743
744 # Invalidate Input/Output list for error if checks.
745 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000746 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100747 pCount, cCount = op["operands"]
748 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000749 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
750 self, error_name, input_list, output_list
751 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100752
Les Bell729b0352021-11-24 10:28:21 +0000753 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100754 self.ser,
755 validator_fcns,
756 error_name,
757 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000758 input1=a,
759 input2=b,
760 input_shape=a.shape,
761 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000762 output_shape=result_tensor.shape,
763 output_dtype=result_tensor.dtype,
764 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100765 input_list=input_list,
766 output_list=output_list,
767 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000768 ):
769 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100770
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000771 self.ser.addOperator(
772 op["op"],
773 input_list,
774 output_list,
775 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000776
777 compliance = self.tensorComplianceMetaData(
778 op, a.dtype, args_dict, result_tensor, error_name
779 )
780 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700781
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000782 def build_argmax(
783 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
784 ):
785 assert len(inputs) == 1
786 a = inputs[0]
787 axis = args_dict["axis"]
788 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100789
790 # Invalidate Input/Output list for error if checks.
791 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000792 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100793 pCount, cCount = op["operands"]
794 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000795 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
796 self, error_name, input_list, output_list
797 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100798
Les Bell729b0352021-11-24 10:28:21 +0000799 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100800 self.ser,
801 validator_fcns,
802 error_name,
803 op=op,
804 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000805 input_shape=a.shape,
806 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000807 output_shape=result_tensor.shape,
808 output_dtype=result_tensor.dtype,
809 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100810 input_list=input_list,
811 output_list=output_list,
812 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000813 ):
814 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700815
816 attr = ts.TosaSerializerAttribute()
817 attr.AxisAttribute(axis)
818
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000819 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000820
821 compliance = self.tensorComplianceMetaData(
822 op, inputs[0].dtype, args_dict, result_tensor, error_name
823 )
824 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700825
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000826 def build_pool2d(
827 self,
828 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100829 inputs,
830 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000831 validator_fcns=None,
832 error_name=None,
833 qinfo=None,
834 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100835 assert len(inputs) == 1
836 input = inputs[0]
837 # max_pool has no accum_dtype
838 accum_dtype = (
839 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
840 )
841 stride = args_dict["stride"]
842 pad = args_dict["pad"]
843 kernel = args_dict["kernel"]
844
Jeremy Johnson0601f802023-11-08 16:28:09 +0000845 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000846 self.ser, self.rng, input, kernel, stride, pad, error_name
847 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100848
849 # Ensure new output type has correct qinfo
850 if error_name == ErrorIf.WrongInputType:
851 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000852 qinfo = [
853 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000854 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000855 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100856
857 # Invalidate Input/Output list for error if checks.
858 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000859 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100860 pCount, cCount = op["operands"]
861 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000862 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
863 self, error_name, input_list, output_list
864 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100865
Les Bell729b0352021-11-24 10:28:21 +0000866 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100867 self.ser,
868 validator_fcns,
869 error_name,
870 op=op,
871 input_shape=input.shape,
872 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000873 output_shape=result_tensor.shape,
874 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000875 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100876 kernel=kernel,
877 stride=stride,
878 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000879 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000880 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100881 input_list=input_list,
882 output_list=output_list,
883 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000884 ):
885 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700886
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000887 if qinfo is None:
888 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700889
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000890 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100891 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000892
893 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700894
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100895 compliance = self.tensorComplianceMetaData(
896 op, inputs[0].dtype, args_dict, result_tensor, error_name
897 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100898
899 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100900
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000901 def build_conv2d(
902 self,
903 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100904 inputs,
905 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000906 validator_fcns=None,
907 error_name=None,
908 qinfo=None,
909 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100910 assert len(inputs) == 3
911 ifm, filter, bias = inputs
912 accum_dtype = args_dict["acc_type"]
913 strides = args_dict["stride"]
914 padding = args_dict["pad"]
915 dilations = args_dict["dilation"]
916
Kevin Cheng550ccc52021-03-03 11:21:43 -0800917 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100918 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100919 self.ser,
920 self.rng,
921 ifm,
922 filter,
923 accum_dtype,
924 strides,
925 padding,
926 dilations,
927 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000928 )
929
930 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000931 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
932 DType.INT8,
933 DType.UINT8,
934 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000935 qinfo = [
936 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100937 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000938 ]
Les Bell0e027d42021-11-09 14:42:14 +0000939
940 # Invalidate Input/Output list for error_if checks.
941 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100942 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000943 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000944 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
945 self, error_name, input_list, output_list
946 )
Les Bell0e027d42021-11-09 14:42:14 +0000947
Les Bell729b0352021-11-24 10:28:21 +0000948 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000949 self.ser,
950 validator_fcns,
951 error_name,
952 op=op,
953 input_dtype=ifm.dtype,
954 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100955 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000956 qinfo=qinfo,
957 input_list=input_list,
958 num_operands=num_operands,
959 output_list=output_list,
960 pad=padding,
961 stride=strides,
962 dilation=dilations,
963 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100964 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100965 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000966 ):
967 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700968
Tai Lyd3797f02023-11-15 23:06:19 +0000969 # TODO - Test local_bound, for now set local bound attribute to False
970 local_bound = False
971
Eric Kunzee5e26762020-10-13 16:11:07 -0700972 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000973 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700974
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000975 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100976
977 compliance = self.tensorComplianceMetaData(
978 op, ifm.dtype, args_dict, result_tensor, error_name
979 )
980
981 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700982
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000983 def build_conv3d(
984 self,
985 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100986 inputs,
987 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000988 validator_fcns=None,
989 error_name=None,
990 qinfo=None,
991 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100992 assert len(inputs) == 3
993 ifm, filter, bias = inputs
994 accum_dtype = args_dict["acc_type"]
995 strides = args_dict["stride"]
996 padding = args_dict["pad"]
997 dilations = args_dict["dilation"]
998
Kevin Cheng1533b852021-09-01 12:51:58 -0700999 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +00001000 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +01001001 self.ser,
1002 self.rng,
1003 ifm,
1004 filter,
1005 accum_dtype,
1006 strides,
1007 padding,
1008 dilations,
1009 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001010 )
1011
1012 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001013 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1014 DType.INT8,
1015 DType.UINT8,
1016 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001017 qinfo = [
1018 TosaQuantGen.getZeroPoint(self, ifm.dtype),
evacha0147ab1762024-01-29 13:23:23 +00001019 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001020 ]
Les Bell0e027d42021-11-09 14:42:14 +00001021
1022 # Invalidate Input/Output list for error_if checks.
1023 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +00001024 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001025 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001026 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1027 self, error_name, input_list, output_list
1028 )
Les Bell0e027d42021-11-09 14:42:14 +00001029
Les Bell729b0352021-11-24 10:28:21 +00001030 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001031 self.ser,
1032 validator_fcns,
1033 error_name,
1034 op=op,
1035 input_dtype=ifm.dtype,
1036 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +00001037 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001038 qinfo=qinfo,
1039 input_list=input_list,
1040 num_operands=num_operands,
1041 output_list=output_list,
1042 pad=padding,
1043 stride=strides,
1044 dilation=dilations,
1045 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001046 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +00001047 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001048 ):
1049 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07001050
Tai Lyd3797f02023-11-15 23:06:19 +00001051 # TODO - Test local_bound, for now set local bound attribute to False
1052 local_bound = False
1053
Kevin Cheng1533b852021-09-01 12:51:58 -07001054 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001055 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -07001056
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001057 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001058
1059 compliance = self.tensorComplianceMetaData(
1060 op, ifm.dtype, args_dict, result_tensor, error_name
1061 )
1062
1063 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001064
Kevin Cheng550ccc52021-03-03 11:21:43 -08001065 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001066 self,
1067 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001068 inputs,
1069 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001070 validator_fcns=None,
1071 error_name=None,
1072 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001073 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001074 assert len(inputs) == 3
1075 ifm, filter, bias = inputs
1076 accum_dtype = args_dict["acc_type"]
1077 strides = args_dict["stride"]
1078 out_pad = args_dict["pad"]
1079 output_shape = args_dict["out_shape"]
1080
TatWai Chong24594f52022-06-08 00:48:04 -07001081 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001082 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001083 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001084 )
Les Bell0e027d42021-11-09 14:42:14 +00001085
1086 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001087 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1088 DType.INT8,
1089 DType.UINT8,
1090 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001091 qinfo = [
1092 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001093 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001094 ]
Les Bell0e027d42021-11-09 14:42:14 +00001095
1096 # Invalidate Input/Output list for error_if checks.
1097 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001098 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001099 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001100 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1101 self, error_name, input_list, output_list
1102 )
Les Bell0e027d42021-11-09 14:42:14 +00001103
Les Bell729b0352021-11-24 10:28:21 +00001104 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001105 self.ser,
1106 validator_fcns,
1107 error_name,
1108 op=op,
1109 input_dtype=ifm.dtype,
1110 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001111 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001112 qinfo=qinfo,
1113 input_list=input_list,
1114 num_operands=num_operands,
1115 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001116 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001117 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001118 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001119 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001120 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001121 ):
1122 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001123
Tai Lyd3797f02023-11-15 23:06:19 +00001124 # TODO - Test local_bound, for now set local bound attribute to False
1125 local_bound = False
1126
Eric Kunzee5e26762020-10-13 16:11:07 -07001127 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001128 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001129 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001130 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001131
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001132 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001133
1134 compliance = self.tensorComplianceMetaData(
1135 op, ifm.dtype, args_dict, result_tensor, error_name
1136 )
1137
1138 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001139
Kevin Cheng550ccc52021-03-03 11:21:43 -08001140 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001141 self,
1142 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001143 inputs,
1144 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001145 validator_fcns=None,
1146 error_name=None,
1147 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001148 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001149 assert len(inputs) == 3
1150 ifm, filter, bias = inputs
1151 accum_dtype = args_dict["acc_type"]
1152 strides = args_dict["stride"]
1153 padding = args_dict["pad"]
1154 dilations = args_dict["dilation"]
1155
Jeremy Johnson4f931302024-01-04 17:05:24 +00001156 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001157 self.ser,
1158 self.rng,
1159 ifm,
1160 filter,
1161 accum_dtype,
1162 strides,
1163 padding,
1164 dilations,
1165 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001166 )
1167
1168 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001169 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1170 DType.INT8,
1171 DType.UINT8,
1172 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001173 qinfo = [
1174 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001175 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001176 ]
Les Bell0e027d42021-11-09 14:42:14 +00001177
1178 # Invalidate Input/Output list for error_if checks.
1179 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001180 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001181 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001182 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1183 self, error_name, input_list, output_list
1184 )
Les Bell0e027d42021-11-09 14:42:14 +00001185
Les Bell729b0352021-11-24 10:28:21 +00001186 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001187 self.ser,
1188 validator_fcns,
1189 error_name,
1190 op=op,
1191 input_dtype=ifm.dtype,
1192 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001193 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001194 qinfo=qinfo,
1195 input_list=input_list,
1196 num_operands=num_operands,
1197 output_list=output_list,
1198 pad=padding,
1199 stride=strides,
1200 dilation=dilations,
1201 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001202 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001203 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001204 ):
1205 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001206
Tai Lyd3797f02023-11-15 23:06:19 +00001207 # TODO - Test local_bound, for now set local bound attribute to False
1208 local_bound = False
1209
Eric Kunzee5e26762020-10-13 16:11:07 -07001210 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001211 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001212
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001213 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001214
1215 compliance = self.tensorComplianceMetaData(
1216 op, ifm.dtype, args_dict, result_tensor, error_name
1217 )
1218
1219 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001220
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001221 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001222 self,
1223 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001224 inputs,
1225 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001226 validator_fcns=None,
1227 error_name=None,
1228 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001229 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001230 assert len(inputs) == 3
1231 ifm, filter, bias = inputs
1232 accum_dtype = args_dict["acc_type"]
1233
1234 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001235 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001236 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001237
1238 # Invalidate Input/Output list for error if checks.
1239 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001240 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001241 pCount, cCount = op["operands"]
1242 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001243 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1244 self, error_name, input_list, output_list
1245 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001246
Les Bell729b0352021-11-24 10:28:21 +00001247 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001248 self.ser,
1249 validator_fcns,
1250 error_name,
1251 op=op,
1252 input_shape=ifm.shape,
1253 input_dtype=ifm.dtype,
1254 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001255 output_shape=result_tensor.shape,
1256 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001257 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001258 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001259 input_list=input_list,
1260 output_list=output_list,
1261 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001262 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001263 ):
1264 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001265
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001266 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001267 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001268
1269 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001270
1271 compliance = self.tensorComplianceMetaData(
1272 op, ifm.dtype, args_dict, result_tensor, error_name
1273 )
1274
1275 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001276
James Ward8b390432022-08-12 20:48:56 +01001277 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001278 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001279 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001280 assert len(inputs) == 2
1281 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001282 accum_dtype = args_dict["acc_type"]
1283 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001284 self.ser, self.rng, a, b, accum_dtype, error_name
1285 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001286
1287 # Invalidate Input/Output list for error if checks.
1288 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001289 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001290 pCount, cCount = op["operands"]
1291 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001292 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1293 self, error_name, input_list, output_list
1294 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001295
Les Bell729b0352021-11-24 10:28:21 +00001296 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001297 self.ser,
1298 validator_fcns,
1299 error_name,
1300 op=op,
1301 input_shape=a.shape,
1302 input_dtype=a.dtype,
1303 input2_shape=b.shape,
1304 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001305 output_shape=result_tensor.shape,
1306 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001307 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001308 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001309 input_list=input_list,
1310 output_list=output_list,
1311 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001312 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001313 ):
1314 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001315
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001316 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001317 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001318
1319 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001320
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001321 compliance = self.tensorComplianceMetaData(
1322 op, a.dtype, args_dict, result_tensor, error_name
1323 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001324
1325 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001326
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001327 def build_reduce(
1328 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1329 ):
1330 assert len(inputs) == 1
1331 a = inputs[0]
1332 axis = args_dict["axis"]
1333 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001334
1335 # Invalidate Input/Output list for error if checks.
1336 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001337 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001338 pCount, cCount = op["operands"]
1339 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001340 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1341 self, error_name, input_list, output_list
1342 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001343
Les Bell729b0352021-11-24 10:28:21 +00001344 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001345 self.ser,
1346 validator_fcns,
1347 error_name,
1348 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001349 axis=axis,
1350 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001351 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001352 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001353 output_dtype=result_tensor.dtype,
1354 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001355 input_list=input_list,
1356 output_list=output_list,
1357 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001358 ):
1359 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001360
1361 attr = ts.TosaSerializerAttribute()
1362 attr.AxisAttribute(axis)
1363
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001364 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001365
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001366 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1367 # Number of products - needed for compliance
1368 args_dict["n"] = a.shape[axis]
1369
1370 compliance = self.tensorComplianceMetaData(
1371 op, a.dtype, args_dict, result_tensor, error_name
1372 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001373
1374 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001375
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001376 def build_clamp(
1377 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1378 ):
1379 assert len(inputs) == 1
1380 a = inputs[0]
1381
1382 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001383
Jeremy Johnson18e26662021-07-22 16:15:29 +01001384 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001385
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001386 if error_name == ErrorIf.MaxSmallerMin:
1387 # Make sure the numbers are different to invoke this error
1388 while v[0] == v[1]:
1389 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1390 max_val = min(v)
1391 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001392 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001393 max_val = max(v)
1394 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001395
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001396 # Invalidate Input/Output list for error if checks.
1397 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001398 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001399 pCount, cCount = op["operands"]
1400 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001401 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1402 self, error_name, input_list, output_list
1403 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001404
Les Bell729b0352021-11-24 10:28:21 +00001405 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001406 self.ser,
1407 validator_fcns,
1408 error_name,
1409 op=op,
1410 max_val=max_val,
1411 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001412 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001413 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001414 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001415 output_dtype=result_tensor.dtype,
1416 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001417 input_list=input_list,
1418 output_list=output_list,
1419 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001420 ):
1421 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001422
1423 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001424 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1425 if a.dtype == DType.FP16:
1426 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1427 min_val = min_val.astype(np.float32)
1428 max_val = max_val.astype(np.float32)
1429
1430 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001431 elif a.dtype in (DType.INT8, DType.INT16):
James Ward34071252022-12-07 15:48:47 +00001432 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Won Jeon2c34b462024-02-06 18:37:00 +00001433 else:
1434 # to avoid internal error for incorrect input types
1435 attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001436
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001437 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001438
1439 compliance = self.tensorComplianceMetaData(
1440 op, a.dtype, args_dict, result_tensor, error_name
1441 )
1442
1443 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001444
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001445 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1446 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001447 attr = ts.TosaSerializerAttribute()
1448
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001449 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001450
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001451 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001452 return result_tens
1453
1454 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001455 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1456 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001457
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001458 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001459 return result_tens
1460
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001461 def build_activation(
1462 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1463 ):
1464 assert len(inputs) == 1
1465 a = inputs[0]
1466
1467 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001468
1469 # Invalidate Input/Output list for error if checks.
1470 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001471 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001472 pCount, cCount = op["operands"]
1473 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001474 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1475 self, error_name, input_list, output_list
1476 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001477
Les Bell729b0352021-11-24 10:28:21 +00001478 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001479 self.ser,
1480 validator_fcns,
1481 error_name,
1482 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001483 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001484 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001485 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001486 output_dtype=result_tensor.dtype,
1487 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001488 input_list=input_list,
1489 output_list=output_list,
1490 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001491 ):
1492 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001493
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001494 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001495
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001496 compliance = self.tensorComplianceMetaData(
1497 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001498 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001499
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001500 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001501
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001502 def build_concat(
1503 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1504 ):
Won Jeon74342e52024-01-09 00:34:40 +00001505 if op["op"] == Op.CONCAT_SHAPE:
1506 axis = 0
1507 else:
1508 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001509 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001510 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001511
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001512 result_tensor = OutputShaper.concatOp(
1513 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001514 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001515
Matthew Haddon818ab902021-07-27 09:12:49 +01001516 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001517 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001518 input_tensor_names.append(tensor.name)
1519
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001520 # Invalidate Input/Output list for error if checks.
1521 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001522 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001523 pCount, cCount = op["operands"]
1524 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001525 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1526 self, error_name, input_list, output_list
1527 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001528
Les Bell729b0352021-11-24 10:28:21 +00001529 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001530 self.ser,
1531 validator_fcns,
1532 error_name,
1533 op=op,
1534 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001535 input_shape=inputs[0].shape,
1536 output_shape=result_tensor.shape,
1537 input_dtype=inputs[0].dtype,
1538 output_dtype=result_tensor.dtype,
1539 inputs=inputs,
1540 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001541 input_list=input_list,
1542 output_list=output_list,
1543 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001544 ):
1545 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001546
Won Jeon74342e52024-01-09 00:34:40 +00001547 if op["op"] == Op.CONCAT:
1548 attr = ts.TosaSerializerAttribute()
1549 attr.AxisAttribute(axis)
1550 else:
1551 assert op["op"] == Op.CONCAT_SHAPE
1552 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001553 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001554
1555 compliance = self.tensorComplianceMetaData(
1556 op, inputs[0].dtype, args_dict, result_tensor, error_name
1557 )
1558
1559 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001560
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001561 def build_pad(
1562 self,
1563 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001564 inputs,
1565 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001566 validator_fcns=None,
1567 error_name=None,
1568 qinfo=None,
1569 ):
Tai Lye095da72024-01-25 22:00:18 +00001570 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001571 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001572 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001573 padding = args_dict["pad"]
1574 pad_const_int = args_dict["pad_const_int"]
1575 pad_const_float = args_dict["pad_const_fp"]
1576
1577 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001578
Tai Lye095da72024-01-25 22:00:18 +00001579 # write empty padding into PadAttribute to ensure inputs[1] is used
Kevin Chengfe392ce2021-10-18 21:51:55 +00001580 attr = ts.TosaSerializerAttribute()
Tai Lye095da72024-01-25 22:00:18 +00001581 attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001582
Matthew Haddone807aae2021-10-11 18:12:58 +01001583 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001584 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001585 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001586 pCount, cCount = op["operands"]
1587 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001588 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1589 self, error_name, input_list, output_list
1590 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001591
Les Bell729b0352021-11-24 10:28:21 +00001592 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001593 self.ser,
1594 validator_fcns,
1595 error_name,
1596 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001597 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001598 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001599 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001600 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001601 pad=padding,
1602 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001603 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001604 input_list=input_list,
1605 output_list=output_list,
1606 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001607 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001608 ):
1609 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001610
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001611 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001612
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001613 compliance = self.tensorComplianceMetaData(
1614 op, a.dtype, args_dict, result_tensor, error_name
1615 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001616
1617 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001618
Won Jeona21b2e82023-08-10 10:33:01 +00001619 def build_dim(
1620 self,
1621 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001622 inputs,
1623 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001624 validator_fcns=None,
1625 error_name=None,
1626 qinfo=None,
1627 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001628 assert len(inputs) == 1
1629 a = inputs[0]
1630 axis = args_dict["axis"]
1631 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001632
1633 # Invalidate Input/Output list for error if checks.
1634 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001635 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001636 pCount, cCount = op["operands"]
1637 num_operands = pCount + cCount
1638 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1639 self, error_name, input_list, output_list
1640 )
1641
1642 if not TosaErrorValidator.evValidateErrorIfs(
1643 self.ser,
1644 validator_fcns,
1645 error_name,
1646 op=op,
1647 axis=axis,
1648 input_shape=a.shape,
1649 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001650 output_shape=result_tensor.shape,
1651 output_dtype=result_tensor.dtype,
1652 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001653 input_list=input_list,
1654 output_list=output_list,
1655 num_operands=num_operands,
1656 ):
1657 return None
1658
1659 attr = ts.TosaSerializerAttribute()
1660 attr.AxisAttribute(axis)
1661
1662 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001663 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001664
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001665 def build_reshape(
1666 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1667 ):
Tai Ly8690a082023-12-18 20:40:24 +00001668 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001669 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001670 shape = inputs[1]
1671 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001672 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001673 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001674 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001675
1676 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001677 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001678 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001679 pCount, cCount = op["operands"]
1680 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001681 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1682 self, error_name, input_list, output_list
1683 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001684
Les Bell729b0352021-11-24 10:28:21 +00001685 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001686 self.ser,
1687 validator_fcns,
1688 error_name,
1689 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001690 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001691 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001692 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001693 output_dtype=result_tensor.dtype,
1694 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001695 input_list=input_list,
1696 output_list=output_list,
1697 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001698 ):
1699 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001700
Tai Ly8690a082023-12-18 20:40:24 +00001701 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001702
1703 compliance = self.tensorComplianceMetaData(
1704 op, a.dtype, args_dict, result_tensor, error_name
1705 )
1706
1707 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001708
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001709 def build_reverse(
1710 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1711 ):
1712 assert len(inputs) == 1
1713 a = inputs[0]
1714 axis = args_dict["axis"]
1715 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001716
1717 # Invalidate Input/Output list for error if checks.
1718 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001719 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001720 pCount, cCount = op["operands"]
1721 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001722 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1723 self, error_name, input_list, output_list
1724 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001725
Les Bell729b0352021-11-24 10:28:21 +00001726 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001727 self.ser,
1728 validator_fcns,
1729 error_name,
1730 op=op,
1731 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001732 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001733 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001734 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001735 output_dtype=result_tensor.dtype,
1736 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001737 input_list=input_list,
1738 output_list=output_list,
1739 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001740 ):
1741 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001742
1743 attr = ts.TosaSerializerAttribute()
1744 attr.AxisAttribute(axis)
1745
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001746 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001747 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001748
evacha0198477222024-01-26 12:25:32 +00001749 def build_transpose(
1750 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1751 ):
1752 assert len(inputs) == 1
1753 a = inputs[0]
1754 perms = args_dict["perms"]
1755
1756 result_tensor = OutputShaper.transposeOp(
1757 self.ser, self.rng, a, perms, error_name
1758 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001759
Kevin Chengfe392ce2021-10-18 21:51:55 +00001760 attr = ts.TosaSerializerAttribute()
1761 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001762
Matthew Haddone807aae2021-10-11 18:12:58 +01001763 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001764 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001765 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001766 pCount, cCount = op["operands"]
1767 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001768 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1769 self, error_name, input_list, output_list
1770 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001771
Les Bell729b0352021-11-24 10:28:21 +00001772 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001773 self.ser,
1774 validator_fcns,
1775 error_name,
1776 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001777 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001778 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001779 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001780 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001781 output_dtype=result_tensor.dtype,
1782 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001783 input_list=input_list,
1784 output_list=output_list,
1785 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001786 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001787 ):
1788 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001789
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001790 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001791
1792 compliance = self.tensorComplianceMetaData(
1793 op, a.dtype, args_dict, result_tensor, error_name
1794 )
1795
1796 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001797
evacha017f7d4252024-01-24 12:08:09 +00001798 def build_slice(
1799 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1800 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001801 assert len(inputs) == 3
1802 a, start_var, size_var = inputs
1803 start_const = args_dict["start"]
1804 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001805
1806 result_tensor = OutputShaper.sliceOp(
TatWai Chongf15bad82024-01-31 21:33:27 -08001807 self.ser, self.rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001808 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001809
1810 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001811 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001812 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001813 pCount, cCount = op["operands"]
1814 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001815 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1816 self, error_name, input_list, output_list
1817 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001818
Les Bell729b0352021-11-24 10:28:21 +00001819 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001820 self.ser,
1821 validator_fcns,
1822 error_name,
1823 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001824 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001825 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001826 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001827 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001828 start=start_const,
1829 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001830 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001831 input_list=input_list,
1832 output_list=output_list,
1833 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001834 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001835 ):
1836 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001837
TatWai Chongf15bad82024-01-31 21:33:27 -08001838 # TODO remove the slice attribute once shape dynamism support is mature.
Eric Kunzee5e26762020-10-13 16:11:07 -07001839 attr = ts.TosaSerializerAttribute()
TatWai Chongf15bad82024-01-31 21:33:27 -08001840 attr.SliceAttribute(start_const, size_const)
Eric Kunzee5e26762020-10-13 16:11:07 -07001841
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001842 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha017f7d4252024-01-24 12:08:09 +00001843
1844 compliance = self.tensorComplianceMetaData(
1845 op, a.dtype, args_dict, result_tensor, error_name
1846 )
1847
1848 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001849
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001850 def build_tile(
1851 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1852 ):
Tai Ly8690a082023-12-18 20:40:24 +00001853 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001854 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001855 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001856 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001857 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001858 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001859 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001860
1861 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001862 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001863 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001864 pCount, cCount = op["operands"]
1865 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001866 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1867 self, error_name, input_list, output_list
1868 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001869
Les Bell729b0352021-11-24 10:28:21 +00001870 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001871 self.ser,
1872 validator_fcns,
1873 error_name,
1874 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001875 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001876 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001877 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001878 output_dtype=result_tensor.dtype,
1879 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001880 input_list=input_list,
1881 output_list=output_list,
1882 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001883 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001884 ):
1885 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001886
Tai Ly8690a082023-12-18 20:40:24 +00001887 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001888
1889 compliance = self.tensorComplianceMetaData(
1890 op, a.dtype, args_dict, result_tensor, error_name
1891 )
1892
1893 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001894
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001895 def build_gather(
1896 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1897 ):
1898 assert len(inputs) == 2
1899 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001900
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001901 result_tensor = OutputShaper.gatherOp(
1902 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001903 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001904
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001905 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001906 input_list = [values.name, indices.name]
1907 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001908 pCount, cCount = op["operands"]
1909 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001910 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1911 self, error_name, input_list, output_list
1912 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001913
Les Bell729b0352021-11-24 10:28:21 +00001914 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001915 self.ser,
1916 validator_fcns,
1917 error_name,
1918 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001919 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001920 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001921 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001922 output_dtype=result_tensor.dtype,
1923 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001924 input_list=input_list,
1925 output_list=output_list,
1926 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001927 ):
1928 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001929
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001930 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001931
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001932 compliance = self.tensorComplianceMetaData(
1933 op, values.dtype, args_dict, result_tensor, error_name
1934 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001935
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001936 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001937
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001938 def build_scatter(
1939 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1940 ):
1941 assert len(inputs) == 3
1942 values_in, indices, input = inputs
1943 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001944 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001945 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001946
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001947 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001948 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001949 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001950 pCount, cCount = op["operands"]
1951 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001952 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1953 self, error_name, input_list, output_list
1954 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001955
Les Bell729b0352021-11-24 10:28:21 +00001956 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001957 self.ser,
1958 validator_fcns,
1959 error_name,
1960 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001961 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001962 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001963 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001964 output_dtype=result_tensor.dtype,
1965 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001966 input_list=input_list,
1967 output_list=output_list,
1968 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001969 ):
1970 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001971
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001972 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001973
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001974 compliance = self.tensorComplianceMetaData(
1975 op, values_in.dtype, args_dict, result_tensor, error_name
1976 )
1977
1978 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001979
Kevin Cheng550ccc52021-03-03 11:21:43 -08001980 def build_resize(
1981 self,
1982 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001983 inputs,
1984 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01001985 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001986 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001987 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001988 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001989 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001990 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001991 scale_input = inputs[1]
1992 offset_input = inputs[2]
1993 border_input = inputs[3]
1994
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001995 mode = args_dict["mode"]
1996 scale = args_dict["scale"]
1997 offset = args_dict["offset"]
1998 border = args_dict["border"]
1999 output_dtype = args_dict["output_dtype"]
2000
2001 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08002002 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002003 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002004 input,
2005 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002006 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002007 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002008 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002009 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002010 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002011 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002012 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002013
Matthew Haddon848efb42021-09-09 12:30:53 +01002014 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002015 input_list = [
2016 input.name,
2017 scale_input.name,
2018 offset_input.name,
2019 border_input.name,
2020 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002021 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002022 pCount, cCount = op["operands"]
2023 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002024 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2025 self, error_name, input_list, output_list
2026 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002027
Les Bell729b0352021-11-24 10:28:21 +00002028 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002029 self.ser,
2030 validator_fcns,
2031 error_name,
2032 op=op,
2033 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002034 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002035 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002036 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002037 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002038 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002039 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002040 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002041 input_list=input_list,
2042 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002043 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002044 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002045 ):
2046 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002047
Eric Kunzee5e26762020-10-13 16:11:07 -07002048 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002049 # write empty scale/offset/border into ResizeAttribute
2050 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002051 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002052
2053 compliance = self.tensorComplianceMetaData(
2054 op, input.dtype, args_dict, result_tensor, error_name
2055 )
2056
2057 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002058
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002059 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
2060 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
2061 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002062 self.ser.addOperator(
2063 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2064 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002065 return result_tens
2066
evacha0198477222024-01-26 12:25:32 +00002067 def build_const(
2068 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2069 ):
2070 assert len(inputs) == 1
2071 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002072 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002073
2074 compliance = self.tensorComplianceMetaData(
2075 op, val.dtype, args_dict, val, error_name
2076 )
2077
2078 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002079
2080 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002081 def build_cast(
2082 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2083 ):
2084 assert len(inputs) == 1
2085 val = inputs[0]
2086 out_dtype = args_dict["out_type"]
2087
2088 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002089 self.ser, self.rng, val, out_dtype, error_name
2090 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002091
2092 # Invalidate Input/Output list for error if checks.
2093 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002094 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002095 pCount, cCount = op["operands"]
2096 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002097 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2098 self, error_name, input_list, output_list
2099 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002100
Les Bell729b0352021-11-24 10:28:21 +00002101 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002102 self.ser,
2103 validator_fcns,
2104 error_name,
2105 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002106 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002107 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002108 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002109 output_dtype=result_tensor.dtype,
2110 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002111 input_list=input_list,
2112 output_list=output_list,
2113 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002114 ):
2115 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002116
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002117 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002118
2119 compliance = self.tensorComplianceMetaData(
2120 op, val.dtype, args_dict, result_tensor, error_name
2121 )
2122
2123 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002124
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002125 def build_rescale(
2126 self,
2127 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002128 inputs,
2129 args_dict,
2130 validator_fcns=None,
2131 error_name=None,
2132 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002133 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002134 assert len(inputs) == 1
2135 val = inputs[0]
2136 out_dtype = args_dict["output_dtype"]
2137 scale32 = args_dict["scale"]
2138 double_round = args_dict["double_round"]
2139 per_channel = args_dict["per_channel"]
2140
2141 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002142 self.ser, self.rng, val, out_dtype, error_name
2143 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002144
2145 if per_channel:
2146 nc = val.shape[-1]
2147 else:
2148 nc = 1
2149
2150 in_type_width = self.typeWidth(val.dtype)
2151 out_type_width = self.typeWidth(out_dtype)
2152
Tai Ly8690a082023-12-18 20:40:24 +00002153 input_unsigned = False
2154 output_unsigned = False
2155
Kevin Cheng3a478572021-01-22 17:21:02 -08002156 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002157 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002158 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002159 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002160 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002161 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002162 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002163 elif error_name in [
2164 ErrorIf.InputZeroPointNotZero,
2165 ErrorIf.U16InputZeroPointNotValid,
2166 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002167 input_zp = self.randInt(-128, 128)
2168 if input_zp == 0:
2169 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002170 in_type_width += 1
2171 elif val.dtype == DType.UINT16:
2172 # Must come after ErrorIf.U16InputZeroPointNotValid check
2173 input_zp = self.rng.choice([0, 32768])
2174 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002175 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002176 else:
2177 input_zp = 0
2178
Kevin Cheng3a478572021-01-22 17:21:02 -08002179 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002180 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002181 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002182 elif out_dtype == DType.UINT8:
2183 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002184 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002185 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002186 elif error_name in [
2187 ErrorIf.OutputZeroPointNotZero,
2188 ErrorIf.U16OutputZeroPointNotValid,
2189 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002190 output_zp = self.randInt(-128, 128)
2191 if output_zp == 0:
2192 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002193 out_type_width += 1
2194 elif out_dtype == DType.UINT16:
2195 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2196 output_zp = self.rng.choice([0, 32768])
2197 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002198 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002199 else:
2200 output_zp = 0
2201
2202 # Calculate scale based on:
2203 # scale = a *(2^output_width)/(2^input_width))
2204
2205 a = np.float32(self.rng.random(size=[nc]))
2206 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2207
2208 if scale32:
2209 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002210 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002211 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2212 else:
2213 # Cap the scaling at 2^15 - 1 for scale16
2214 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2215
Kevin Cheng550ccc52021-03-03 11:21:43 -08002216 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002217
2218 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2219 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002220 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2221 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002222
2223 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002224 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2225 scale_arr[i], scale32
2226 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002227 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2228 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002229
Kevin Cheng550ccc52021-03-03 11:21:43 -08002230 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002231 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002232 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002233 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002234 assert val.placeholderFilename
2235 values = np.load(
2236 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2237 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002238 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2239 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2240 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002241 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2242 # Check we can safely convert to the expected dtype
2243 assert (
2244 val_adj.all() >= np.iinfo(values.dtype).min
2245 and val_adj.all() <= np.iinfo(values.dtype).max
2246 )
2247
2248 # Force casting to output datatype
2249 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2250
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002251 if not np.all(np.array_equal(values, val_adj)):
2252 # Values changed so overwrite file with new values
2253 np.save(
2254 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2255 val_adj,
2256 False,
2257 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002258
Matthew Haddonc2025212021-10-08 21:21:05 +01002259 # Invalidate Input/Output list for error if checks.
2260 input_list = [val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002261 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002262 pCount, cCount = op["operands"]
2263 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002264 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2265 self, error_name, input_list, output_list
2266 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002267
2268 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002269 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002270 self.ser,
2271 validator_fcns,
2272 error_name,
2273 op=op,
2274 input_dtype=val.dtype,
2275 output_dtype=out_dtype,
2276 input_shape=val.shape,
2277 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002278 scale32=scale32,
2279 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002280 input_list=input_list,
2281 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002282 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002283 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002284 ):
2285 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002286
Eric Kunzee5e26762020-10-13 16:11:07 -07002287 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002288 attr.RescaleAttribute(
2289 input_zp,
2290 output_zp,
2291 multiplier_arr,
2292 shift_arr,
2293 scale32,
2294 double_round,
2295 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002296 input_unsigned,
2297 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002298 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002299
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002300 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002301
2302 compliance = self.tensorComplianceMetaData(
2303 op, val.dtype, args_dict, result_tensor, error_name
2304 )
2305
2306 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002307
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002308 def _get_condition_tensor(self, op, cond, error_name):
2309 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002310 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002311 else:
2312 cond_type = DType.BOOL
2313 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2314 choice = self.rng.choice([1, 2])
2315 if choice == 1:
2316 cond_shape = [2]
2317 else:
2318 cond_shape = [1, 2]
2319 else:
2320 # Must be of size 1 (rank 0)
2321 cond_shape = []
2322 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2323 return cond_tens
2324
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002325 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002326 self,
2327 op,
2328 inputs,
2329 args_dict,
2330 validator_fcns=None,
2331 error_name=None,
2332 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002333 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002334 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002335 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002336 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002337 assert len(inputs) == 2
2338 then_tens, else_tens = inputs
2339
2340 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002341
2342 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002343 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002344
2345 # Make then/else tensors
2346 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002347
Jeremy Johnson587cc842024-02-08 11:45:44 +00002348 dtype = DType.INT32
2349
Matthew Haddon630c17c2021-10-14 15:05:41 +01002350 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002351 if error_name in [
2352 ErrorIf.CondIfOutputListThenGraphMismatch,
2353 ErrorIf.CondIfOutputListElseGraphMismatch,
2354 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002355 incorrect_shape = deepcopy(then_tens.shape)
2356 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002357 incorrect_shape[i] += (
2358 self.rng.choice([-3, -2, 2, 3])
2359 if incorrect_shape[i] > 3
2360 else self.rng.choice([1, 2, 4])
2361 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002362 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2363
Jeremy Johnson18e26662021-07-22 16:15:29 +01002364 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2365 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002366
2367 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002368 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002369
2370 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002371 then_block = "THEN_BLOCK"
2372 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002373 attr = ts.TosaSerializerAttribute()
2374 attr.CondIfAttribute(then_block, else_block)
2375
2376 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002377 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002378
Jerry Ge9e94af82022-10-27 09:57:00 -07002379 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002380 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002381 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002382 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002383 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002384 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002385 self.ser.addOutputTensor(then_tens)
2386
Jerry Ge9e94af82022-10-27 09:57:00 -07002387 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002388 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002389 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002390 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002391 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002392 self.ser.addOutputTensor(else_tens)
2393
Les Bell729b0352021-11-24 10:28:21 +00002394 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002395 self.ser,
2396 validator_fcns,
2397 error_name,
2398 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002399 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002400 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002401 ):
2402 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002403
Jeremy Johnson587cc842024-02-08 11:45:44 +00002404 compliance = self.tensorComplianceMetaData(
2405 op, dtype, args_dict, result_tensor, error_name
2406 )
2407
2408 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002409
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002410 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002411 self,
2412 op,
2413 inputs,
2414 args_dict,
2415 validator_fcns=None,
2416 error_name=None,
2417 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002418 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002419 # For cond_if with a binary op in the then/else blocks, take a and b and
2420 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002421 assert len(inputs) == 2
2422 a, b = inputs
2423
2424 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002425
2426 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002427 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002428
Jeremy Johnson587cc842024-02-08 11:45:44 +00002429 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002430
2431 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002432 then_block = "THEN_BLOCK"
2433 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002434 attr = ts.TosaSerializerAttribute()
2435 attr.CondIfAttribute(then_block, else_block)
2436
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002437 if error_name in [
2438 ErrorIf.CondIfInputListThenGraphMismatch,
2439 ErrorIf.CondIfInputListElseGraphMismatch,
2440 ErrorIf.CondIfOutputListElseGraphMismatch,
2441 ErrorIf.CondIfOutputListThenGraphMismatch,
2442 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002443 incorrect_shape = a.shape.copy()
2444 for i in range(len(incorrect_shape)):
2445 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2446 incorrect_block_input = deepcopy(a)
2447 incorrect_block_input.shape = incorrect_shape
2448
Eric Kunzee5e26762020-10-13 16:11:07 -07002449 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002450 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002451 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002452 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002453
James Ward24dbc422022-10-19 12:20:31 +01002454 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002455 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002456 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002457 then_op, else_op = (
2458 self.TOSA_OP_LIST["logical_right_shift"],
2459 self.TOSA_OP_LIST["logical_left_shift"],
2460 )
Les Bell6040b4d2021-10-11 12:50:31 +01002461 else:
2462 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002463
Jeremy Johnson587cc842024-02-08 11:45:44 +00002464 # Determine the element-wise binary operation that compliance will need to
2465 # check the results of
2466 compliance_op = then_op if cond else else_op
2467
2468 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002469 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002470 if (
2471 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2472 and block == then_block
2473 ) or (
2474 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2475 and block == else_block
2476 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002477 self.ser.addInputTensor(incorrect_block_input)
2478 self.ser.addInputTensor(b)
2479 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002480 elif (
2481 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2482 and block == then_block
2483 ) or (
2484 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2485 and block == else_block
2486 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002487 self.ser.addInputTensor(a)
2488 self.ser.addInputTensor(b)
2489 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2490 else:
2491 self.ser.addInputTensor(a)
2492 self.ser.addInputTensor(b)
2493 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002494 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002495
Les Bell729b0352021-11-24 10:28:21 +00002496 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002497 self.ser,
2498 validator_fcns,
2499 error_name,
2500 op=op,
2501 a=a,
2502 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002503 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002504 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002505 ):
2506 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002507
Jeremy Johnson587cc842024-02-08 11:45:44 +00002508 compliance = self.tensorComplianceMetaData(
2509 compliance_op, a.dtype, args_dict, result_tensor, error_name
2510 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002511
Jeremy Johnson587cc842024-02-08 11:45:44 +00002512 return TosaTestGen.BuildInfo(result_tensor, compliance)
2513
2514 def build_while_loop(
2515 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2516 ):
2517 assert len(inputs) == 1
2518 a = inputs[0]
2519 iter_val = args_dict["iterations"]
2520
Kevin Cheng550ccc52021-03-03 11:21:43 -08002521 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002522
Kevin Cheng550ccc52021-03-03 11:21:43 -08002523 cond_block = "COND_BLOCK"
2524 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002525
2526 attr = ts.TosaSerializerAttribute()
2527 attr.WhileLoopAttribute(cond_block, body_block)
2528
2529 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002530 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002531 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002532 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002533
2534 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002535 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2536 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002537 if error_name == ErrorIf.InputListOutputListMismatch:
2538 incorrect_acc = deepcopy(acc)
2539 for i in range(len(incorrect_acc.shape)):
2540 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2541 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2542 else:
2543 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002544
2545 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002546 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002547 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002548 [iter.name, a.name, acc.name],
2549 [iter_out.name, a_out.name, acc_out.name],
2550 attr,
2551 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002552 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002553
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002554 if error_name in [
2555 ErrorIf.InputListCondGraphMismatch,
2556 ErrorIf.InputListBodyGraphInputMismatch,
2557 ErrorIf.InputListBodyGraphOutputMismatch,
2558 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002559 incorrect_iter = deepcopy(iter)
2560 for i in range(len(incorrect_iter.shape)):
2561 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2562 if len(incorrect_iter.shape) == 0:
2563 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2564
2565 incorrect_acc = deepcopy(acc)
2566 for i in range(len(incorrect_acc.shape)):
2567 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2568
Eric Kunzee5e26762020-10-13 16:11:07 -07002569 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002570 self.ser.addBasicBlock(cond_block)
2571
Matthew Haddon630c17c2021-10-14 15:05:41 +01002572 if error_name == ErrorIf.InputListCondGraphMismatch:
2573 self.ser.addInputTensor(incorrect_iter)
2574 self.ser.addInputTensor(a)
2575 self.ser.addInputTensor(incorrect_acc)
2576 else:
2577 self.ser.addInputTensor(iter)
2578 self.ser.addInputTensor(a)
2579 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002580 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002581
2582 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002583 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002584 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002585 cond_type = DType.BOOL
2586 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2587 choice = self.rng.choice([1, 2])
2588 if choice == 1:
2589 cond_shape = [3]
2590 else:
2591 cond_shape = [1, 2]
2592 else:
2593 cond_shape = []
2594 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002595
Kevin Cheng550ccc52021-03-03 11:21:43 -08002596 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002597
2598 # BODY block (input: a, acc, iter, output: a, acc, iter)
2599 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002600 self.ser.addBasicBlock(body_block)
2601
Matthew Haddon630c17c2021-10-14 15:05:41 +01002602 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2603 self.ser.addInputTensor(incorrect_iter)
2604 self.ser.addInputTensor(a)
2605 self.ser.addInputTensor(incorrect_acc)
2606 else:
2607 self.ser.addInputTensor(iter)
2608 self.ser.addInputTensor(a)
2609 self.ser.addInputTensor(acc)
2610
Kevin Cheng550ccc52021-03-03 11:21:43 -08002611 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002612
2613 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002614 iter_body_out = self.ser.addIntermediate(
2615 incorrect_iter.shape, incorrect_iter.dtype
2616 )
2617 acc_body_out = self.ser.addIntermediate(
2618 incorrect_acc.shape, incorrect_acc.dtype
2619 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002620 else:
2621 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2622 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2623
Eric Kunzee5e26762020-10-13 16:11:07 -07002624 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2625 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2626 self.ser.addOutputTensor(iter_body_out)
2627 self.ser.addOutputTensor(a)
2628 self.ser.addOutputTensor(acc_body_out)
2629
Les Bell729b0352021-11-24 10:28:21 +00002630 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002631 self.ser,
2632 validator_fcns,
2633 error_name,
2634 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002635 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002636 ):
2637 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002638
Jeremy Johnson587cc842024-02-08 11:45:44 +00002639 compliance = self.tensorComplianceMetaData(
2640 op, a.dtype, args_dict, acc_out, error_name
2641 )
2642
2643 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002644
Luke Hutton57287132023-02-06 14:54:18 +00002645 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002646 self,
2647 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002648 inputs,
2649 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002650 validator_fcns=None,
2651 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002652 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002653 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002654 assert len(inputs) == 2
2655 val1, val2 = inputs
2656 inverse = args_dict["inverse"]
2657
Luke Hutton57287132023-02-06 14:54:18 +00002658 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2659
2660 input_names = [val1.name, val2.name]
2661 pCount, cCount = op["operands"]
2662 num_operands = pCount + cCount
2663
2664 output_names = [res.name for res in results]
2665 output_shapes = [res.shape for res in results]
2666 output_dtypes = [res.dtype for res in results]
2667
2668 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2669 self, error_name, input_names, output_names
2670 )
2671
2672 if not TosaErrorValidator.evValidateErrorIfs(
2673 self.ser,
2674 validator_fcns,
2675 error_name,
2676 op=op,
2677 inverse=inverse,
2678 input1=val1,
2679 input2=val2,
2680 input_shape=val1.shape,
2681 input_dtype=val1.dtype,
2682 output_shape=output_shapes,
2683 output_dtype=output_dtypes,
2684 result_tensors=results,
2685 input_list=input_names,
2686 output_list=output_names,
2687 num_operands=num_operands,
2688 ):
2689 return None
2690
Tai Lyd3797f02023-11-15 23:06:19 +00002691 # TODO - Test local_bound, for now set local bound attribute to False
2692 local_bound = False
2693
Luke Hutton57287132023-02-06 14:54:18 +00002694 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002695 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002696
2697 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002698
2699 compliance = []
2700 for res in results:
2701 compliance.append(
2702 self.tensorComplianceMetaData(
2703 op, val1.dtype, args_dict, res, error_name
2704 )
2705 )
2706
2707 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002708
Tai Lyd3797f02023-11-15 23:06:19 +00002709 def build_rfft2d(
2710 self,
2711 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002712 inputs,
2713 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002714 validator_fcns=None,
2715 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002716 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002717 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002718 assert len(inputs) == 1
2719 val = inputs[0]
Luke Hutton261b7b62023-01-10 14:50:31 +00002720 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2721
2722 input_names = [val.name]
2723 pCount, cCount = op["operands"]
2724 num_operands = pCount + cCount
2725
2726 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002727 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002728 output_dtypes = [res.dtype for res in results]
2729
2730 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2731 self, error_name, input_names, output_names
2732 )
2733
2734 if not TosaErrorValidator.evValidateErrorIfs(
2735 self.ser,
2736 validator_fcns,
2737 error_name,
2738 op=op,
2739 input_shape=val.shape,
2740 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002741 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002742 output_dtype=output_dtypes,
2743 result_tensors=results,
2744 input_list=input_names,
2745 output_list=output_names,
2746 num_operands=num_operands,
2747 ):
2748 return None
2749
Tai Lyd3797f02023-11-15 23:06:19 +00002750 # TODO - Test local_bound, for now set local bound attribute to False
2751 local_bound = False
2752
2753 attr = ts.TosaSerializerAttribute()
2754 attr.RFFTAttribute(local_bound)
2755
2756 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002757
2758 compliance = []
2759 for res in results:
2760 compliance.append(
2761 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2762 )
2763
2764 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002765
Won Jeon74342e52024-01-09 00:34:40 +00002766 def build_shape_op(
2767 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2768 ):
2769 assert len(inputs) == 2
2770 a, b = inputs
2771
2772 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2773
2774 # Invalidate Input/Output list for error if checks.
2775 input_list = [a.name, b.name]
2776 output_list = [result_tensor.name]
2777 pCount, cCount = op["operands"]
2778 num_operands = pCount + cCount
2779 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2780 self, error_name, input_list, output_list
2781 )
2782
2783 if not TosaErrorValidator.evValidateErrorIfs(
2784 self.ser,
2785 validator_fcns,
2786 error_name,
2787 op=op,
2788 input1=a,
2789 input2=b,
2790 input_shape=a.shape,
2791 input_dtype=a.dtype,
2792 output_shape=result_tensor.shape,
2793 output_dtype=result_tensor.dtype,
2794 result_tensors=[result_tensor],
2795 input_list=input_list,
2796 output_list=output_list,
2797 num_operands=num_operands,
2798 ):
2799 return None
2800
2801 self.ser.addOperator(
2802 op["op"],
2803 input_list,
2804 output_list,
2805 )
2806 compliance = self.tensorComplianceMetaData(
2807 op, a.dtype, args_dict, result_tensor, error_name
2808 )
2809
2810 return TosaTestGen.BuildInfo(result_tensor, compliance)
2811
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002812 def create_filter_lists(
2813 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2814 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002815 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2816 default_test_rank_range = range(1, 5)
2817 if not shapeFilter:
2818 shapeFilter = [None]
2819
2820 # Calculate the filters based on what is requested and what the operator allows
2821 rmin, rmax = op["rank"]
2822 if rankFilter is not None:
2823 cleanRankFilter = []
2824 # Ensure rankFilter values are allowed by operator
2825 for rank in rankFilter:
2826 if rank >= rmin and rank <= rmax:
2827 cleanRankFilter.append(rank)
2828 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002829 # Ensure default behaviour is bounded by default range or by operator,
2830 # whichever is the smaller range of ranks.
2831 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002832 cleanRankFilter = (
2833 opRankRange
2834 if len(opRankRange) <= len(default_test_rank_range)
2835 else default_test_rank_range
2836 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002837 else:
2838 cleanRankFilter = range(rmin, rmax + 1)
2839
2840 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002841
Matthew Haddon1c00b712021-10-01 15:51:03 +01002842 if dtypeFilter is not None:
2843 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002844 # Create list of operator dtypes filtered by requested dtypes
2845 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002846 if dtype in dtypeFilter or (
2847 isinstance(dtype, list) and dtype[0] in dtypeFilter
2848 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002849 cleanDtypeFilter.append(dtype)
2850 else:
2851 cleanDtypeFilter = dtypes
2852
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002853 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002854 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002855 "shapeFilter": shapeFilter,
2856 "rankFilter": cleanRankFilter,
2857 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002858 }
2859 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002860 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002861 if validator is not None:
2862 validator_info = validator(check=False, op=op)
2863 else:
2864 return None
2865
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002866 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002867
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002868 # Set parameters as required
2869 if error_arguments["rank"] is not None:
2870 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002871 else:
2872 rankFilter = cleanRankFilter
2873
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002874 if error_arguments["dtype"] is not None:
2875 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002876 else:
2877 dtypeFilter = cleanDtypeFilter
2878
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002879 if error_arguments["shape"] is not None:
2880 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002881 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002882 shapeFilter = shapeFilter[
2883 :2
2884 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002885
2886 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002887 "shapeFilter": shapeFilter,
2888 "rankFilter": rankFilter,
2889 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002890 }
2891 return filterDict
2892
Kevin Cheng550ccc52021-03-03 11:21:43 -08002893 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002894 self,
2895 opName,
2896 shapeFilter=[None],
2897 rankFilter=None,
2898 dtypeFilter=None,
2899 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002900 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002901
2902 try:
2903 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002904 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002905 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002906
2907 # Initialize a new random number generator
2908 self.rng = np.random.default_rng(self.random_seed)
2909
Jeremy Johnson1271c442023-09-05 11:39:26 +01002910 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002911
Eric Kunzee5e26762020-10-13 16:11:07 -07002912 # Test list consists of a tuple of:
2913 # (opName, testNameStr, dtype, shapeList, argumentsList)
2914 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002915 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002916 error_if_validators = op["error_if_validators"]
2917 else:
2918 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002919
Matthew Haddon1c00b712021-10-01 15:51:03 +01002920 for validator in error_if_validators:
2921 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002922 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002923 else:
2924 error_name = None
2925
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002926 filterDict = self.create_filter_lists(
2927 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2928 )
2929 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002930 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002931 cleanRankFilter = filterDict["rankFilter"]
2932 cleanDtypeFilter = filterDict["dtypeFilter"]
2933 cleanShapeFilter = filterDict["shapeFilter"]
2934 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002935
2936 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002937 for t in cleanDtypeFilter:
2938 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002939 # Filter out by rank
2940 if shape is not None and len(shape) != r:
2941 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002942 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002943 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002944
Matthew Haddon74567092021-07-16 15:38:20 +01002945 shapeStr = self.shapeStr(shapeList[0])
2946 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002947
Matthew Haddon74567092021-07-16 15:38:20 +01002948 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2949 argList = []
2950 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002951 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002952 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002953 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002954
Matthew Haddon74567092021-07-16 15:38:20 +01002955 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002956 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002957 if argStr:
2958 testStr = "{}_{}_{}_{}".format(
2959 opName, shapeStr, typeStr, argStr
2960 )
2961 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002962 testStr = "{}_{}_{}".format(
2963 opName, shapeStr, typeStr
2964 )
2965 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002966 if argStr:
2967 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2968 opName, error_name, shapeStr, typeStr, argStr
2969 )
2970 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002971 testStr = "{}_ERRORIF_{}_{}_{}".format(
2972 opName, error_name, shapeStr, typeStr
2973 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002974
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002975 testList.append(
2976 (opName, testStr, t, error_name, shapeList, args)
2977 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002978
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002979 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002980 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2981 if "invalid_test_validators" in op:
2982 invalid_test_validators = op["invalid_test_validators"]
2983 clean_testList = []
2984 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002985 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002986 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002987 if validator_fcn(
2988 opName=test[0],
2989 input_dtype=test[2],
2990 shapeList=test[4],
2991 args=test[5],
2992 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002993 remove_test = True
2994 if not remove_test:
2995 clean_testList.append(test)
2996 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002997
2998 return testList
2999
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003000 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00003001 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003002 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003003 try:
3004 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003005 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003006 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003007
Jeremy Johnson0c716862023-04-13 17:18:19 +01003008 if self.args.verbose:
3009 print(f"Creating {testStr}")
3010
Eric Kunzee5e26762020-10-13 16:11:07 -07003011 # Create a serializer
3012 self.createSerializer(opName, testStr)
3013
Jeremy Johnson1271c442023-09-05 11:39:26 +01003014 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003015 if "error_if_validators" in op:
3016 error_if_validators = op["error_if_validators"]
3017 else:
3018 error_if_validators = None
3019
Kevin Cheng550ccc52021-03-03 11:21:43 -08003020 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003021 num_operands = pCount + cCount
3022
3023 if isinstance(dtype_or_dtypeList, list):
3024 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003025 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003026 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003027 else:
3028 dtypeList = [dtype_or_dtypeList] * (num_operands)
3029
Won Jeon74342e52024-01-09 00:34:40 +00003030 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003031 assert (
3032 len(shapeList) == num_operands
3033 ), "shapeList length {} must match number of operands {}".format(
3034 len(shapeList), num_operands
3035 )
3036 assert (
3037 len(dtypeList) == num_operands
3038 ), "dtypeList length {} must match number of operands {}".format(
3039 len(dtypeList), num_operands
3040 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003041
3042 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003043 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003044 except KeyError:
3045 qgen = None
3046
3047 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003048
Matthew Haddon1c00b712021-10-01 15:51:03 +01003049 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003050 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003051 else:
3052 qinfo = None
3053
Jeremy Johnson1271c442023-09-05 11:39:26 +01003054 # Extra meta data for the desc.json
3055 tensMeta = {}
3056
Jeremy Johnson587cc842024-02-08 11:45:44 +00003057 # Check we are using the new interface with an argsDict dictionary
3058 assert isinstance(
3059 argsDict, dict
3060 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003061
Jeremy Johnson587cc842024-02-08 11:45:44 +00003062 # New interface with args info in dictionary
3063 assert "dg_type" in argsDict
3064 tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name)
3065 if tvgInfo.dataGenDict:
3066 tensMeta["data_gen"] = tvgInfo.dataGenDict
3067 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003068
Jeremy Johnson587cc842024-02-08 11:45:44 +00003069 result = build_fcn(
3070 self,
3071 op,
3072 tens,
3073 argsDict,
3074 validator_fcns=error_if_validators,
3075 error_name=error_name,
3076 qinfo=qinfo,
3077 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003078
Jeremy Johnson1271c442023-09-05 11:39:26 +01003079 if result:
Les Bell729b0352021-11-24 10:28:21 +00003080 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003081 if isinstance(result, TosaTestGen.BuildInfo):
3082 # Add the compliance meta data (if any)
3083 compliance = result.getComplianceInfo()
3084 if compliance:
3085 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003086 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00003087 else:
3088 # The test is not valid
3089 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01003090
Eric Kunzee5e26762020-10-13 16:11:07 -07003091 def createDynamicOpLists(self):
3092
Jeremy Johnson00423432022-09-12 17:27:37 +01003093 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3094 # Already created these lists (can occur when class is initialized more than once)
3095 return
3096
Eric Kunzee5e26762020-10-13 16:11:07 -07003097 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003098 if not self.args.level8k:
3099 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3100 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3101 else:
3102 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3103 KERNELS_2D = [[1, bigK], [bigK, 2]]
3104 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003105
Kevin Cheng1533b852021-09-01 12:51:58 -07003106 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003107 testName = "conv2d_{}x{}".format(k[0], k[1])
3108 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3109 self.TOSA_OP_LIST[testName]["filter"] = k
3110 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003111
Kevin Cheng550ccc52021-03-03 11:21:43 -08003112 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3113 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3114 "depthwise_conv2d_TEMPLATE"
3115 ].copy()
3116 self.TOSA_OP_LIST[testName]["filter"] = k
3117 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003118
Kevin Cheng550ccc52021-03-03 11:21:43 -08003119 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3120 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3121 "transpose_conv2d_TEMPLATE"
3122 ].copy()
3123 self.TOSA_OP_LIST[testName]["filter"] = k
3124 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003125
Kevin Cheng1533b852021-09-01 12:51:58 -07003126 for k in KERNELS_3D:
3127 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3128 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3129 self.TOSA_OP_LIST[testName]["filter"] = k
3130 self.TOSA_OP_LIST[testName]["template"] = False
3131
Eric Kunzee5e26762020-10-13 16:11:07 -07003132 # Delete any templates after having created any dynamic ops
3133 # This is a two-pass operation because it's bad practice to delete
3134 # keys from dictionaries while iterating
3135 keyList = []
3136 for k in self.TOSA_OP_LIST:
3137 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003138 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003139 keyList.append(k)
3140 continue
3141 except KeyError:
3142 pass
3143
3144 for k in keyList:
3145 del self.TOSA_OP_LIST[k]
3146
3147 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003148 """Fill in default fields for ops if they aren't already specified.
3149 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003150 for op in self.TOSA_OP_LIST:
3151
3152 # Required fields
3153 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003154 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003155 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003156 raise Exception(
3157 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3158 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003159
3160 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003161 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003162 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003163 raise Exception(
3164 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3165 op
3166 )
3167 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003168
3169 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003170 _ = self.TOSA_OP_LIST[op]["types"]
3171 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003172 raise Exception(
3173 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3174 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003175
3176 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003177 _ = self.TOSA_OP_LIST[op]["op"]
3178 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003179 raise Exception(
3180 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3181 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003182
3183 # Put in default rank range, if missing
3184 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003185 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003186 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003187 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003188
3189 # Tensor operator list
3190 # 'op': op name
3191 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003192 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3193 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003194 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3195 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003196 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003197
Kevin Cheng550ccc52021-03-03 11:21:43 -08003198 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003199 TYPE_INT_FP = [
3200 DType.INT8,
3201 DType.INT16,
3202 DType.INT32,
3203 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003204 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003205 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003206 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003207
Kevin Cheng550ccc52021-03-03 11:21:43 -08003208 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003209 TYPE_FI32 = [
3210 DType.FP32,
3211 DType.FP16,
3212 DType.BF16,
3213 DType.INT32,
3214 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003215 TYPE_FIB = [
3216 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003217 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003218 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003219 DType.INT8,
3220 DType.INT16,
3221 DType.INT32,
3222 DType.BOOL,
3223 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003224 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003225
Won Jeon2c34b462024-02-06 18:37:00 +00003226 TYPE_NARROW_INT_FP = [
3227 DType.INT8,
3228 DType.INT16,
3229 DType.FP16,
3230 DType.BF16,
3231 DType.FP32,
3232 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003233
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003234 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003235 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003236 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003237 [DType.INT8, DType.INT8, DType.INT32],
3238 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003239 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003240 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003241 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003242 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003243 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3244 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003245 ]
3246
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003247 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003248
3249 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003250 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003251 "argmax": {
3252 "op": Op.ARGMAX,
3253 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003254 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003255 "build_fcn": (
3256 build_argmax,
3257 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003258 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003259 TosaArgGen.agAxis,
3260 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003261 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003262 "error_if_validators": (
3263 TosaErrorValidator.evAxisSmallerZero,
3264 TosaErrorValidator.evAxisLargerRank,
3265 TosaErrorValidator.evArgmaxOutputRankMismatch,
3266 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3267 TosaErrorValidator.evWrongRank,
3268 TosaErrorValidator.evWrongInputType,
3269 TosaErrorValidator.evWrongOutputType,
3270 TosaErrorValidator.evWrongInputList,
3271 TosaErrorValidator.evWrongOutputList,
3272 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003273 "data_gen": {
3274 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3275 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003276 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003277 "avg_pool2d": {
3278 "op": Op.AVG_POOL2D,
3279 "operands": (1, 0),
3280 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003281 "build_fcn": (
3282 build_pool2d,
3283 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003284 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003285 TosaArgGen.agPooling,
3286 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003287 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003288 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003289 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003290 "error_if_validators": (
3291 TosaErrorValidator.evKernelSmallerOne,
3292 TosaErrorValidator.evStrideSmallerOne,
3293 TosaErrorValidator.evPadSmallerZero,
3294 TosaErrorValidator.evWrongRank,
3295 TosaErrorValidator.evWrongInputType,
3296 TosaErrorValidator.evWrongOutputType,
3297 TosaErrorValidator.evWrongInputList,
3298 TosaErrorValidator.evWrongOutputList,
3299 TosaErrorValidator.evInputZeroPointNotZero,
3300 TosaErrorValidator.evOutputZeroPointNotZero,
3301 TosaErrorValidator.evPadLargerEqualKernel,
3302 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003303 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003304 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003305 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003306 "data_gen": {
3307 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3308 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003309 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003310 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003311 "conv2d_TEMPLATE": {
3312 "op": Op.CONV2D,
3313 "operands": (1, 2),
3314 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003315 "build_fcn": (
3316 build_conv2d,
3317 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003318 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003319 TosaArgGen.agConv,
3320 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003321 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003322 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003323 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3324 "error_if_validators": (
3325 TosaErrorValidator.evWrongInputType,
3326 TosaErrorValidator.evWrongOutputType,
3327 TosaErrorValidator.evWrongInputList,
3328 TosaErrorValidator.evWrongOutputList,
3329 TosaErrorValidator.evInputZeroPointNotZero,
3330 TosaErrorValidator.evWeightZeroPointNotZero,
3331 TosaErrorValidator.evPadSmallerZero,
3332 TosaErrorValidator.evStrideSmallerOne,
3333 TosaErrorValidator.evDilationSmallerOne,
3334 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003335 TosaErrorValidator.evConvOutputShapeMismatch,
3336 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003337 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003338 "data_gen": {
3339 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3340 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003341 "template": True,
3342 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003343 # Templated operator. Filled in by createDynamicOpLists
3344 "conv3d_TEMPLATE": {
3345 "op": Op.CONV3D,
3346 "operands": (1, 2),
3347 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003348 "build_fcn": (
3349 build_conv3d,
3350 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003351 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003352 TosaArgGen.agConv,
3353 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003354 "qgen": TosaQuantGen.qgConv,
3355 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003356 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3357 "error_if_validators": (
3358 TosaErrorValidator.evWrongInputType,
3359 TosaErrorValidator.evWrongOutputType,
3360 TosaErrorValidator.evWrongInputList,
3361 TosaErrorValidator.evWrongOutputList,
3362 TosaErrorValidator.evInputZeroPointNotZero,
3363 TosaErrorValidator.evWeightZeroPointNotZero,
3364 TosaErrorValidator.evPadSmallerZero,
3365 TosaErrorValidator.evStrideSmallerOne,
3366 TosaErrorValidator.evDilationSmallerOne,
3367 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003368 TosaErrorValidator.evConvOutputShapeMismatch,
3369 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003370 ),
evacha0147ab1762024-01-29 13:23:23 +00003371 "data_gen": {
3372 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3373 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003374 "template": True,
3375 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003376 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003377 "depthwise_conv2d_TEMPLATE": {
3378 "op": Op.DEPTHWISE_CONV2D,
3379 "operands": (1, 2),
3380 "filter": [1, 1],
3381 "rank": (4, 4),
3382 "build_fcn": (
3383 build_depthwise_conv2d,
3384 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003385 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003386 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003387 ),
3388 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003389 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003390 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3391 "error_if_validators": (
3392 TosaErrorValidator.evWrongInputType,
3393 TosaErrorValidator.evWrongOutputType,
3394 TosaErrorValidator.evWrongInputList,
3395 TosaErrorValidator.evWrongOutputList,
3396 TosaErrorValidator.evInputZeroPointNotZero,
3397 TosaErrorValidator.evWeightZeroPointNotZero,
3398 TosaErrorValidator.evPadSmallerZero,
3399 TosaErrorValidator.evStrideSmallerOne,
3400 TosaErrorValidator.evDilationSmallerOne,
3401 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003402 TosaErrorValidator.evConvOutputShapeMismatch,
3403 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003404 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003405 "data_gen": {
3406 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3407 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003408 "template": True,
3409 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003410 "fully_connected": {
3411 "op": Op.FULLY_CONNECTED,
3412 "operands": (1, 2),
3413 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003414 "build_fcn": (
3415 build_fully_connected,
3416 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003417 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003418 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003419 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003420 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003421 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003422 "error_if_validators": (
3423 TosaErrorValidator.evInputZeroPointNotZero,
3424 TosaErrorValidator.evWeightZeroPointNotZero,
3425 TosaErrorValidator.evWrongRank,
3426 TosaErrorValidator.evWrongInputType,
3427 TosaErrorValidator.evWrongOutputType,
3428 TosaErrorValidator.evWrongInputList,
3429 TosaErrorValidator.evWrongOutputList,
3430 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003431 "data_gen": {
3432 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3433 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003434 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003435 "matmul": {
3436 "op": Op.MATMUL,
3437 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003438 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003439 "build_fcn": (
3440 build_matmul,
3441 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003442 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003443 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003444 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003445 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003446 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003447 "error_if_validators": (
3448 TosaErrorValidator.evInputZeroPointNotZero,
3449 TosaErrorValidator.evWrongRank,
3450 TosaErrorValidator.evWrongInputType,
3451 TosaErrorValidator.evWrongOutputType,
3452 TosaErrorValidator.evWrongInputList,
3453 TosaErrorValidator.evWrongOutputList,
3454 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003455 "data_gen": {
3456 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003457 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003458 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003459 "max_pool2d": {
3460 "op": Op.MAX_POOL2D,
3461 "operands": (1, 0),
3462 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003463 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003464 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003465 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003466 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003467 TosaArgGen.agPooling,
3468 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003469 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003470 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003471 "error_if_validators": (
3472 TosaErrorValidator.evKernelSmallerOne,
3473 TosaErrorValidator.evStrideSmallerOne,
3474 TosaErrorValidator.evPadSmallerZero,
3475 TosaErrorValidator.evWrongRank,
3476 TosaErrorValidator.evWrongInputType,
3477 TosaErrorValidator.evWrongOutputType,
3478 TosaErrorValidator.evWrongInputList,
3479 TosaErrorValidator.evWrongOutputList,
3480 TosaErrorValidator.evPadLargerEqualKernel,
3481 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003482 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003483 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003484 "data_gen": {
3485 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3486 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003487 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003488 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003489 "transpose_conv2d_TEMPLATE": {
3490 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003491 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003492 "rank": (4, 4),
3493 "build_fcn": (
3494 build_transpose_conv2d,
3495 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003496 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003497 TosaArgGen.agTransposeConv2D,
3498 ),
3499 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003500 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003501 "invalid_test_validators": (
3502 TosaInvalidValidator.ivHeightWidthInvalid,
3503 TosaInvalidValidator.ivNonPositiveOutputShape,
3504 ),
3505 "error_if_validators": (
3506 TosaErrorValidator.evWrongInputType,
3507 TosaErrorValidator.evWrongOutputType,
3508 TosaErrorValidator.evWrongInputList,
3509 TosaErrorValidator.evWrongOutputList,
3510 TosaErrorValidator.evInputZeroPointNotZero,
3511 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003512 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003513 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003514 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003515 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003516 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003517 "data_gen": {
3518 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3519 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003520 "template": True,
3521 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003522 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003523 "clamp": {
3524 "op": Op.CLAMP,
3525 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003526 "build_fcn": (
3527 build_clamp,
3528 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003529 TosaTensorValuesGen.tvgLazyGenDefault,
3530 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003531 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003532 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003533 "error_if_validators": (
3534 TosaErrorValidator.evMaxSmallerMin,
3535 TosaErrorValidator.evWrongInputType,
3536 TosaErrorValidator.evWrongOutputType,
3537 TosaErrorValidator.evWrongInputList,
3538 TosaErrorValidator.evWrongOutputList,
3539 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003540 "data_gen": {
3541 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3542 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003543 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003544 "sigmoid": {
3545 "op": Op.SIGMOID,
3546 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003547 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003548 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003549 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003550 TosaTensorValuesGen.tvgLazyGenDefault,
3551 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003552 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003553 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003554 "error_if_validators": (
3555 TosaErrorValidator.evWrongInputType,
3556 TosaErrorValidator.evWrongOutputType,
3557 TosaErrorValidator.evWrongInputList,
3558 TosaErrorValidator.evWrongOutputList,
3559 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003560 "data_gen": {
3561 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3562 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003563 },
3564 "tanh": {
3565 "op": Op.TANH,
3566 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003567 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003568 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003569 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003570 TosaTensorValuesGen.tvgLazyGenDefault,
3571 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003572 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003573 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003574 "error_if_validators": (
3575 TosaErrorValidator.evWrongInputType,
3576 TosaErrorValidator.evWrongOutputType,
3577 TosaErrorValidator.evWrongInputList,
3578 TosaErrorValidator.evWrongOutputList,
3579 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003580 "data_gen": {
3581 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3582 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003583 "compliance": {
3584 "abs_error_lower_bound": 0.5,
3585 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003586 },
Won Jeon78155c62023-06-10 00:20:04 +00003587 "erf": {
3588 "op": Op.ERF,
3589 "operands": (1, 0),
3590 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003591 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003592 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003593 TosaTensorValuesGen.tvgLazyGenDefault,
3594 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003595 ),
3596 "types": TYPE_FP,
3597 "error_if_validators": (
3598 TosaErrorValidator.evWrongInputType,
3599 TosaErrorValidator.evWrongOutputType,
3600 TosaErrorValidator.evWrongInputList,
3601 TosaErrorValidator.evWrongOutputList,
3602 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003603 "data_gen": {
3604 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3605 },
3606 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003607 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003608 # Elementwise Binary Operators
3609 "add": {
3610 "op": Op.ADD,
3611 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003612 "build_fcn": (
3613 build_binary_broadcast,
3614 TosaTensorGen.tgBroadcastFuzz,
3615 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003616 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003617 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003618 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003619 "error_if_validators": (
3620 TosaErrorValidator.evRankMismatch,
3621 TosaErrorValidator.evWrongInputType,
3622 TosaErrorValidator.evWrongOutputType,
3623 TosaErrorValidator.evWrongInputList,
3624 TosaErrorValidator.evWrongOutputList,
3625 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003626 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003627 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003628 "data_gen": {
3629 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3630 },
3631 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003632 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003633 "arithmetic_right_shift": {
3634 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3635 "operands": (2, 0),
3636 "build_fcn": (
3637 build_arithmetic_right_shift,
3638 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003639 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003640 TosaArgGen.agArithmeticRightShift,
3641 ),
3642 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003643 "error_if_validators": (
3644 TosaErrorValidator.evRankMismatch,
3645 TosaErrorValidator.evWrongInputType,
3646 TosaErrorValidator.evWrongOutputType,
3647 TosaErrorValidator.evWrongInputList,
3648 TosaErrorValidator.evWrongOutputList,
3649 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003650 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003651 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003652 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003653 "bitwise_and": {
3654 "op": Op.BITWISE_AND,
3655 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003656 "build_fcn": (
3657 build_binary_broadcast,
3658 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003659 TosaTensorValuesGen.tvgLazyGenDefault,
3660 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003661 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003662 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003663 "error_if_validators": (
3664 TosaErrorValidator.evRankMismatch,
3665 TosaErrorValidator.evWrongInputType,
3666 TosaErrorValidator.evWrongOutputType,
3667 TosaErrorValidator.evWrongInputList,
3668 TosaErrorValidator.evWrongOutputList,
3669 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003670 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003671 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003672 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003673 "bitwise_or": {
3674 "op": Op.BITWISE_OR,
3675 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003676 "build_fcn": (
3677 build_binary_broadcast,
3678 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003679 TosaTensorValuesGen.tvgLazyGenDefault,
3680 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003681 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003682 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003683 "error_if_validators": (
3684 TosaErrorValidator.evRankMismatch,
3685 TosaErrorValidator.evWrongInputType,
3686 TosaErrorValidator.evWrongOutputType,
3687 TosaErrorValidator.evWrongInputList,
3688 TosaErrorValidator.evWrongOutputList,
3689 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003690 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003691 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003692 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003693 "bitwise_xor": {
3694 "op": Op.BITWISE_XOR,
3695 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003696 "build_fcn": (
3697 build_binary_broadcast,
3698 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003699 TosaTensorValuesGen.tvgLazyGenDefault,
3700 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003701 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003702 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003703 "error_if_validators": (
3704 TosaErrorValidator.evRankMismatch,
3705 TosaErrorValidator.evWrongInputType,
3706 TosaErrorValidator.evWrongOutputType,
3707 TosaErrorValidator.evWrongInputList,
3708 TosaErrorValidator.evWrongOutputList,
3709 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003710 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003711 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003712 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003713 "intdiv": {
3714 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003715 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003716 "build_fcn": (
3717 build_binary_broadcast,
3718 TosaTensorGen.tgBroadcastFuzz,
3719 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003720 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003721 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003722 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003723 "error_if_validators": (
3724 TosaErrorValidator.evRankMismatch,
3725 TosaErrorValidator.evWrongInputType,
3726 TosaErrorValidator.evWrongOutputType,
3727 TosaErrorValidator.evWrongInputList,
3728 TosaErrorValidator.evWrongOutputList,
3729 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003730 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003731 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003732 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003733 "logical_and": {
3734 "op": Op.LOGICAL_AND,
3735 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003736 "build_fcn": (
3737 build_binary_broadcast,
3738 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003739 TosaTensorValuesGen.tvgLazyGenDefault,
3740 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003741 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003742 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003743 "error_if_validators": (
3744 TosaErrorValidator.evRankMismatch,
3745 TosaErrorValidator.evWrongInputType,
3746 TosaErrorValidator.evWrongOutputType,
3747 TosaErrorValidator.evWrongInputList,
3748 TosaErrorValidator.evWrongOutputList,
3749 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003750 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003751 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003752 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003753 "logical_left_shift": {
3754 "op": Op.LOGICAL_LEFT_SHIFT,
3755 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003756 "build_fcn": (
3757 build_binary_broadcast,
3758 TosaTensorGen.tgBroadcastFuzz,
3759 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003760 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003761 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003762 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003763 "error_if_validators": (
3764 TosaErrorValidator.evRankMismatch,
3765 TosaErrorValidator.evWrongInputType,
3766 TosaErrorValidator.evWrongOutputType,
3767 TosaErrorValidator.evWrongInputList,
3768 TosaErrorValidator.evWrongOutputList,
3769 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003770 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003771 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003772 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003773 "logical_right_shift": {
3774 "op": Op.LOGICAL_RIGHT_SHIFT,
3775 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003776 "build_fcn": (
3777 build_binary_broadcast,
3778 TosaTensorGen.tgBroadcastFuzz,
3779 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003780 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003781 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003782 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003783 "error_if_validators": (
3784 TosaErrorValidator.evRankMismatch,
3785 TosaErrorValidator.evWrongInputType,
3786 TosaErrorValidator.evWrongOutputType,
3787 TosaErrorValidator.evWrongInputList,
3788 TosaErrorValidator.evWrongOutputList,
3789 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003790 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003791 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003792 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003793 "logical_or": {
3794 "op": Op.LOGICAL_OR,
3795 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003796 "build_fcn": (
3797 build_binary_broadcast,
3798 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003799 TosaTensorValuesGen.tvgLazyGenDefault,
3800 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003801 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003802 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003803 "error_if_validators": (
3804 TosaErrorValidator.evRankMismatch,
3805 TosaErrorValidator.evWrongInputType,
3806 TosaErrorValidator.evWrongOutputType,
3807 TosaErrorValidator.evWrongInputList,
3808 TosaErrorValidator.evWrongOutputList,
3809 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003810 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003811 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003812 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003813 "logical_xor": {
3814 "op": Op.LOGICAL_XOR,
3815 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003816 "build_fcn": (
3817 build_binary_broadcast,
3818 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003819 TosaTensorValuesGen.tvgLazyGenDefault,
3820 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003821 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003822 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003823 "error_if_validators": (
3824 TosaErrorValidator.evRankMismatch,
3825 TosaErrorValidator.evWrongInputType,
3826 TosaErrorValidator.evWrongOutputType,
3827 TosaErrorValidator.evWrongInputList,
3828 TosaErrorValidator.evWrongOutputList,
3829 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003830 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003831 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003832 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003833 "maximum": {
3834 "op": Op.MAXIMUM,
3835 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003836 "build_fcn": (
3837 build_binary_broadcast,
3838 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003839 TosaTensorValuesGen.tvgLazyGenDefault,
3840 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003841 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003842 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003843 "error_if_validators": (
3844 TosaErrorValidator.evRankMismatch,
3845 TosaErrorValidator.evWrongInputType,
3846 TosaErrorValidator.evWrongOutputType,
3847 TosaErrorValidator.evWrongInputList,
3848 TosaErrorValidator.evWrongOutputList,
3849 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003850 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003851 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003852 "data_gen": {
3853 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3854 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003855 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003856 "minimum": {
3857 "op": Op.MINIMUM,
3858 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003859 "build_fcn": (
3860 build_binary_broadcast,
3861 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003862 TosaTensorValuesGen.tvgLazyGenDefault,
3863 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003864 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003865 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003866 "error_if_validators": (
3867 TosaErrorValidator.evRankMismatch,
3868 TosaErrorValidator.evWrongInputType,
3869 TosaErrorValidator.evWrongOutputType,
3870 TosaErrorValidator.evWrongInputList,
3871 TosaErrorValidator.evWrongOutputList,
3872 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003873 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003874 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003875 "data_gen": {
3876 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3877 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003878 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003879 "mul": {
3880 "op": Op.MUL,
3881 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003882 "build_fcn": (
3883 build_mul,
3884 TosaTensorGen.tgBroadcastFuzz,
3885 TosaTensorValuesGen.tvgMul,
3886 TosaArgGen.agMul,
3887 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003888 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003889 "error_if_validators": (
3890 TosaErrorValidator.evWrongInputType,
3891 TosaErrorValidator.evWrongOutputType,
3892 TosaErrorValidator.evWrongInputList,
3893 TosaErrorValidator.evWrongOutputList,
3894 TosaErrorValidator.evRankMismatch,
3895 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003896 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003897 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003898 "data_gen": {
3899 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3900 },
3901 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003902 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003903 "pow": {
3904 "op": Op.POW,
3905 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003906 "build_fcn": (
3907 build_binary_broadcast,
3908 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003909 TosaTensorValuesGen.tvgPow,
3910 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003911 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003912 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003913 "error_if_validators": (
3914 TosaErrorValidator.evRankMismatch,
3915 TosaErrorValidator.evWrongInputType,
3916 TosaErrorValidator.evWrongOutputType,
3917 TosaErrorValidator.evWrongInputList,
3918 TosaErrorValidator.evWrongOutputList,
3919 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003920 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003921 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003922 "data_gen": {
3923 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3924 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003925 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003926 "sub": {
3927 "op": Op.SUB,
3928 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003929 "build_fcn": (
3930 build_binary_broadcast,
3931 TosaTensorGen.tgBroadcastFuzz,
3932 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003933 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003934 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003935 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003936 "error_if_validators": (
3937 TosaErrorValidator.evRankMismatch,
3938 TosaErrorValidator.evWrongInputType,
3939 TosaErrorValidator.evWrongOutputType,
3940 TosaErrorValidator.evWrongInputList,
3941 TosaErrorValidator.evWrongOutputList,
3942 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003943 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003944 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003945 "data_gen": {
3946 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3947 },
3948 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003949 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003950 "table": {
3951 "op": Op.TABLE,
3952 # Use the automatic generation functions to create the input array
3953 # but create the table tensor in the build function, as it may be
3954 # a different type from the input
3955 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003956 "build_fcn": (
3957 build_table,
3958 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003959 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003960 TosaArgGen.agTable,
3961 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003962 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003963 "error_if_validators": (
3964 TosaErrorValidator.evWrongInputType,
3965 TosaErrorValidator.evWrongOutputType,
3966 TosaErrorValidator.evWrongInputList,
3967 TosaErrorValidator.evWrongOutputList,
3968 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003969 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003970 # Elementwise Unary operators
3971 "abs": {
3972 "op": Op.ABS,
3973 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003974 "build_fcn": (
3975 build_unary,
3976 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003977 TosaTensorValuesGen.tvgLazyGenDefault,
3978 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003979 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003980 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003981 "error_if_validators": (
3982 TosaErrorValidator.evWrongInputType,
3983 TosaErrorValidator.evWrongOutputType,
3984 TosaErrorValidator.evWrongInputList,
3985 TosaErrorValidator.evWrongOutputList,
3986 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003987 "data_gen": {
3988 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3989 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003990 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003991 "bitwise_not": {
3992 "op": Op.BITWISE_NOT,
3993 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003994 "build_fcn": (
3995 build_unary,
3996 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003997 TosaTensorValuesGen.tvgLazyGenDefault,
3998 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003999 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004000 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004001 "error_if_validators": (
4002 TosaErrorValidator.evWrongInputType,
4003 TosaErrorValidator.evWrongOutputType,
4004 TosaErrorValidator.evWrongInputList,
4005 TosaErrorValidator.evWrongOutputList,
4006 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004007 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004008 "ceil": {
4009 "op": Op.CEIL,
4010 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004011 "build_fcn": (
4012 build_unary,
4013 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004014 TosaTensorValuesGen.tvgLazyGenDefault,
4015 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004016 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004017 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004018 "error_if_validators": (
4019 TosaErrorValidator.evWrongInputType,
4020 TosaErrorValidator.evWrongOutputType,
4021 TosaErrorValidator.evWrongInputList,
4022 TosaErrorValidator.evWrongOutputList,
4023 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004024 "data_gen": {
4025 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4026 },
4027 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004028 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004029 "clz": {
4030 "op": Op.CLZ,
4031 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004032 "build_fcn": (
4033 build_unary,
4034 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004035 TosaTensorValuesGen.tvgLazyGenDefault,
4036 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004037 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004038 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004039 "error_if_validators": (
4040 TosaErrorValidator.evWrongInputType,
4041 TosaErrorValidator.evWrongOutputType,
4042 TosaErrorValidator.evWrongInputList,
4043 TosaErrorValidator.evWrongOutputList,
4044 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004045 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004046 "exp": {
4047 "op": Op.EXP,
4048 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004049 "build_fcn": (
4050 build_unary,
4051 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004052 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004053 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004054 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004055 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004056 "error_if_validators": (
4057 TosaErrorValidator.evWrongInputType,
4058 TosaErrorValidator.evWrongOutputType,
4059 TosaErrorValidator.evWrongInputList,
4060 TosaErrorValidator.evWrongOutputList,
4061 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004062 "data_gen": {
4063 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4064 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004065 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004066 "floor": {
4067 "op": Op.FLOOR,
4068 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004069 "build_fcn": (
4070 build_unary,
4071 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004072 TosaTensorValuesGen.tvgLazyGenDefault,
4073 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004074 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004075 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004076 "error_if_validators": (
4077 TosaErrorValidator.evWrongInputType,
4078 TosaErrorValidator.evWrongOutputType,
4079 TosaErrorValidator.evWrongInputList,
4080 TosaErrorValidator.evWrongOutputList,
4081 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004082 "data_gen": {
4083 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4084 },
4085 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004086 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004087 "log": {
4088 "op": Op.LOG,
4089 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004090 "build_fcn": (
4091 build_unary,
4092 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004093 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004094 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004095 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004096 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004097 "error_if_validators": (
4098 TosaErrorValidator.evWrongInputType,
4099 TosaErrorValidator.evWrongOutputType,
4100 TosaErrorValidator.evWrongInputList,
4101 TosaErrorValidator.evWrongOutputList,
4102 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004103 "data_gen": {
4104 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4105 },
4106 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004107 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004108 "logical_not": {
4109 "op": Op.LOGICAL_NOT,
4110 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004111 "build_fcn": (
4112 build_unary,
4113 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004114 TosaTensorValuesGen.tvgLazyGenDefault,
4115 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004116 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004117 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004118 "error_if_validators": (
4119 TosaErrorValidator.evWrongInputType,
4120 TosaErrorValidator.evWrongOutputType,
4121 TosaErrorValidator.evWrongInputList,
4122 TosaErrorValidator.evWrongOutputList,
4123 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004124 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004125 "negate": {
4126 "op": Op.NEGATE,
4127 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004128 "build_fcn": (
4129 build_unary,
4130 TosaTensorGen.tgBasic,
4131 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004132 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004133 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004134 "qgen": TosaQuantGen.qgUnary,
4135 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004136 "error_if_validators": (
4137 TosaErrorValidator.evInputZeroPointNotZero,
4138 TosaErrorValidator.evOutputZeroPointNotZero,
4139 TosaErrorValidator.evWrongInputType,
4140 TosaErrorValidator.evWrongOutputType,
4141 TosaErrorValidator.evWrongInputList,
4142 TosaErrorValidator.evWrongOutputList,
4143 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004144 "data_gen": {
4145 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4146 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004147 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004148 "reciprocal": {
4149 "op": Op.RECIPROCAL,
4150 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004151 "build_fcn": (
4152 build_unary,
4153 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004154 TosaTensorValuesGen.tvgLazyGenDefault,
4155 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004156 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004157 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004158 "error_if_validators": (
4159 TosaErrorValidator.evWrongInputType,
4160 TosaErrorValidator.evWrongOutputType,
4161 TosaErrorValidator.evWrongInputList,
4162 TosaErrorValidator.evWrongOutputList,
4163 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004164 "data_gen": {
4165 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4166 },
4167 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004168 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004169 "rsqrt": {
4170 "op": Op.RSQRT,
4171 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004172 "build_fcn": (
4173 build_unary,
4174 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004175 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004176 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004177 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004178 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004179 "error_if_validators": (
4180 TosaErrorValidator.evWrongInputType,
4181 TosaErrorValidator.evWrongOutputType,
4182 TosaErrorValidator.evWrongInputList,
4183 TosaErrorValidator.evWrongOutputList,
4184 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004185 "data_gen": {
4186 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4187 },
4188 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004189 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004190 # Elementwise Ternary operators
4191 "select": {
4192 "op": Op.SELECT,
4193 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004194 "build_fcn": (
4195 build_select,
4196 TosaTensorGen.tgBroadcastFuzz,
4197 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004198 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004199 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004200 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004201 "error_if_validators": (
4202 TosaErrorValidator.evRankMismatch,
4203 TosaErrorValidator.evWrongInputType,
4204 TosaErrorValidator.evWrongOutputType,
4205 TosaErrorValidator.evWrongInputList,
4206 TosaErrorValidator.evWrongOutputList,
4207 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004208 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004209 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004210 "data_gen": {
4211 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4212 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004213 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004214 # Comparison operators
4215 "equal": {
4216 "op": Op.EQUAL,
4217 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004218 "build_fcn": (
4219 build_comparison,
4220 TosaTensorGen.tgBroadcastFuzz,
4221 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004222 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004223 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004224 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004225 "error_if_validators": (
4226 TosaErrorValidator.evRankMismatch,
4227 TosaErrorValidator.evWrongInputType,
4228 TosaErrorValidator.evWrongOutputType,
4229 TosaErrorValidator.evWrongInputList,
4230 TosaErrorValidator.evWrongOutputList,
4231 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004232 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004233 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004234 "data_gen": {
4235 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4236 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004237 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004238 "greater_equal": {
4239 "op": Op.GREATER_EQUAL,
4240 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004241 "build_fcn": (
4242 build_comparison,
4243 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004244 TosaTensorValuesGen.tvgLazyGenDefault,
4245 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004246 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004247 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004248 "error_if_validators": (
4249 TosaErrorValidator.evRankMismatch,
4250 TosaErrorValidator.evWrongInputType,
4251 TosaErrorValidator.evWrongOutputType,
4252 TosaErrorValidator.evWrongInputList,
4253 TosaErrorValidator.evWrongOutputList,
4254 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004255 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004256 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004257 "data_gen": {
4258 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4259 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004260 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004261 "greater": {
4262 "op": Op.GREATER,
4263 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004264 "build_fcn": (
4265 build_comparison,
4266 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004267 TosaTensorValuesGen.tvgLazyGenDefault,
4268 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004269 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004270 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004271 "error_if_validators": (
4272 TosaErrorValidator.evRankMismatch,
4273 TosaErrorValidator.evWrongInputType,
4274 TosaErrorValidator.evWrongOutputType,
4275 TosaErrorValidator.evWrongInputList,
4276 TosaErrorValidator.evWrongOutputList,
4277 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004278 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004279 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004280 "data_gen": {
4281 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4282 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004283 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004284 # Reduction operators
4285 "reduce_all": {
4286 "op": Op.REDUCE_ALL,
4287 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004288 "build_fcn": (
4289 build_reduce,
4290 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004291 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004292 TosaArgGen.agAxis,
4293 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004294 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004295 "error_if_validators": (
4296 TosaErrorValidator.evAxisLargerRank,
4297 TosaErrorValidator.evAxisSmallerZero,
4298 TosaErrorValidator.evShapeOfAxisNotOne,
4299 TosaErrorValidator.evWrongInputType,
4300 TosaErrorValidator.evWrongOutputType,
4301 TosaErrorValidator.evWrongRank,
4302 TosaErrorValidator.evWrongInputList,
4303 TosaErrorValidator.evWrongOutputList,
4304 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004305 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004306 "reduce_any": {
4307 "op": Op.REDUCE_ANY,
4308 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004309 "build_fcn": (
4310 build_reduce,
4311 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004312 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004313 TosaArgGen.agAxis,
4314 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004315 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004316 "error_if_validators": (
4317 TosaErrorValidator.evAxisLargerRank,
4318 TosaErrorValidator.evAxisSmallerZero,
4319 TosaErrorValidator.evShapeOfAxisNotOne,
4320 TosaErrorValidator.evWrongInputType,
4321 TosaErrorValidator.evWrongOutputType,
4322 TosaErrorValidator.evWrongRank,
4323 TosaErrorValidator.evWrongInputList,
4324 TosaErrorValidator.evWrongOutputList,
4325 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004326 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004327 "reduce_max": {
4328 "op": Op.REDUCE_MAX,
4329 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004330 "build_fcn": (
4331 build_reduce,
4332 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004333 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004334 TosaArgGen.agAxis,
4335 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004336 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004337 "error_if_validators": (
4338 TosaErrorValidator.evAxisLargerRank,
4339 TosaErrorValidator.evAxisSmallerZero,
4340 TosaErrorValidator.evShapeOfAxisNotOne,
4341 TosaErrorValidator.evWrongInputType,
4342 TosaErrorValidator.evWrongOutputType,
4343 TosaErrorValidator.evWrongRank,
4344 TosaErrorValidator.evWrongInputList,
4345 TosaErrorValidator.evWrongOutputList,
4346 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004347 "data_gen": {
4348 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4349 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004350 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004351 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004352 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004353 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004354 "build_fcn": (
4355 build_reduce,
4356 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004357 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004358 TosaArgGen.agAxis,
4359 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004360 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004361 "error_if_validators": (
4362 TosaErrorValidator.evAxisLargerRank,
4363 TosaErrorValidator.evAxisSmallerZero,
4364 TosaErrorValidator.evShapeOfAxisNotOne,
4365 TosaErrorValidator.evWrongInputType,
4366 TosaErrorValidator.evWrongOutputType,
4367 TosaErrorValidator.evWrongRank,
4368 TosaErrorValidator.evWrongInputList,
4369 TosaErrorValidator.evWrongOutputList,
4370 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004371 "data_gen": {
4372 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4373 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004374 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004375 "reduce_product": {
4376 "op": Op.REDUCE_PRODUCT,
4377 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004378 "build_fcn": (
4379 build_reduce,
4380 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004381 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004382 TosaArgGen.agAxis,
4383 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004384 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004385 "error_if_validators": (
4386 TosaErrorValidator.evAxisLargerRank,
4387 TosaErrorValidator.evAxisSmallerZero,
4388 TosaErrorValidator.evShapeOfAxisNotOne,
4389 TosaErrorValidator.evWrongInputType,
4390 TosaErrorValidator.evWrongOutputType,
4391 TosaErrorValidator.evWrongRank,
4392 TosaErrorValidator.evWrongInputList,
4393 TosaErrorValidator.evWrongOutputList,
4394 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004395 "data_gen": {
4396 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4397 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004398 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004399 "reduce_sum": {
4400 "op": Op.REDUCE_SUM,
4401 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004402 "build_fcn": (
4403 build_reduce,
4404 TosaTensorGen.tgBasic,
4405 TosaTensorValuesGen.tvgReduceSum,
4406 TosaArgGen.agAxis,
4407 ),
James Ward24dbc422022-10-19 12:20:31 +01004408 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004409 "error_if_validators": (
4410 TosaErrorValidator.evAxisLargerRank,
4411 TosaErrorValidator.evAxisSmallerZero,
4412 TosaErrorValidator.evShapeOfAxisNotOne,
4413 TosaErrorValidator.evWrongInputType,
4414 TosaErrorValidator.evWrongOutputType,
4415 TosaErrorValidator.evWrongRank,
4416 TosaErrorValidator.evWrongInputList,
4417 TosaErrorValidator.evWrongOutputList,
4418 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004419 "data_gen": {
4420 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4421 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004422 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004423 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004424 "concat": {
4425 "op": Op.CONCAT,
4426 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004427 "build_fcn": (
4428 build_concat,
4429 TosaTensorGen.tgConcat,
4430 TosaTensorValuesGen.tvgConcat,
4431 TosaArgGen.agAxis,
4432 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004433 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004434 "error_if_validators": (
4435 TosaErrorValidator.evAxisLargerRank,
4436 TosaErrorValidator.evAxisSmallerZero,
4437 TosaErrorValidator.evConcatInputRankMismatch,
4438 TosaErrorValidator.evConcatShapeSumMismatch,
4439 TosaErrorValidator.evConcatInputDimMismatch,
4440 TosaErrorValidator.evWrongInputType,
4441 TosaErrorValidator.evWrongOutputType,
4442 TosaErrorValidator.evWrongOutputList,
4443 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004444 "data_gen": {
4445 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4446 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004447 },
4448 "pad": {
4449 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004450 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004451 "build_fcn": (
4452 build_pad,
4453 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004454 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004455 TosaArgGen.agPad,
4456 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004457 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004458 "error_if_validators": (
4459 TosaErrorValidator.evWrongInputType,
4460 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004461 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004462 TosaErrorValidator.evWrongOutputType,
4463 TosaErrorValidator.evWrongInputList,
4464 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004465 TosaErrorValidator.evRankMismatch,
4466 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004467 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004468 "data_gen": {
4469 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4470 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004471 },
Won Jeona21b2e82023-08-10 10:33:01 +00004472 "dim": {
4473 "op": Op.DIM,
4474 "operands": (1, 0),
4475 "build_fcn": (
4476 build_dim,
4477 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004478 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004479 TosaArgGen.agAxis,
4480 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004481 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004482 "error_if_validators": (
4483 TosaErrorValidator.evAxisLargerRank,
4484 TosaErrorValidator.evAxisSmallerZero,
4485 TosaErrorValidator.evWrongInputType,
4486 TosaErrorValidator.evWrongInputList,
4487 TosaErrorValidator.evWrongOutputList,
4488 TosaErrorValidator.evWrongRank,
4489 ),
4490 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004491 "reshape": {
4492 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004493 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004494 "build_fcn": (
4495 build_reshape,
4496 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004497 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004498 TosaArgGen.agReshape,
4499 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004500 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004501 "error_if_validators": (
4502 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4503 TosaErrorValidator.evWrongInputType,
4504 TosaErrorValidator.evWrongOutputType,
4505 TosaErrorValidator.evWrongInputList,
4506 TosaErrorValidator.evWrongOutputList,
4507 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004508 "data_gen": {
4509 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4510 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004511 },
4512 "reverse": {
4513 "op": Op.REVERSE,
4514 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004515 "build_fcn": (
4516 build_reverse,
4517 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004518 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004519 TosaArgGen.agAxis,
4520 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004521 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004522 "error_if_validators": (
4523 TosaErrorValidator.evAxisSmallerZero,
4524 TosaErrorValidator.evAxisLargerRank,
4525 TosaErrorValidator.evWrongInputType,
4526 TosaErrorValidator.evWrongOutputType,
4527 TosaErrorValidator.evWrongInputList,
4528 TosaErrorValidator.evWrongOutputList,
4529 ),
evacha0198477222024-01-26 12:25:32 +00004530 "data_gen": {
4531 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4532 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004533 },
4534 "slice": {
4535 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004536 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004537 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004538 "build_fcn": (
4539 build_slice,
4540 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004541 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004542 TosaArgGen.agSlice,
4543 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004544 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004545 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004546 # TODO Turn off these error categories for now as the reference
4547 # model cannot allocate memory space for empty tensor. We probably
4548 # can report an accurate error messege at the right place during
4549 # exeuction.
4550 # TosaErrorValidator.evStartSmallerZero,
4551 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004552 TosaErrorValidator.evStartSizeOutsideBounds,
4553 TosaErrorValidator.evSizeOutputShapeMismatch,
4554 TosaErrorValidator.evInputSizeStartLengthMismatch,
4555 TosaErrorValidator.evWrongRank,
4556 TosaErrorValidator.evWrongInputType,
4557 TosaErrorValidator.evWrongOutputType,
4558 TosaErrorValidator.evWrongInputList,
4559 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004560 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004561 ),
evacha017f7d4252024-01-24 12:08:09 +00004562 "data_gen": {
4563 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4564 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004565 },
4566 "tile": {
4567 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004568 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004569 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004570 "build_fcn": (
4571 build_tile,
4572 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004573 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004574 TosaArgGen.agTile,
4575 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004576 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004577 "error_if_validators": (
4578 TosaErrorValidator.evWrongInputType,
4579 TosaErrorValidator.evWrongOutputType,
4580 TosaErrorValidator.evWrongInputList,
4581 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004582 TosaErrorValidator.evRankMismatch,
4583 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004584 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004585 "data_gen": {
4586 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4587 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004588 },
4589 "transpose": {
4590 "op": Op.TRANSPOSE,
4591 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004592 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004593 "build_fcn": (
4594 build_transpose,
4595 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004596 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004597 TosaArgGen.agTranspose,
4598 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004599 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004600 "error_if_validators": (
4601 TosaErrorValidator.evIndexOutsideBounds,
4602 TosaErrorValidator.evIndexUsedTwice,
4603 TosaErrorValidator.evWrongInputType,
4604 TosaErrorValidator.evWrongOutputType,
4605 TosaErrorValidator.evWrongInputList,
4606 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004607 TosaErrorValidator.evWrongRank,
4608 TosaErrorValidator.evRankMismatch,
4609 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004610 ),
evacha0198477222024-01-26 12:25:32 +00004611 "data_gen": {
4612 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4613 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004614 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004615 # Data nodes
4616 "const": {
4617 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004618 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004619 "build_fcn": (
4620 build_const,
4621 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004622 TosaTensorValuesGen.tvgLazyGenDefault,
4623 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004624 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004625 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha0198477222024-01-26 12:25:32 +00004626 "data_gen": {
4627 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4628 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004629 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004630 "identity": {
4631 "op": Op.IDENTITY,
4632 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004633 "build_fcn": (
4634 build_unary,
4635 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004636 TosaTensorValuesGen.tvgLazyGenDefault,
4637 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004638 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004639 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004640 "data_gen": {
4641 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4642 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004643 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004644 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004645 "gather": {
4646 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004647 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004648 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004649 "build_fcn": (
4650 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004651 TosaTensorGen.tgGather,
4652 TosaTensorValuesGen.tvgGather,
4653 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004654 ),
James Ward24dbc422022-10-19 12:20:31 +01004655 "types": (
4656 DType.INT8,
4657 DType.INT16,
4658 DType.INT32,
4659 DType.FP16,
4660 DType.BF16,
4661 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004662 DType.FP8E4M3,
4663 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004664 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004665 "error_if_validators": (
4666 TosaErrorValidator.evWrongInputType,
4667 TosaErrorValidator.evWrongOutputType,
4668 TosaErrorValidator.evWrongInputList,
4669 TosaErrorValidator.evWrongOutputList,
4670 TosaErrorValidator.evWrongRank,
4671 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004672 "data_gen": {
4673 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4674 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004675 },
4676 "scatter": {
4677 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004678 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004679 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004680 "build_fcn": (
4681 build_scatter,
4682 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004683 TosaTensorValuesGen.tvgScatter,
4684 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004685 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004686 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004687 "error_if_validators": (
4688 TosaErrorValidator.evWrongInputType,
4689 TosaErrorValidator.evWrongOutputType,
4690 TosaErrorValidator.evWrongInputList,
4691 TosaErrorValidator.evWrongOutputList,
4692 TosaErrorValidator.evWrongRank,
4693 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004694 "data_gen": {
4695 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4696 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004697 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004698 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004699 "resize": {
4700 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004701 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004702 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004703 "build_fcn": (
4704 build_resize,
4705 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004706 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004707 TosaArgGen.agResize,
4708 ),
James Ward24dbc422022-10-19 12:20:31 +01004709 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004710 "invalid_test_validators": (
4711 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004712 ),
4713 "error_if_validators": (
4714 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004715 TosaErrorValidator.evScaleSmallerEqualZero,
4716 TosaErrorValidator.evScaleNLargerMax,
4717 TosaErrorValidator.evScaleDLargerMax,
4718 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004719 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004720 TosaErrorValidator.evBorderSmallerMin,
4721 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004722 TosaErrorValidator.evWrongInputType,
4723 TosaErrorValidator.evWrongOutputType,
4724 TosaErrorValidator.evWrongRank,
4725 TosaErrorValidator.evWrongInputList,
4726 TosaErrorValidator.evWrongOutputList,
4727 TosaErrorValidator.evBatchMismatch,
4728 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004729 TosaErrorValidator.evResizeOutputShapeMismatch,
4730 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004731 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004732 "data_gen": {
4733 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4734 },
4735 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004736 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004737 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004738 "cast": {
4739 "op": Op.CAST,
4740 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004741 "build_fcn": (
4742 build_cast,
4743 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004744 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004745 TosaArgGen.agCast,
4746 ),
James Ward8b390432022-08-12 20:48:56 +01004747 "types": (
4748 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004749 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004750 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004751 DType.INT8,
4752 DType.INT16,
4753 DType.INT32,
4754 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004755 DType.FP8E4M3,
4756 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004757 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004758 "error_if_validators": (
4759 TosaErrorValidator.evWrongInputType,
4760 TosaErrorValidator.evWrongOutputType,
4761 TosaErrorValidator.evWrongInputList,
4762 TosaErrorValidator.evWrongOutputList,
4763 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004764 "data_gen": {
4765 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4766 },
4767 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004768 },
4769 "rescale": {
4770 "op": Op.RESCALE,
4771 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004772 "build_fcn": (
4773 build_rescale,
4774 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00004775 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004776 TosaArgGen.agRescale,
4777 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004778 "types": [
4779 DType.UINT8,
4780 DType.INT8,
4781 DType.INT16,
4782 DType.INT32,
4783 DType.INT48,
4784 DType.UINT16,
4785 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004786 "error_if_validators": (
4787 TosaErrorValidator.evInputZeroPointNotZero,
4788 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004789 TosaErrorValidator.evU16InputZeroPointNotValid,
4790 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004791 TosaErrorValidator.evScaleTrue,
4792 TosaErrorValidator.evScaleNotTrue,
4793 TosaErrorValidator.evWrongInputType,
4794 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004795 TosaErrorValidator.evWrongInputList,
4796 TosaErrorValidator.evWrongOutputList,
4797 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004798 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004799 # Custom
4800 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004801 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004802 # Two varients of cond_if, one that generates one of two constant tensors (no
4803 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4804 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004805 "cond_if_const": {
4806 "op": Op.COND_IF,
4807 "operands": (0, 2),
4808 "build_fcn": (
4809 build_cond_if_const,
4810 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004811 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004812 TosaArgGen.agCondIf,
4813 ),
4814 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004815 "error_if_validators": (
4816 TosaErrorValidator.evOutputListThenGraphMismatch,
4817 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004818 TosaErrorValidator.evCondIfCondNotMatchingBool,
4819 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004820 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004821 },
4822 "cond_if_binary": {
4823 "op": Op.COND_IF,
4824 "operands": (2, 0),
4825 "build_fcn": (
4826 build_cond_if_binary,
4827 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004828 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004829 TosaArgGen.agCondIf,
4830 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004831 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004832 "error_if_validators": (
4833 TosaErrorValidator.evInputListThenGraphMismatch,
4834 TosaErrorValidator.evInputListElseGraphMismatch,
4835 TosaErrorValidator.evOutputListThenGraphMismatch,
4836 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004837 TosaErrorValidator.evCondIfCondNotMatchingBool,
4838 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004839 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004840 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004841 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004842 "while_loop": {
4843 "op": Op.WHILE_LOOP,
4844 "operands": (0, 1),
4845 "build_fcn": (
4846 build_while_loop,
4847 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004848 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004849 TosaArgGen.agWhileLoop,
4850 ),
4851 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004852 "error_if_validators": (
4853 TosaErrorValidator.evInputListOutputListMismatch,
4854 TosaErrorValidator.evInputListCondGraphMismatch,
4855 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4856 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4857 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004858 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004859 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004860 },
Luke Hutton57287132023-02-06 14:54:18 +00004861 "fft2d": {
4862 "op": Op.FFT2D,
4863 "operands": (2, 0),
4864 "rank": (3, 3),
4865 "build_fcn": (
4866 build_fft2d,
4867 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004868 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004869 TosaArgGen.agFFT2d,
4870 ),
4871 "types": [DType.FP32],
4872 "error_if_validators": (
4873 TosaErrorValidator.evWrongInputType,
4874 TosaErrorValidator.evWrongOutputType,
4875 TosaErrorValidator.evWrongInputList,
4876 TosaErrorValidator.evWrongOutputList,
4877 TosaErrorValidator.evWrongRank,
4878 TosaErrorValidator.evBatchMismatch,
4879 TosaErrorValidator.evKernelNotPowerOfTwo,
4880 TosaErrorValidator.evFFTInputShapeMismatch,
4881 TosaErrorValidator.evFFTOutputShapeMismatch,
4882 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004883 "data_gen": {
4884 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4885 },
Luke Hutton57287132023-02-06 14:54:18 +00004886 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004887 "rfft2d": {
4888 "op": Op.RFFT2D,
4889 "operands": (1, 0),
4890 "rank": (3, 3),
4891 "build_fcn": (
4892 build_rfft2d,
4893 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004894 TosaTensorValuesGen.tvgLazyGenDefault,
4895 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004896 ),
4897 "types": [DType.FP32],
4898 "error_if_validators": (
4899 TosaErrorValidator.evWrongInputType,
4900 TosaErrorValidator.evWrongOutputType,
4901 TosaErrorValidator.evWrongInputList,
4902 TosaErrorValidator.evWrongOutputList,
4903 TosaErrorValidator.evWrongRank,
4904 TosaErrorValidator.evBatchMismatch,
4905 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004906 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004907 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004908 "data_gen": {
4909 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4910 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004911 },
Won Jeon74342e52024-01-09 00:34:40 +00004912 # Shape
4913 "add_shape": {
4914 "op": Op.ADD_SHAPE,
4915 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004916 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004917 "build_fcn": (
4918 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004919 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004920 TosaTensorValuesGen.tvgAddSub,
4921 TosaArgGen.agNone,
4922 ),
4923 "types": [DType.SHAPE],
4924 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4925 },
4926 "sub_shape": {
4927 "op": Op.SUB_SHAPE,
4928 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004929 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004930 "build_fcn": (
4931 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004932 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004933 TosaTensorValuesGen.tvgAddSub,
4934 TosaArgGen.agNone,
4935 ),
4936 "types": [DType.SHAPE],
4937 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4938 },
4939 "mul_shape": {
4940 "op": Op.MUL_SHAPE,
4941 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004942 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004943 "build_fcn": (
4944 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004945 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004946 TosaTensorValuesGen.tvgMul,
4947 TosaArgGen.agNone,
4948 ),
4949 "types": [DType.SHAPE],
4950 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4951 },
4952 "div_shape": {
4953 "op": Op.DIV_SHAPE,
4954 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004955 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004956 "build_fcn": (
4957 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004958 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004959 TosaTensorValuesGen.tvgIntDiv,
4960 TosaArgGen.agNone,
4961 ),
4962 "types": [DType.SHAPE],
4963 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4964 },
4965 "concat_shape": {
4966 "op": Op.CONCAT_SHAPE,
4967 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004968 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004969 "build_fcn": (
4970 build_concat,
4971 TosaTensorGen.tgConcat,
4972 TosaTensorValuesGen.tvgConcat,
4973 TosaArgGen.agNone,
4974 ),
4975 "types": [DType.SHAPE],
4976 "error_if_validators": (),
4977 },
4978 "const_shape": {
4979 "op": Op.CONST_SHAPE,
4980 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004981 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004982 "build_fcn": (
4983 build_const,
4984 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004985 TosaTensorValuesGen.tvgLazyGenDefault,
4986 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00004987 ),
4988 "types": [DType.SHAPE],
4989 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004990 }
4991
Kevin Cheng550ccc52021-03-03 11:21:43 -08004992
Eric Kunzee5e26762020-10-13 16:11:07 -07004993class OutputShaper:
4994 # Methods in this class compute the expected output shape and datatype
4995 # for common classes of operations
4996 def __init__(self):
4997 pass
4998
4999 # These methods return arguments that can be used for
5000 # creating a new output tensor
5001 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005002 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5003 if error_name != ErrorIf.RankMismatch:
5004 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005005 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005006
5007 shape = []
5008 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005009 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005010 shape.append(b.shape[i])
5011 else:
5012 shape.append(a.shape[i])
5013
Jerry Ge135c9552023-05-23 20:59:32 +00005014 fuzz_idx = rng.integers(0, len(a.shape))
5015 if error_name == ErrorIf.DimensionMismatch:
5016 shape[fuzz_idx] += 1
5017
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005018 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005019 all_dtypes = [
5020 DType.INT8,
5021 DType.INT16,
5022 DType.INT32,
5023 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005024 DType.FP16,
5025 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005026 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005027 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005028 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5029 outputDType = rng.choice(wrong_dtypes)
5030 else:
5031 outputDType = a.dtype
5032
5033 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005034
5035 @staticmethod
5036 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005037 assert len(a.shape) == len(b.shape)
5038 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005039
5040 shape = []
5041 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005042 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005043 shape.append(a.shape[i])
5044
Kevin Cheng550ccc52021-03-03 11:21:43 -08005045 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005046
5047 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005048 def unaryOp(ser, rng, a, error_name=None):
5049 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005050 all_dtypes = [
5051 DType.INT8,
5052 DType.INT16,
5053 DType.INT32,
5054 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005055 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005056 DType.FP16,
5057 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005058 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005059 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5060 outputDType = rng.choice(wrong_dtypes)
5061 else:
5062 outputDType = a.dtype
5063
5064 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005065
5066 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005067 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005068 if error_name != ErrorIf.RankMismatch:
5069 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005070 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005071
5072 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005073 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005074 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005075 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5076 else:
5077 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005078
Jerry Ge135c9552023-05-23 20:59:32 +00005079 fuzz_idx = rng.integers(0, len(a.shape))
5080 if error_name == ErrorIf.DimensionMismatch:
5081 shape[fuzz_idx] += 1
5082
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005083 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005084 all_dtypes = [
5085 DType.INT8,
5086 DType.INT16,
5087 DType.INT32,
5088 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005089 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005090 DType.FP16,
5091 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005092 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005093 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5094 outputDType = rng.choice(wrong_dtypes)
5095 else:
5096 outputDType = a.dtype
5097
5098 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005099
5100 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005101 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005102 if error_name != ErrorIf.RankMismatch:
5103 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005104 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005105
5106 # Do broadcast
5107 shape = []
5108 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005109 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005110 shape.append(b.shape[i])
5111 else:
5112 shape.append(a.shape[i])
5113
Jerry Ge135c9552023-05-23 20:59:32 +00005114 fuzz_idx = rng.integers(0, len(a.shape))
5115 if error_name == ErrorIf.DimensionMismatch:
5116 shape[fuzz_idx] += 1
5117
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005118 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005119 wrong_dtypes = [
5120 DType.INT8,
5121 DType.INT16,
5122 DType.INT32,
5123 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005124 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005125 DType.FP16,
5126 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005127 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005128 outputDType = rng.choice(wrong_dtypes)
5129 else:
5130 outputDType = DType.BOOL
5131
5132 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005133
5134 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005135 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005136 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005137 if error_name not in [
5138 ErrorIf.AxisSmallerZero,
5139 ErrorIf.AxisLargerRank,
5140 ErrorIf.ShapeOfAxisNotOne,
5141 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005142 shape[axis] = 1
5143 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5144 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005145
Matthew Haddond6ce7252021-09-29 15:35:44 +01005146 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005147 all_dtypes = [
5148 DType.INT8,
5149 DType.INT16,
5150 DType.INT32,
5151 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005152 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005153 DType.FP16,
5154 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005155 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005156 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5157 outputDType = rng.choice(wrong_dtypes)
5158 else:
5159 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005160
Matthew Haddond6ce7252021-09-29 15:35:44 +01005161 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005162
5163 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005164 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005165 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005166
5167 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5168 del shape[axis]
5169
5170 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5171 remove = rng.choice([True, False])
5172 if remove and len(shape) > 1:
5173 del shape[0]
5174 else:
5175 shape.append(1)
5176 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5177 for i in range(len(shape)):
5178 shape[i] = shape[i] + rng.integers(1, 10)
5179
5180 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005181 all_dtypes = [
5182 DType.INT8,
5183 DType.INT16,
5184 DType.INT32,
5185 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005186 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005187 DType.FP16,
5188 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005189 DType.FP8E4M3,
5190 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005191 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005192 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5193 outputDType = rng.choice(wrong_dtypes)
5194 else:
5195 outputDType = DType.INT32
5196
5197 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005198
5199 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005200 def conv2dOp(
5201 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5202 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005203
5204 # IFM: NHWC
5205 # Filter: OHWI
5206 # OFM: NHWC
5207
Kevin Cheng550ccc52021-03-03 11:21:43 -08005208 h = (
5209 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005210 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005211 + padding[0]
5212 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005213 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005214 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005215
Kevin Cheng550ccc52021-03-03 11:21:43 -08005216 w = (
5217 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005218 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005219 + padding[2]
5220 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005221 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005222 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005223
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005224 if error_name == ErrorIf.ConvOutputShapeMismatch:
5225 choices = [1, 2, 3]
5226 change = rng.choice(choices)
5227 # increment in multiples of stride to not hit non-integer error case
5228 if change in [1, 3]:
5229 h = h + (rng.choice(choices) * strides[0])
5230 if change in [2, 3]:
5231 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005232
Eric Kunzee5e26762020-10-13 16:11:07 -07005233 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5234
James Ward8b390432022-08-12 20:48:56 +01005235 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005236 # Pick some potentially correct output dtype if input type is incorrect
5237 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005238 else:
James Ward8b390432022-08-12 20:48:56 +01005239 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005240
5241 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005242 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005243 excludes = [DType.FP16, DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00005244 if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
5245 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005246 else:
5247 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005248 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005249 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005250
Kevin Cheng550ccc52021-03-03 11:21:43 -08005251 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005252
5253 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005254 def conv3dOp(
5255 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5256 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005257
5258 # IFM: NDHWC
5259 # Filter: ODHWI
5260 # OFM: NDHWC
5261
5262 d = (
5263 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005264 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005265 + padding[0]
5266 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005267 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005268 ) // strides[0] + 1
5269
5270 h = (
5271 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005272 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005273 + padding[2]
5274 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005275 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005276 ) // strides[1] + 1
5277
5278 w = (
5279 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005280 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005281 + padding[4]
5282 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005283 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005284 ) // strides[2] + 1
5285
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005286 if error_name == ErrorIf.ConvOutputShapeMismatch:
5287 choices = [1, 2, 3, 4]
5288 change = rng.choice(choices)
5289 # increment in multiples of stride to not hit non-integer error case
5290 if change in [1, 4]:
5291 d = d + (rng.choice(choices) * strides[0])
5292 if change in [2, 4]:
5293 h = h + (rng.choice(choices) * strides[1])
5294 if change in [3, 4]:
5295 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005296
Kevin Cheng1533b852021-09-01 12:51:58 -07005297 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5298
James Ward8b390432022-08-12 20:48:56 +01005299 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005300 # Pick some potentially correct output dtype if input type is incorrect
5301 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005302 else:
James Ward8b390432022-08-12 20:48:56 +01005303 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005304
5305 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005306 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005307 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005308 else:
5309 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005310 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005311 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005312
5313 return ser.addOutput(ofm_shape, out_dtype)
5314
5315 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005316 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005317 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005318 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005319 # IFM: NHWC
5320 # Filter: HWCM
5321 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005322
Kevin Cheng550ccc52021-03-03 11:21:43 -08005323 h = (
5324 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005325 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005326 + padding[0]
5327 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005328 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005329 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005330
Kevin Cheng550ccc52021-03-03 11:21:43 -08005331 w = (
5332 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005333 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005334 + padding[2]
5335 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005336 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005337 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005338
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005339 if error_name == ErrorIf.ConvOutputShapeMismatch:
5340 choices = [1, 2, 3]
5341 change = rng.choice(choices)
5342 # increment in multiples of stride to not hit non-integer error case
5343 if change in [1, 3]:
5344 h = h + (rng.choice(choices) * strides[0])
5345 if change in [2, 3]:
5346 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005347
Eric Kunzee5e26762020-10-13 16:11:07 -07005348 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5349
James Ward8b390432022-08-12 20:48:56 +01005350 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005351 # Pick some potentially correct output dtype if input type is incorrect
5352 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005353 else:
James Ward8b390432022-08-12 20:48:56 +01005354 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005355
5356 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005357 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005358 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005359 else:
5360 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005361 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005362 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005363
Kevin Cheng550ccc52021-03-03 11:21:43 -08005364 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005365
5366 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005367 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005368 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005369 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005370 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005371 h = 1
5372 w = 1
5373 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005374 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5375 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005376
5377 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005378 choices = [1, 2, 3]
5379 change = rng.choice(choices)
5380 # increment in multiples of stride to not hit non-integer error case
5381 if change in [1, 3]:
5382 h = h + (rng.choice(choices) * stride[0])
5383 if change in [2, 3]:
5384 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005385 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005386
5387 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005388 all_dtypes = [
5389 DType.INT8,
5390 DType.INT16,
5391 DType.INT32,
5392 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005393 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005394 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005395 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005396 DType.FP8E4M3,
5397 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005398 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005399 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5400 outputDType = rng.choice(wrong_dtypes)
5401 else:
5402 outputDType = ifm.dtype
5403
5404 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005405
5406 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005407 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005408 # input: N, IC
5409 # filter: OC, IC
5410 # output: N, OC
5411
5412 output_shape = [input.shape[0], filter.shape[0]]
5413
James Ward8b390432022-08-12 20:48:56 +01005414 # Validated in arg_gen (also invalidated for ErrorIf)
5415 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005416
Kevin Cheng550ccc52021-03-03 11:21:43 -08005417 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005418
5419 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005420 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005421 # a: N, H, C
5422 # b: N, C, W
5423 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005424
Kevin Cheng2d60f002021-06-09 14:18:32 -07005425 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005426
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005427 if error_name == ErrorIf.WrongOutputType:
5428 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005429 incorrect_types = (
5430 DType.INT4,
5431 DType.INT8,
5432 DType.INT16,
5433 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005434 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005435 DType.FP16,
5436 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005437 DType.FP8E4M3,
5438 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005439 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005440 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005441 incorrect_types = (
5442 DType.INT4,
5443 DType.INT8,
5444 DType.INT16,
5445 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005446 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005447 DType.FP16,
5448 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005449 DType.FP8E4M3,
5450 DType.FP8E5M2,
5451 )
5452 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5453 incorrect_types = (
5454 DType.INT4,
5455 DType.INT8,
5456 DType.INT16,
5457 DType.INT32,
5458 DType.INT48,
5459 DType.FP32,
5460 DType.BF16,
5461 DType.FP8E4M3,
5462 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005463 )
James Ward24dbc422022-10-19 12:20:31 +01005464 elif (
5465 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5466 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005467 incorrect_types = (
5468 DType.INT4,
5469 DType.INT8,
5470 DType.INT16,
5471 DType.INT32,
5472 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005473 DType.FP8E4M3,
5474 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005475 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005476 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005477 elif error_name == ErrorIf.WrongInputType:
5478 # Pick some potentially correct output dtype if input type is incorrect
5479 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005480 else:
James Ward8b390432022-08-12 20:48:56 +01005481 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005482
Kevin Cheng550ccc52021-03-03 11:21:43 -08005483 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005484
5485 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005486 def concatOp(ser, rng, axis, inputs, error_name=None):
5487 input1 = inputs[0]
5488 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005489
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005490 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005491 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005492 if not (
5493 # unable to concat tensors of different ranks
5494 error_name == ErrorIf.ConcatInputRankMismatch
5495 # unable to concat tensors along an invalid axis
5496 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005497 ):
5498 for tensor in remaining_inputs:
5499 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005500
Matthew Haddon01c359d2021-10-15 16:30:48 +01005501 if error_name == ErrorIf.ConcatShapeSumMismatch:
5502 output_shape[axis] += rng.integers(5, 10)
5503
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005504 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005505 all_dtypes = {
5506 DType.INT8,
5507 DType.INT16,
5508 DType.INT32,
5509 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005510 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005511 DType.FP16,
5512 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005513 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005514 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5515 outputDType = rng.choice(wrong_dtypes)
5516 else:
5517 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005518
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005519 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005520
5521 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005522 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005523
5524 output_shape = a.shape.copy()
5525
5526 for i in range(len(output_shape)):
5527 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5528
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005529 if error_name == ErrorIf.PadOutputShapeMismatch:
5530 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005531 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005532 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005533 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005534
Matthew Haddone807aae2021-10-11 18:12:58 +01005535 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005536 all_dtypes = [
5537 DType.INT8,
5538 DType.INT16,
5539 DType.INT32,
5540 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005541 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005542 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005543 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005544 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005545 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5546 outputDType = rng.choice(wrong_dtypes)
5547 else:
5548 outputDType = a.dtype
5549
5550 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005551
5552 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005553 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005554 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005555
5556 if error_name == ErrorIf.WrongOutputType:
5557 all_dtypes = [
5558 DType.INT8,
5559 DType.INT16,
5560 DType.INT32,
5561 DType.INT48,
5562 DType.FP32,
5563 DType.FP16,
5564 DType.BF16,
5565 ]
5566 wrong_dtypes = list(set(all_dtypes))
5567 outputDType = rng.choice(wrong_dtypes)
5568 else:
5569 outputDType = DType.SHAPE
5570
5571 return ser.addOutput(output_shape, outputDType)
5572
5573 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005574 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005575 output_shape = shape.copy()
5576
Matthew Haddone807aae2021-10-11 18:12:58 +01005577 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5578 for i in range(len(output_shape)):
5579 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5580
5581 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005582 all_dtypes = [
5583 DType.INT8,
5584 DType.INT16,
5585 DType.INT32,
5586 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005587 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005588 DType.FP16,
5589 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005590 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005591 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5592 outputDType = rng.choice(wrong_dtypes)
5593 else:
5594 outputDType = a.dtype
5595
5596 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005597
5598 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005599 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005600
Matthew Haddone807aae2021-10-11 18:12:58 +01005601 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005602 all_dtypes = [
5603 DType.INT8,
5604 DType.INT16,
5605 DType.INT32,
5606 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005607 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005608 DType.FP16,
5609 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005610 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005611 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005612 outputDType = rng.choice(wrong_dtypes)
5613 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005614 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005615
Luke Huttona4e48ca2023-02-22 11:53:48 +00005616 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005617 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005618 for index in range(len(output_shape)):
5619 if output_shape[index] <= 2:
5620 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5621 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005622 output_shape[index] = output_shape[index] + rng.choice(
5623 [-2, -1, 1, 2]
5624 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005625 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5626 output_shape = input.shape.copy()
5627 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005628 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005629
5630 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005631
5632 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005633 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005634
5635 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005636 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005637
5638 for i in range(len(output_shape)):
5639 output_shape[i] = a.shape[i] * multiples[i]
5640
Luke Huttona4e48ca2023-02-22 11:53:48 +00005641 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005642 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005643
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005644 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005645 all_dtypes = [
5646 DType.INT8,
5647 DType.INT16,
5648 DType.INT32,
5649 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005650 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005651 DType.FP16,
5652 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005653 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005654 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5655 outputDType = rng.choice(wrong_dtypes)
5656 else:
5657 outputDType = a.dtype
5658
5659 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005660
5661 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005662 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005663 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005664
Kevin Cheng550ccc52021-03-03 11:21:43 -08005665 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005666
Luke Huttona4e48ca2023-02-22 11:53:48 +00005667 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005668 for i in range(len(output_shape)):
5669 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005670
Luke Huttona4e48ca2023-02-22 11:53:48 +00005671 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5672 for i in range(len(output_shape)):
5673 output_shape[i] += rng.integers(1, 10)
5674 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005675 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005676
Matthew Haddone807aae2021-10-11 18:12:58 +01005677 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005678 all_dtypes = [
5679 DType.INT8,
5680 DType.INT16,
5681 DType.INT32,
5682 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005683 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005684 DType.FP16,
5685 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005686 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005687 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5688 outputDType = rng.choice(wrong_dtypes)
5689 else:
5690 outputDType = a.dtype
5691
5692 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005693
5694 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005695 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005696 if error_name != ErrorIf.WrongRank:
5697 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005698 assert len(indices.shape) == 2
5699 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005700
Kevin Cheng77d0f762020-11-24 10:26:32 -08005701 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5702
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005703 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005704 all_dtypes = [
5705 DType.INT8,
5706 DType.INT16,
5707 DType.INT32,
5708 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005709 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005710 DType.FP16,
5711 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005712 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005713 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5714 outputDType = rng.choice(wrong_dtypes)
5715 else:
5716 outputDType = values.dtype
5717
5718 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005719
5720 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005721 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005722 if error_name != ErrorIf.WrongRank:
5723 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005724 assert len(indices.shape) == 2
5725 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005726 assert values_in.shape[0] == indices.shape[0] # N
5727 assert input.shape[1] == indices.shape[1] # W
5728 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005729
5730 output_shape = values_in.shape
5731
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005732 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005733 all_dtypes = [
5734 DType.INT8,
5735 DType.INT16,
5736 DType.INT32,
5737 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005738 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005739 DType.FP16,
5740 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005741 DType.FP8E4M3,
5742 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005743 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005744 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5745 outputDType = rng.choice(wrong_dtypes)
5746 else:
5747 outputDType = values_in.dtype
5748
5749 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005750
5751 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005752 def tableOp(ser, rng, input, error_name=None):
5753 # Same shape as the input, dtype dependent on input dtype
5754 if error_name != ErrorIf.WrongInputType:
5755 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005756 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005757 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005758 wrong_dtypes = [
5759 DType.INT8,
5760 DType.INT16,
5761 DType.INT32,
5762 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005763 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005764 DType.FP16,
5765 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005766 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005767 wrong_dtypes.remove(output_dtype)
5768 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005769 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005770
5771 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005772 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005773 serializer,
5774 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005775 input,
5776 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005777 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005778 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005779 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005780 input_dtype,
5781 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005782 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005783 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005784 # Calculate OH, OW
5785 scale_y_n = scale[0]
5786 scale_y_d = scale[1]
5787 scale_x_n = scale[2]
5788 scale_x_d = scale[3]
5789 if error_name == ErrorIf.ScaleSmallerEqualZero:
5790 scale_y_n = max(scale_y_n, 1)
5791 scale_y_d = max(scale_y_d, 1)
5792 scale_x_n = max(scale_x_n, 1)
5793 scale_x_d = max(scale_x_d, 1)
5794
5795 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5796 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5797
5798 if error_name is not None:
5799 # Make sure the output tensor is valid, which can occur when
5800 # scale, offset or border have been changed for ERROR_IFs
5801 oh = max(oh, 1)
5802 ow = max(ow, 1)
5803 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005804 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5805 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005806
5807 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5808 choices = [1, 2, 3]
5809 change = rng.choice(choices)
5810 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5811 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005812 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005813 oh -= scale_y_d
5814 assert oh > 0 # Should have been caught in agResize
5815 else:
5816 oh += scale_y_d
5817 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005818 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005819 ow -= scale_x_d
5820 assert ow > 0 # Should have been caught in agResize
5821 else:
5822 ow += scale_x_d
5823
Matthew Haddon848efb42021-09-09 12:30:53 +01005824 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005825 output_dims = [
5826 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005827 oh,
5828 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005829 input.shape[0],
5830 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005831 elif error_name == ErrorIf.BatchMismatch:
5832 output_dims = [
5833 input.shape[0] + rng.integers(1, 10),
5834 oh,
5835 ow,
5836 input.shape[3],
5837 ]
5838 elif error_name == ErrorIf.ChannelMismatch:
5839 output_dims = [
5840 input.shape[0],
5841 oh,
5842 ow,
5843 input.shape[3] + rng.integers(1, 10),
5844 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005845 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005846 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005847
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005848 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005849
5850 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005851 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005852 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005853
5854 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005855 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005856 if error_name == ErrorIf.ConvOutputShapeMismatch:
5857 choices = [1, 2, 3]
5858 change = rng.choice(choices)
5859 if change in [1, 3]:
5860 output_shape[1] = output_shape[1] + rng.choice(choices)
5861 if change in [2, 3]:
5862 output_shape[2] = output_shape[2] + rng.choice(choices)
5863
James Ward8b390432022-08-12 20:48:56 +01005864 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005865 # Pick some potentially correct output dtype if input type is incorrect
5866 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005867 else:
James Ward8b390432022-08-12 20:48:56 +01005868 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005869
5870 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005871 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005872 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005873 else:
5874 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005875 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005876 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005877
Kevin Cheng550ccc52021-03-03 11:21:43 -08005878 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005879
5880 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005881 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5882 outputs = []
5883
5884 assert ifm1.dtype == ifm2.dtype
5885 input_dtype = ifm1.dtype
5886
5887 if error_name != ErrorIf.FFTInputShapeMismatch:
5888 assert ifm1.shape == ifm2.shape
5889
5890 input_shape = ifm1.shape
5891 if error_name != ErrorIf.WrongRank:
5892 assert len(input_shape) == 3
5893
5894 output_shape = input_shape.copy()
5895 output_dtype = input_dtype
5896
5897 if error_name == ErrorIf.WrongOutputType:
5898 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005899 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005900 output_dtype = rng.choice(wrong_dtypes)
5901 elif error_name == ErrorIf.BatchMismatch:
5902 output_shape[0] += rng.integers(1, 10)
5903 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5904 modify_dim = rng.choice([1, 2])
5905 output_shape[modify_dim] += rng.integers(1, 10)
5906
5907 outputs.append(serializer.addOutput(output_shape, output_dtype))
5908 outputs.append(serializer.addOutput(output_shape, output_dtype))
5909 return outputs
5910
5911 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005912 def rfft2dOp(serializer, rng, value, error_name=None):
5913 outputs = []
5914
5915 input_shape = value.shape
5916 if error_name != ErrorIf.WrongRank:
5917 assert len(input_shape) == 3
5918
5919 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5920
5921 output_dtype = value.dtype
5922 if error_name == ErrorIf.WrongOutputType:
5923 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005924 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005925 output_dtype = rng.choice(wrong_dtypes)
5926 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005927 output_shape[0] += rng.integers(1, 10)
5928 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5929 modify_dim = rng.choice([1, 2])
5930 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005931
5932 outputs.append(serializer.addOutput(output_shape, output_dtype))
5933 outputs.append(serializer.addOutput(output_shape, output_dtype))
5934 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005935
5936 @staticmethod
5937 def addShapeOp(ser, rng, a, b, error_name=None):
5938 if error_name != ErrorIf.RankMismatch:
5939 assert len(a.shape) == len(b.shape)
5940 assert a.dtype == b.dtype
5941
5942 shape = []
5943 for i in range(len(a.shape)):
5944 shape.append(a.shape[i])
5945
5946 fuzz_idx = rng.integers(0, len(a.shape))
5947 if error_name == ErrorIf.DimensionMismatch:
5948 shape[fuzz_idx] += 1
5949
5950 if error_name == ErrorIf.WrongOutputType:
5951 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5952 outputDType = rng.choice(wrong_dtypes)
5953 else:
5954 outputDType = DType.SHAPE
5955 return ser.addOutput(shape, outputDType)