blob: 978e735a017510653098604ebe3bb8e08819a472 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01005from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01006from datetime import datetime
7from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07008
Jeremy Johnson1271c442023-09-05 11:39:26 +01009import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000011import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010012from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_arg_gen import TosaArgGen
14from generator.tosa_arg_gen import TosaQuantGen
15from generator.tosa_arg_gen import TosaTensorGen
16from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000017from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010018from generator.tosa_error_if import TosaErrorIfArgGen
19from generator.tosa_error_if import TosaErrorValidator
20from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson1271c442023-09-05 11:39:26 +010021from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
Jeremy Johnson1271c442023-09-05 11:39:26 +010025TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
26// SPDX-License-Identifier: Apache-2.0
27// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
28"""
29
Matthew Haddonb724efc2021-08-25 16:40:29 +010030
Eric Kunzee5e26762020-10-13 16:11:07 -070031class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010032 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000033 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010034 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010035 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010036 TOSA_8K_LEVEL_MAX_KERNEL = 8192
37 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010038
Jeremy Johnson1271c442023-09-05 11:39:26 +010039 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000040 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010041 TOSA_MI_DOT_PRODUCT_MIN = 1000
42
Eric Kunzee5e26762020-10-13 16:11:07 -070043 def __init__(self, args):
44 self.args = args
45 self.basePath = args.output_dir
46 self.random_seed = args.random_seed
47 self.ser = None
48 self.rng = np.random.default_rng(self.random_seed)
49 self.createDynamicOpLists()
50 self.initOpListDefaults()
51 self.quantGen = TosaQuantGen()
52 # Force makeShape to do a specific starting shape
53 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010054 # JSON schema validation
55 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010056 # Data generator library is sometimes needed for compliance set up
57 # even if we are generating the data later (lazy_data_generation)
58 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010060 # Work out floating point range
61 def convertFPRange(rangeFP, maxFP):
62 # Converts program arguments of max/-max to FP max
63 vals = []
64 for v in rangeFP:
65 if v == "max":
66 v = maxFP
67 elif v == "-max":
68 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000069 elif v < 0:
70 # Trim to minimum data type value
71 v = max(v, -maxFP)
72 elif v > 0:
73 # Trim to maximum data type value
74 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010075 vals.append(v)
76 return tuple(sorted(vals))
77
78 self.random_float_range = {}
Won Jeon2c34b462024-02-06 18:37:00 +000079 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010080 self.random_float_range[dtype] = convertFPRange(
81 args.tensor_fp_value_range,
82 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
83 )
84
Eric Kunzee5e26762020-10-13 16:11:07 -070085 def createSerializer(self, opName, testPath):
86 self.testPath = os.path.join(opName, testPath)
87
88 fullPath = os.path.join(self.basePath, self.testPath)
89 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +010090 # Embed const data in the flatbuffer
91 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +010092 if self.args.lazy_data_gen:
93 # Lazy data generation - so make constants files
94 constMode = ts.ConstMode.INPUTS
95 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +010096 constMode = ts.ConstMode.EMBED_DUMP
97 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -070098
99 def getSerializer(self):
100 return self.ser
101
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 def serialize(self, testName, metaData=None):
103 path = Path(self.basePath) / self.testPath
104
105 # Write out TOSA flatbuffer binary
106 path_fb = path / f"{testName}.tosa"
107 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700108 fd.write(self.ser.serialize())
109
Jeremy Johnson1271c442023-09-05 11:39:26 +0100110 # Get JSON descriptor from serializer
111 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
112
113 if metaData:
114 # Add extra meta data to desc.json
115 desc["meta"] = metaData
116
117 # Validate desc.json before we output it
118 self.descSchemaValidator.validate_config(desc)
119
120 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 if "data_gen" in metaData:
122 if self.args.lazy_data_gen:
123 # Output datagen meta data as CPP data
124 path_md = path / f"{testName}_meta_data_gen.cpp"
125 with path_md.open("w") as fd:
126 fd.write(TOSA_AUTOGENERATED_HEADER)
127 fd.write("// Test meta data for data generation setup\n\n")
128 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
129 json.dump(metaData["data_gen"], fd)
130 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 if "compliance" in metaData:
132 # Output datagen meta data as CPP data
133 path_md = path / f"{testName}_meta_compliance.cpp"
134 with path_md.open("w") as fd:
135 fd.write(TOSA_AUTOGENERATED_HEADER)
136 fd.write("// Test meta data for compliance validation\n\n")
137 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
138 json.dump(metaData["compliance"], fd)
139 fd.write(')";\n\n')
140
141 # Write desc.json
142 path_desc = path / "desc.json"
143 with path_desc.open("w") as fd:
144 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
Matthew Haddon74567092021-07-16 15:38:20 +0100146 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000147 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +0100148 seed = self.random_seed + 1
149 self.rng = np.random.default_rng(seed)
150
Jeremy Johnson1271c442023-09-05 11:39:26 +0100151 def getDTypeRange(self, dtype, high_inclusive=False):
152 # Returns dtype value range boundaries (low, high)
153 # The high boundary is excluded in the range
154 # unless high_inclusive is True
Won Jeon2c34b462024-02-06 18:37:00 +0000155 if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100156 return self.random_float_range[dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100157 elif dtype == DType.BOOL:
158 rng = (0, 2)
159 elif dtype == DType.UINT8:
160 rng = (0, 256)
161 elif dtype == DType.UINT16:
162 rng = (0, 65536)
163 elif dtype == DType.INT4:
164 # TOSA specific INT4 weight range from -7 to 7
165 rng = (-7, 8)
166 elif dtype == DType.INT8:
167 rng = (-128, 128)
168 elif dtype == DType.INT16:
169 rng = (-32768, 32768)
Won Jeon74342e52024-01-09 00:34:40 +0000170 elif dtype == DType.INT32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100171 rng = (-(1 << 31), (1 << 31))
Won Jeon74342e52024-01-09 00:34:40 +0000172 elif dtype == DType.SHAPE:
173 rng = tuple(self.args.tensor_shape_range[0:2])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100174 elif dtype == DType.INT48:
175 rng = (-(1 << 47), (1 << 47))
176 else:
177 raise Exception("Unknown dtype: {}".format(dtype))
178
179 if not high_inclusive:
180 # Exclusive high: low <= range < high
181 return rng
182 else:
183 # Inclusive range: low <= range <= high
184 return (rng[0], rng[1] - 1)
185
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000186 def getRandTensor(self, shape, dtype, data_range=None):
187 if data_range is None:
188 low, high = self.getDTypeRange(dtype)
189 else:
190 low, high = data_range
Jeremy Johnson1271c442023-09-05 11:39:26 +0100191
Eric Kunzee5e26762020-10-13 16:11:07 -0700192 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Jerry Gec5291692024-01-02 22:29:08 +0000194 elif dtype == DType.INT8:
195 return np.int8(self.rng.integers(low=low, high=high, size=shape))
196 elif dtype == DType.UINT8:
197 return np.uint8(self.rng.integers(low=low, high=high, size=shape))
Jerry Ge20ab3df2024-01-26 16:56:55 +0000198 elif dtype == DType.INT16:
199 return np.int16(self.rng.integers(low=low, high=high, size=shape))
200 elif dtype == DType.UINT16:
201 return np.uint16(self.rng.integers(low=low, high=high, size=shape))
Won Jeon74342e52024-01-09 00:34:40 +0000202 elif dtype in (DType.INT48, DType.SHAPE):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100203 return np.int64(self.rng.integers(low=low, high=high, size=shape))
Won Jeon2c34b462024-02-06 18:37:00 +0000204 elif dtype in (
205 DType.FP16,
206 DType.BF16,
207 DType.FP32,
208 DType.FP8E4M3,
209 DType.FP8E5M2,
210 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100211 f_tensor = self.rng.uniform(low=low, high=high, size=shape)
212
213 if dtype == DType.FP16:
214 return np.float16(f_tensor)
215 else:
216 f32_tensor = np.float32(f_tensor)
217 if dtype == DType.BF16:
218 # Floor the last 16 bits of each f32 value
219 return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
Won Jeon2c34b462024-02-06 18:37:00 +0000220 elif dtype == DType.FP8E4M3:
221 return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
222 elif dtype == DType.FP8E5M2:
223 return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
Jeremy Johnson1271c442023-09-05 11:39:26 +0100224 else:
225 return f32_tensor
Eric Kunzee5e26762020-10-13 16:11:07 -0700226 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100227 # All other integer types
228 return np.int32(self.rng.integers(low=low, high=high, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
Kevin Cheng989cb052021-04-28 16:29:44 -0700230 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700231 placeholders = []
232
Kevin Cheng989cb052021-04-28 16:29:44 -0700233 assert len(shape_list) == len(dtype_list)
234
Jeremy Johnson1271c442023-09-05 11:39:26 +0100235 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700236 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100237 if not self.args.lazy_data_gen:
238 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700239 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700240
241 return placeholders
242
Kevin Cheng989cb052021-04-28 16:29:44 -0700243 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700244 consts = []
245
Kevin Cheng989cb052021-04-28 16:29:44 -0700246 assert len(shape_list) == len(dtype_list)
247
Jeremy Johnson1271c442023-09-05 11:39:26 +0100248 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700249 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100250 if not self.args.lazy_data_gen:
251 arr = self.getRandTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700252 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700253
254 return consts
255
256 def makeShape(self, rank):
257 if self.targetted_shape:
258 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800259 return np.int32(
260 self.rng.integers(
261 low=self.args.tensor_shape_range[0],
262 high=self.args.tensor_shape_range[1],
263 size=rank,
264 )
265 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700266
267 def setTargetShape(self, shape):
268 self.targetted_shape = shape
269
270 def randInt(self, low=0, high=256):
271 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
272
273 def getRandNumberDType(self, dtype):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100274 low, high = self.getDTypeRange(dtype)
275
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100276 if dtype == DType.FP32:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100277 return np.float32(self.rng.uniform(low=low, high=high))
James Ward8b390432022-08-12 20:48:56 +0100278 elif dtype == DType.FP16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100279 return np.float16(self.rng.uniform(low=low, high=high))
James Ward24dbc422022-10-19 12:20:31 +0100280 elif dtype == DType.BF16:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100281 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
282 return gtu.vect_f32_to_bf16(rand_f32)
Won Jeon2c34b462024-02-06 18:37:00 +0000283 elif dtype == DType.FP8E4M3:
284 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
285 return gtu.vect_f32_to_fp8e4m3(rand_f32)
286 elif dtype == DType.FP8E5M2:
287 rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
288 return gtu.vect_f32_to_fp8e5m2(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700289 elif dtype == DType.BOOL:
290 return self.rng.choice([False, True])
Tai Ly8690a082023-12-18 20:40:24 +0000291 elif dtype == DType.INT48 or dtype == DType.SHAPE:
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 # Special size
293 return np.int64(self.rng.integers(low, high, size=1))[0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700294
295 return np.int32(self.rng.integers(low, high, size=1))[0]
296
297 def shapeStr(self, shape):
298
299 sStr = []
300 # Convert to strings
301 for i in shape:
302 sStr.append(str(i))
303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100306 def typeStr(self, dtype):
307 if isinstance(dtype, list) or isinstance(dtype, tuple):
308 assert len(dtype) >= 2
309 strs = [self.typeStr(t) for t in dtype]
310 # Limit types to the first 2 as the 3rd is the accumulator
311 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700312 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100313 if dtype in gtu.DTYPE_ATTRIBUTES:
314 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700315 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100316 raise Exception(
317 "Unknown dtype, cannot convert to string: {}".format(dtype)
318 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700319
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100320 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100321 """Get the datatype width for data types"""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100322 if dtype in gtu.DTYPE_ATTRIBUTES:
323 return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700324 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100325 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700326
Luke Hutton57287132023-02-06 14:54:18 +0000327 def constrictBatchSize(self, shape):
328 # Limit the batch size unless an explicit target shape set
329 if self.args.max_batch_size and not self.args.target_shapes:
330 shape[0] = min(shape[0], self.args.max_batch_size)
331 return shape
332
James Ward30124a82023-02-02 14:56:33 +0000333 def makeDimension(self):
334 return self.randInt(
335 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
336 )
337
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100338 def tensorComplianceMetaData(
339 self, op, inputType, argsDict, outputTensor, errorName
340 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000341 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
342 UNSUPPORTED_NON_FP32_INPUT_OPS = (
343 Op.MATMUL,
344 Op.CONV2D,
345 Op.FULLY_CONNECTED,
346 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000347 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000348 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000349 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100350 if (
351 errorName
352 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000353 or (
354 not gtu.dtypeIsSupportedByCompliance(inputType)
355 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
356 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100357 ):
358 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100359 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100360
Jeremy Johnson1271c442023-09-05 11:39:26 +0100361 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100362 compliance_tens = {
363 "mode": None,
364 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
365 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
366 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100367 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
368 mode = gtu.ComplianceMode.DOT_PRODUCT
369 compliance_tens["dot_product_info"] = {
370 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100371 "ks": int(argsDict["ksb"])
372 if "ksb" in argsDict
373 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100374 }
375 elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
376 mode = gtu.ComplianceMode.FP_SPECIAL
377 elif "compliance" in op and "ulp" in op["compliance"]:
378 mode = gtu.ComplianceMode.ULP
379 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000380 elif "compliance" in op and "relative" in op["compliance"]:
381 mode = gtu.ComplianceMode.RELATIVE
382 compliance_tens["relative_info"] = {
383 "max": argsDict["max_abs_value"],
384 "scale": op["compliance"]["relative"],
385 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100386 elif op["op"] == Op.REDUCE_PRODUCT:
387 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000388 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000389 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000390 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000391 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
392 compliance_tens["abs_error_info"] = {
393 "lower_bound": op["compliance"]["abs_error_lower_bound"]
394 }
Jerry Ge51bd4f52024-02-20 11:21:19 -0800395 elif op["op"] in (Op.SIN, Op.COS):
396 mode = gtu.ComplianceMode.ABS_ERROR
397 if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
398 compliance_tens["abs_error_info"] = {
399 "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
400 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100401 else:
402 mode = gtu.ComplianceMode.EXACT
403 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
404
405 return compliance_tens
406
407 # Build Op functions
408 # Create the output tensor (calling OutputShaper as needed)
409 # Do final tweaks to attributes (if necessary for errorIf)
410 # Add Op into graph
411 # Return resulting tensor information or BuildInfo
412
413 class BuildInfo:
414 """Enhanced build information containing result tensor and associated compliance dict."""
415
416 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000417 if isinstance(resultTensor, list):
418 assert complianceDict is None or isinstance(complianceDict, list)
419 self.resultTensorList = resultTensor
420 self.complianceDictList = complianceDict
421 else:
422 self.resultTensorList = [resultTensor]
423 if complianceDict is None:
424 self.complianceDictList = None
425 else:
426 self.complianceDictList = [complianceDict]
427
428 def getComplianceInfo(self):
429 if self.complianceDictList is None:
430 return None
431 else:
432 tens_dict = {}
433 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
434 if comp is not None:
435 tens_dict[tens.name] = comp
436
437 if tens_dict:
438 # Have some compliance data, so return the info
439 compliance = {
440 "version": "0.1",
441 "tensors": tens_dict,
442 }
443 else:
444 compliance = None
445 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700446
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000447 def build_unary(
448 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
449 ):
450 assert len(inputs) == 1
451 a = inputs[0]
452 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100453
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000454 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100455
456 # Ensure new output type has correct qinfo
457 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000458 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000459 qinfo = [
460 TosaQuantGen.getZeroPoint(self, a.dtype),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000461 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000462 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100463
464 # Invalidate Input/Output list for error if checks.
465 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000466 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100467 pCount, cCount = op["operands"]
468 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000469 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
470 self, error_name, input_list, output_list
471 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100472
Les Bell729b0352021-11-24 10:28:21 +0000473 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100474 self.ser,
475 validator_fcns,
476 error_name,
477 op=op,
478 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000479 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000480 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000481 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100482 input_list=input_list,
483 output_list=output_list,
484 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000485 ):
486 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100487
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000488 attr = None
489 if op["op"] == Op.NEGATE:
490 attr = ts.TosaSerializerAttribute()
491 attr.NegateAttribute(qinfo[0], qinfo[1])
492
493 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000494
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000495 compliance = self.tensorComplianceMetaData(
496 op, a.dtype, args_dict, result_tensor, error_name
497 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000498 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700499
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000500 def build_binary_broadcast(
501 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
502 ):
503 assert len(inputs) == 2
504 a, b = inputs
505 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000506 self.ser, self.rng, a, b, error_name
507 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100508
509 # Invalidate Input/Output list for error if checks.
510 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000511 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100512 pCount, cCount = op["operands"]
513 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000514 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
515 self, error_name, input_list, output_list
516 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100517
Les Bell729b0352021-11-24 10:28:21 +0000518 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100519 self.ser,
520 validator_fcns,
521 error_name,
522 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000523 input1=a,
524 input2=b,
525 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000526 output_dtype=result_tensor.dtype,
527 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100528 input_list=input_list,
529 output_list=output_list,
530 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000531 ):
532 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100533
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000534 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000535
Jeremy Johnson9a758382023-11-07 16:27:35 +0000536 compliance = self.tensorComplianceMetaData(
537 op, a.dtype, args_dict, result_tensor, error_name
538 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000539
540 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700541
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100542 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700543 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000544 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700545 return result_tens
546
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000547 def build_arithmetic_right_shift(
Jeremy Johnson587cc842024-02-08 11:45:44 +0000548 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000549 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000550 assert len(inputs) == 2
551 a, b = inputs
552 round = args_dict["round"]
553 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000554 self.ser, self.rng, a, b, error_name
555 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100556
557 # Invalidate Input/Output list for error if checks.
558 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000559 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100560 pCount, cCount = op["operands"]
561 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000562 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
563 self, error_name, input_list, output_list
564 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100565
Les Bell729b0352021-11-24 10:28:21 +0000566 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100567 self.ser,
568 validator_fcns,
569 error_name,
570 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000571 input1=a,
572 input2=b,
573 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000574 output_dtype=result_tensor.dtype,
575 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100576 input_list=input_list,
577 output_list=output_list,
578 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000579 ):
580 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800581
582 attr = ts.TosaSerializerAttribute()
583 attr.ArithmeticRightShiftAttribute(round)
584
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000585 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000586
587 compliance = self.tensorComplianceMetaData(
588 op, a.dtype, args_dict, result_tensor, error_name
589 )
590
591 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800592
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100593 def build_mul(
594 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
595 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000596 # Note that mul is binary operator but it has a shift value tensor
597 assert len(inputs) == 3
598 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100599
600 result_tensor = OutputShaper.binaryBroadcastOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000601 self.ser, self.rng, a, b, error_name
602 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700603
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100604 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100605 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100606 result_tensor.setDtype(DType.INT32)
607
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100608 if error_name == ErrorIf.WrongOutputType:
609 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
610 outputDType = self.rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100611 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100612
613 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000614 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100615 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100616 pCount, cCount = op["operands"]
617 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000618 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
619 self, error_name, input_list, output_list
620 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100621
Les Bell729b0352021-11-24 10:28:21 +0000622 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100623 self.ser,
624 validator_fcns,
625 error_name,
626 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000627 input1=a,
628 input2=b,
629 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100630 output_dtype=result_tensor.dtype,
631 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100632 input_list=input_list,
633 output_list=output_list,
634 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000635 ):
636 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700637
Jeremy Johnson0a042992024-02-28 13:20:05 +0000638 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100639
640 compliance = self.tensorComplianceMetaData(
641 op, a.dtype, args_dict, result_tensor, error_name
642 )
643
644 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700645
Jeremy Johnson587cc842024-02-08 11:45:44 +0000646 def build_table(
647 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
648 ):
649 assert len(inputs) == 1
650 a = inputs[0]
651 table = args_dict["table"]
652 result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700653
Kevin Chengfe392ce2021-10-18 21:51:55 +0000654 attr = ts.TosaSerializerAttribute()
655 attr.TableAttribute(table)
656
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100657 # Invalidate Input/Output list for error if checks.
658 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000659 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100660 pCount, cCount = op["operands"]
661 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000662 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
663 self, error_name, input_list, output_list
664 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100665
Les Bell729b0352021-11-24 10:28:21 +0000666 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100667 self.ser,
668 validator_fcns,
669 error_name,
670 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000671 input_shape=a.shape,
672 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000673 output_dtype=result_tensor.dtype,
674 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100675 input_list=input_list,
676 output_list=output_list,
677 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000678 ):
679 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100680
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000681 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700682
Jeremy Johnson587cc842024-02-08 11:45:44 +0000683 compliance = self.tensorComplianceMetaData(
684 op, a.dtype, args_dict, result_tensor, error_name
685 )
686
687 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700688
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000689 def build_select(
690 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
691 ):
692 assert len(inputs) == 3
693 cond, a, b = inputs
694
695 result_tensor = OutputShaper.selectOp(
696 self.ser, self.rng, cond, a, b, error_name
697 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100698
699 # Invalidate Input/Output list for error if checks.
700 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000701 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100702 pCount, cCount = op["operands"]
703 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000704 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
705 self, error_name, input_list, output_list
706 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100707
Les Bell729b0352021-11-24 10:28:21 +0000708 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100709 self.ser,
710 validator_fcns,
711 error_name,
712 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000713 input1=cond,
714 input2=a,
715 input3=b,
716 input_shape=a.shape,
717 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000718 output_dtype=result_tensor.dtype,
719 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100720 input_list=input_list,
721 output_list=output_list,
722 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000723 ):
724 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100725
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000726 self.ser.addOperator(
727 op["op"],
728 input_list,
729 output_list,
730 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000731 compliance = self.tensorComplianceMetaData(
732 op, a.dtype, args_dict, result_tensor, error_name
733 )
734
735 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700736
Jeremy Johnsona0150012023-11-15 15:52:06 +0000737 def build_comparison(
738 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
739 ):
740 assert len(inputs) == 2
741 a, b = inputs
742
743 result_tensor = OutputShaper.binaryComparisonOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000744 self.ser, self.rng, a, b, error_name
745 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100746
747 # Invalidate Input/Output list for error if checks.
748 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000749 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100750 pCount, cCount = op["operands"]
751 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000752 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
753 self, error_name, input_list, output_list
754 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100755
Les Bell729b0352021-11-24 10:28:21 +0000756 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100757 self.ser,
758 validator_fcns,
759 error_name,
760 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000761 input1=a,
762 input2=b,
763 input_shape=a.shape,
764 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000765 output_shape=result_tensor.shape,
766 output_dtype=result_tensor.dtype,
767 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100768 input_list=input_list,
769 output_list=output_list,
770 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000771 ):
772 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100773
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000774 self.ser.addOperator(
775 op["op"],
776 input_list,
777 output_list,
778 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000779
780 compliance = self.tensorComplianceMetaData(
781 op, a.dtype, args_dict, result_tensor, error_name
782 )
783 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700784
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000785 def build_argmax(
786 self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
787 ):
788 assert len(inputs) == 1
789 a = inputs[0]
790 axis = args_dict["axis"]
791 result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100792
793 # Invalidate Input/Output list for error if checks.
794 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000795 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100796 pCount, cCount = op["operands"]
797 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000798 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
799 self, error_name, input_list, output_list
800 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100801
Les Bell729b0352021-11-24 10:28:21 +0000802 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100803 self.ser,
804 validator_fcns,
805 error_name,
806 op=op,
807 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000808 input_shape=a.shape,
809 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000810 output_shape=result_tensor.shape,
811 output_dtype=result_tensor.dtype,
812 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100813 input_list=input_list,
814 output_list=output_list,
815 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000816 ):
817 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700818
819 attr = ts.TosaSerializerAttribute()
820 attr.AxisAttribute(axis)
821
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000822 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000823
824 compliance = self.tensorComplianceMetaData(
825 op, inputs[0].dtype, args_dict, result_tensor, error_name
826 )
827 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700828
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000829 def build_pool2d(
830 self,
831 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100832 inputs,
833 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000834 validator_fcns=None,
835 error_name=None,
836 qinfo=None,
837 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100838 assert len(inputs) == 1
839 input = inputs[0]
840 # max_pool has no accum_dtype
841 accum_dtype = (
842 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
843 )
844 stride = args_dict["stride"]
845 pad = args_dict["pad"]
846 kernel = args_dict["kernel"]
847
Jeremy Johnson0601f802023-11-08 16:28:09 +0000848 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000849 self.ser, self.rng, input, kernel, stride, pad, error_name
850 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100851
852 # Ensure new output type has correct qinfo
853 if error_name == ErrorIf.WrongInputType:
854 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000855 qinfo = [
856 TosaQuantGen.getZeroPoint(self, input.dtype),
Jeremy Johnson0601f802023-11-08 16:28:09 +0000857 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000858 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100859
860 # Invalidate Input/Output list for error if checks.
861 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000862 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100863 pCount, cCount = op["operands"]
864 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000865 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
866 self, error_name, input_list, output_list
867 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100868
Les Bell729b0352021-11-24 10:28:21 +0000869 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100870 self.ser,
871 validator_fcns,
872 error_name,
873 op=op,
874 input_shape=input.shape,
875 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000876 output_shape=result_tensor.shape,
877 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000878 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100879 kernel=kernel,
880 stride=stride,
881 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000882 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000883 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100884 input_list=input_list,
885 output_list=output_list,
886 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000887 ):
888 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700889
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000890 if qinfo is None:
891 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700892
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000893 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100894 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000895
896 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700897
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100898 compliance = self.tensorComplianceMetaData(
899 op, inputs[0].dtype, args_dict, result_tensor, error_name
900 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100901
902 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100903
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000904 def build_conv2d(
905 self,
906 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100907 inputs,
908 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000909 validator_fcns=None,
910 error_name=None,
911 qinfo=None,
912 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100913 assert len(inputs) == 3
914 ifm, filter, bias = inputs
915 accum_dtype = args_dict["acc_type"]
916 strides = args_dict["stride"]
917 padding = args_dict["pad"]
918 dilations = args_dict["dilation"]
919
Kevin Cheng550ccc52021-03-03 11:21:43 -0800920 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100921 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100922 self.ser,
923 self.rng,
924 ifm,
925 filter,
926 accum_dtype,
927 strides,
928 padding,
929 dilations,
930 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000931 )
932
933 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000934 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
935 DType.INT8,
936 DType.UINT8,
937 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000938 qinfo = [
939 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100940 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000941 ]
Les Bell0e027d42021-11-09 14:42:14 +0000942
943 # Invalidate Input/Output list for error_if checks.
944 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100945 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000946 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000947 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
948 self, error_name, input_list, output_list
949 )
Les Bell0e027d42021-11-09 14:42:14 +0000950
Les Bell729b0352021-11-24 10:28:21 +0000951 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000952 self.ser,
953 validator_fcns,
954 error_name,
955 op=op,
956 input_dtype=ifm.dtype,
957 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100958 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000959 qinfo=qinfo,
960 input_list=input_list,
961 num_operands=num_operands,
962 output_list=output_list,
963 pad=padding,
964 stride=strides,
965 dilation=dilations,
966 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100967 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100968 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000969 ):
970 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700971
Tai Lyd3797f02023-11-15 23:06:19 +0000972 # TODO - Test local_bound, for now set local bound attribute to False
973 local_bound = False
974
Eric Kunzee5e26762020-10-13 16:11:07 -0700975 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000976 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700977
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000978 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100979
980 compliance = self.tensorComplianceMetaData(
981 op, ifm.dtype, args_dict, result_tensor, error_name
982 )
983
984 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700985
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000986 def build_conv3d(
987 self,
988 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100989 inputs,
990 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000991 validator_fcns=None,
992 error_name=None,
993 qinfo=None,
994 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100995 assert len(inputs) == 3
996 ifm, filter, bias = inputs
997 accum_dtype = args_dict["acc_type"]
998 strides = args_dict["stride"]
999 padding = args_dict["pad"]
1000 dilations = args_dict["dilation"]
1001
Kevin Cheng1533b852021-09-01 12:51:58 -07001002 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +00001003 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +01001004 self.ser,
1005 self.rng,
1006 ifm,
1007 filter,
1008 accum_dtype,
1009 strides,
1010 padding,
1011 dilations,
1012 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001013 )
1014
1015 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001016 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1017 DType.INT8,
1018 DType.UINT8,
1019 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001020 qinfo = [
1021 TosaQuantGen.getZeroPoint(self, ifm.dtype),
evacha0147ab1762024-01-29 13:23:23 +00001022 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001023 ]
Les Bell0e027d42021-11-09 14:42:14 +00001024
1025 # Invalidate Input/Output list for error_if checks.
1026 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +00001027 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001028 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001029 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1030 self, error_name, input_list, output_list
1031 )
Les Bell0e027d42021-11-09 14:42:14 +00001032
Les Bell729b0352021-11-24 10:28:21 +00001033 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001034 self.ser,
1035 validator_fcns,
1036 error_name,
1037 op=op,
1038 input_dtype=ifm.dtype,
1039 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +00001040 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001041 qinfo=qinfo,
1042 input_list=input_list,
1043 num_operands=num_operands,
1044 output_list=output_list,
1045 pad=padding,
1046 stride=strides,
1047 dilation=dilations,
1048 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001049 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +00001050 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001051 ):
1052 return None
Kevin Cheng1533b852021-09-01 12:51:58 -07001053
Tai Lyd3797f02023-11-15 23:06:19 +00001054 # TODO - Test local_bound, for now set local bound attribute to False
1055 local_bound = False
1056
Kevin Cheng1533b852021-09-01 12:51:58 -07001057 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001058 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -07001059
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001060 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001061
1062 compliance = self.tensorComplianceMetaData(
1063 op, ifm.dtype, args_dict, result_tensor, error_name
1064 )
1065
1066 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001067
Kevin Cheng550ccc52021-03-03 11:21:43 -08001068 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001069 self,
1070 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001071 inputs,
1072 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001073 validator_fcns=None,
1074 error_name=None,
1075 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001076 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001077 assert len(inputs) == 3
1078 ifm, filter, bias = inputs
1079 accum_dtype = args_dict["acc_type"]
1080 strides = args_dict["stride"]
1081 out_pad = args_dict["pad"]
1082 output_shape = args_dict["out_shape"]
1083
TatWai Chong24594f52022-06-08 00:48:04 -07001084 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001085 result_tensor = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +01001086 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001087 )
Les Bell0e027d42021-11-09 14:42:14 +00001088
1089 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001090 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1091 DType.INT8,
1092 DType.UINT8,
1093 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001094 qinfo = [
1095 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson95a67102024-01-10 14:16:39 +00001096 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001097 ]
Les Bell0e027d42021-11-09 14:42:14 +00001098
1099 # Invalidate Input/Output list for error_if checks.
1100 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001101 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001102 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001103 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1104 self, error_name, input_list, output_list
1105 )
Les Bell0e027d42021-11-09 14:42:14 +00001106
Les Bell729b0352021-11-24 10:28:21 +00001107 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001108 self.ser,
1109 validator_fcns,
1110 error_name,
1111 op=op,
1112 input_dtype=ifm.dtype,
1113 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001114 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001115 qinfo=qinfo,
1116 input_list=input_list,
1117 num_operands=num_operands,
1118 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001119 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001120 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001121 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001122 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001123 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001124 ):
1125 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001126
Tai Lyd3797f02023-11-15 23:06:19 +00001127 # TODO - Test local_bound, for now set local bound attribute to False
1128 local_bound = False
1129
Eric Kunzee5e26762020-10-13 16:11:07 -07001130 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001131 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001132 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001133 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001134
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001135 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001136
1137 compliance = self.tensorComplianceMetaData(
1138 op, ifm.dtype, args_dict, result_tensor, error_name
1139 )
1140
1141 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001142
Kevin Cheng550ccc52021-03-03 11:21:43 -08001143 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001144 self,
1145 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001146 inputs,
1147 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001148 validator_fcns=None,
1149 error_name=None,
1150 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001151 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001152 assert len(inputs) == 3
1153 ifm, filter, bias = inputs
1154 accum_dtype = args_dict["acc_type"]
1155 strides = args_dict["stride"]
1156 padding = args_dict["pad"]
1157 dilations = args_dict["dilation"]
1158
Jeremy Johnson4f931302024-01-04 17:05:24 +00001159 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001160 self.ser,
1161 self.rng,
1162 ifm,
1163 filter,
1164 accum_dtype,
1165 strides,
1166 padding,
1167 dilations,
1168 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001169 )
1170
1171 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001172 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1173 DType.INT8,
1174 DType.UINT8,
1175 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001176 qinfo = [
1177 TosaQuantGen.getZeroPoint(self, ifm.dtype),
Jeremy Johnson4f931302024-01-04 17:05:24 +00001178 TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001179 ]
Les Bell0e027d42021-11-09 14:42:14 +00001180
1181 # Invalidate Input/Output list for error_if checks.
1182 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001183 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001184 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001185 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1186 self, error_name, input_list, output_list
1187 )
Les Bell0e027d42021-11-09 14:42:14 +00001188
Les Bell729b0352021-11-24 10:28:21 +00001189 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001190 self.ser,
1191 validator_fcns,
1192 error_name,
1193 op=op,
1194 input_dtype=ifm.dtype,
1195 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001196 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001197 qinfo=qinfo,
1198 input_list=input_list,
1199 num_operands=num_operands,
1200 output_list=output_list,
1201 pad=padding,
1202 stride=strides,
1203 dilation=dilations,
1204 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001205 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001206 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001207 ):
1208 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001209
Tai Lyd3797f02023-11-15 23:06:19 +00001210 # TODO - Test local_bound, for now set local bound attribute to False
1211 local_bound = False
1212
Eric Kunzee5e26762020-10-13 16:11:07 -07001213 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001214 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001215
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001216 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001217
1218 compliance = self.tensorComplianceMetaData(
1219 op, ifm.dtype, args_dict, result_tensor, error_name
1220 )
1221
1222 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001223
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001224 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001225 self,
1226 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001227 inputs,
1228 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001229 validator_fcns=None,
1230 error_name=None,
1231 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001232 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001233 assert len(inputs) == 3
1234 ifm, filter, bias = inputs
1235 accum_dtype = args_dict["acc_type"]
1236
1237 result_tensor = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +01001238 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001239 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001240
1241 # Invalidate Input/Output list for error if checks.
1242 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001243 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001244 pCount, cCount = op["operands"]
1245 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001246 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1247 self, error_name, input_list, output_list
1248 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001249
Les Bell729b0352021-11-24 10:28:21 +00001250 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001251 self.ser,
1252 validator_fcns,
1253 error_name,
1254 op=op,
1255 input_shape=ifm.shape,
1256 input_dtype=ifm.dtype,
1257 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001258 output_shape=result_tensor.shape,
1259 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001260 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001261 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001262 input_list=input_list,
1263 output_list=output_list,
1264 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001265 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001266 ):
1267 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001268
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001269 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001270 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001271
1272 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001273
1274 compliance = self.tensorComplianceMetaData(
1275 op, ifm.dtype, args_dict, result_tensor, error_name
1276 )
1277
1278 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001279
James Ward8b390432022-08-12 20:48:56 +01001280 def build_matmul(
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001281 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
James Ward8b390432022-08-12 20:48:56 +01001282 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001283 assert len(inputs) == 2
1284 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001285 accum_dtype = args_dict["acc_type"]
1286 result_tensor = OutputShaper.matmulOp(
James Ward8b390432022-08-12 20:48:56 +01001287 self.ser, self.rng, a, b, accum_dtype, error_name
1288 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001289
1290 # Invalidate Input/Output list for error if checks.
1291 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001292 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001293 pCount, cCount = op["operands"]
1294 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001295 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1296 self, error_name, input_list, output_list
1297 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001298
Les Bell729b0352021-11-24 10:28:21 +00001299 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001300 self.ser,
1301 validator_fcns,
1302 error_name,
1303 op=op,
1304 input_shape=a.shape,
1305 input_dtype=a.dtype,
1306 input2_shape=b.shape,
1307 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001308 output_shape=result_tensor.shape,
1309 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001311 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001312 input_list=input_list,
1313 output_list=output_list,
1314 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001315 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001316 ):
1317 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001318
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001319 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001320 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001321
1322 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001323
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001324 compliance = self.tensorComplianceMetaData(
1325 op, a.dtype, args_dict, result_tensor, error_name
1326 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001327
1328 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001329
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001330 def build_reduce(
1331 self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
1332 ):
1333 assert len(inputs) == 1
1334 a = inputs[0]
1335 axis = args_dict["axis"]
1336 result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001337
1338 # Invalidate Input/Output list for error if checks.
1339 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001340 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001341 pCount, cCount = op["operands"]
1342 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001343 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1344 self, error_name, input_list, output_list
1345 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001346
Les Bell729b0352021-11-24 10:28:21 +00001347 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001348 self.ser,
1349 validator_fcns,
1350 error_name,
1351 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001352 axis=axis,
1353 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001354 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001355 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001356 output_dtype=result_tensor.dtype,
1357 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001358 input_list=input_list,
1359 output_list=output_list,
1360 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001361 ):
1362 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001363
1364 attr = ts.TosaSerializerAttribute()
1365 attr.AxisAttribute(axis)
1366
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001367 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001368
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001369 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1370 # Number of products - needed for compliance
1371 args_dict["n"] = a.shape[axis]
1372
1373 compliance = self.tensorComplianceMetaData(
1374 op, a.dtype, args_dict, result_tensor, error_name
1375 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001376
1377 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001378
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001379 def build_clamp(
1380 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1381 ):
1382 assert len(inputs) == 1
1383 a = inputs[0]
1384
1385 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001386
Jeremy Johnson18e26662021-07-22 16:15:29 +01001387 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001388
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001389 if error_name == ErrorIf.MaxSmallerMin:
1390 # Make sure the numbers are different to invoke this error
1391 while v[0] == v[1]:
1392 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1393 max_val = min(v)
1394 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001395 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001396 max_val = max(v)
1397 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001398
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001399 # Invalidate Input/Output list for error if checks.
1400 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001401 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001402 pCount, cCount = op["operands"]
1403 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001404 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1405 self, error_name, input_list, output_list
1406 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001407
Les Bell729b0352021-11-24 10:28:21 +00001408 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001409 self.ser,
1410 validator_fcns,
1411 error_name,
1412 op=op,
1413 max_val=max_val,
1414 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001415 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001416 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001417 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001418 output_dtype=result_tensor.dtype,
1419 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001420 input_list=input_list,
1421 output_list=output_list,
1422 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001423 ):
1424 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001425
1426 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001427 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1428 if a.dtype == DType.FP16:
1429 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1430 min_val = min_val.astype(np.float32)
1431 max_val = max_val.astype(np.float32)
1432
1433 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001434 elif a.dtype in (DType.INT8, DType.INT16):
James Ward34071252022-12-07 15:48:47 +00001435 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Won Jeon2c34b462024-02-06 18:37:00 +00001436 else:
1437 # to avoid internal error for incorrect input types
1438 attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001439
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001440 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001441
1442 compliance = self.tensorComplianceMetaData(
1443 op, a.dtype, args_dict, result_tensor, error_name
1444 )
1445
1446 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001447
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001448 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1449 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001450 attr = ts.TosaSerializerAttribute()
1451
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001452 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001453
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001454 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001455 return result_tens
1456
1457 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001458 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1459 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001460
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001461 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001462 return result_tens
1463
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001464 def build_activation(
1465 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1466 ):
1467 assert len(inputs) == 1
1468 a = inputs[0]
1469
1470 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001471
1472 # Invalidate Input/Output list for error if checks.
1473 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001474 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001475 pCount, cCount = op["operands"]
1476 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001477 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1478 self, error_name, input_list, output_list
1479 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001480
Les Bell729b0352021-11-24 10:28:21 +00001481 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001482 self.ser,
1483 validator_fcns,
1484 error_name,
1485 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001486 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001487 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001488 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001489 output_dtype=result_tensor.dtype,
1490 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001491 input_list=input_list,
1492 output_list=output_list,
1493 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001494 ):
1495 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001496
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001497 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001498
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001499 compliance = self.tensorComplianceMetaData(
1500 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001501 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001502
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001503 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001504
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001505 def build_concat(
1506 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1507 ):
Won Jeon74342e52024-01-09 00:34:40 +00001508 if op["op"] == Op.CONCAT_SHAPE:
1509 axis = 0
1510 else:
1511 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001512 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001513 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001514
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001515 result_tensor = OutputShaper.concatOp(
1516 self.ser, self.rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001517 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001518
Matthew Haddon818ab902021-07-27 09:12:49 +01001519 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001520 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001521 input_tensor_names.append(tensor.name)
1522
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001523 # Invalidate Input/Output list for error if checks.
1524 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001525 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001526 pCount, cCount = op["operands"]
1527 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001528 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1529 self, error_name, input_list, output_list
1530 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001531
Les Bell729b0352021-11-24 10:28:21 +00001532 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001533 self.ser,
1534 validator_fcns,
1535 error_name,
1536 op=op,
1537 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001538 input_shape=inputs[0].shape,
1539 output_shape=result_tensor.shape,
1540 input_dtype=inputs[0].dtype,
1541 output_dtype=result_tensor.dtype,
1542 inputs=inputs,
1543 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001544 input_list=input_list,
1545 output_list=output_list,
1546 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001547 ):
1548 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001549
Won Jeon74342e52024-01-09 00:34:40 +00001550 if op["op"] == Op.CONCAT:
1551 attr = ts.TosaSerializerAttribute()
1552 attr.AxisAttribute(axis)
1553 else:
1554 assert op["op"] == Op.CONCAT_SHAPE
1555 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001556 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001557
1558 compliance = self.tensorComplianceMetaData(
1559 op, inputs[0].dtype, args_dict, result_tensor, error_name
1560 )
1561
1562 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001563
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001564 def build_pad(
1565 self,
1566 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001567 inputs,
1568 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001569 validator_fcns=None,
1570 error_name=None,
1571 qinfo=None,
1572 ):
Tai Lye095da72024-01-25 22:00:18 +00001573 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001574 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001575 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001576 padding = args_dict["pad"]
1577 pad_const_int = args_dict["pad_const_int"]
1578 pad_const_float = args_dict["pad_const_fp"]
1579
1580 result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001581
Tai Lye095da72024-01-25 22:00:18 +00001582 # write empty padding into PadAttribute to ensure inputs[1] is used
Kevin Chengfe392ce2021-10-18 21:51:55 +00001583 attr = ts.TosaSerializerAttribute()
Tai Lye095da72024-01-25 22:00:18 +00001584 attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001585
Matthew Haddone807aae2021-10-11 18:12:58 +01001586 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001587 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001588 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001589 pCount, cCount = op["operands"]
1590 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001591 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1592 self, error_name, input_list, output_list
1593 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001594
Les Bell729b0352021-11-24 10:28:21 +00001595 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001596 self.ser,
1597 validator_fcns,
1598 error_name,
1599 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001600 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001601 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001602 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001603 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001604 pad=padding,
1605 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001606 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001607 input_list=input_list,
1608 output_list=output_list,
1609 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001610 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001611 ):
1612 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001613
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001614 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001615
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001616 compliance = self.tensorComplianceMetaData(
1617 op, a.dtype, args_dict, result_tensor, error_name
1618 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001619
1620 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001621
Won Jeona21b2e82023-08-10 10:33:01 +00001622 def build_dim(
1623 self,
1624 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001625 inputs,
1626 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001627 validator_fcns=None,
1628 error_name=None,
1629 qinfo=None,
1630 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001631 assert len(inputs) == 1
1632 a = inputs[0]
1633 axis = args_dict["axis"]
1634 result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001635
1636 # Invalidate Input/Output list for error if checks.
1637 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001638 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001639 pCount, cCount = op["operands"]
1640 num_operands = pCount + cCount
1641 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1642 self, error_name, input_list, output_list
1643 )
1644
1645 if not TosaErrorValidator.evValidateErrorIfs(
1646 self.ser,
1647 validator_fcns,
1648 error_name,
1649 op=op,
1650 axis=axis,
1651 input_shape=a.shape,
1652 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001653 output_shape=result_tensor.shape,
1654 output_dtype=result_tensor.dtype,
1655 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001656 input_list=input_list,
1657 output_list=output_list,
1658 num_operands=num_operands,
1659 ):
1660 return None
1661
1662 attr = ts.TosaSerializerAttribute()
1663 attr.AxisAttribute(axis)
1664
1665 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001666 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001667
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001668 def build_reshape(
1669 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1670 ):
Tai Ly8690a082023-12-18 20:40:24 +00001671 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001672 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001673 shape = inputs[1]
1674 shape_attr = args_dict["new_shape"]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001675 result_tensor = OutputShaper.reshapeOp(
Won Jeon64e4bfe2024-01-18 06:31:55 +00001676 self.ser, self.rng, a, shape_attr, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001677 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001678
1679 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001680 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001681 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001682 pCount, cCount = op["operands"]
1683 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001684 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1685 self, error_name, input_list, output_list
1686 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001687
Les Bell729b0352021-11-24 10:28:21 +00001688 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001689 self.ser,
1690 validator_fcns,
1691 error_name,
1692 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001693 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001694 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001695 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001696 output_dtype=result_tensor.dtype,
1697 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001698 input_list=input_list,
1699 output_list=output_list,
1700 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001701 ):
1702 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001703
Tai Ly8690a082023-12-18 20:40:24 +00001704 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001705
1706 compliance = self.tensorComplianceMetaData(
1707 op, a.dtype, args_dict, result_tensor, error_name
1708 )
1709
1710 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001711
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001712 def build_reverse(
1713 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1714 ):
1715 assert len(inputs) == 1
1716 a = inputs[0]
1717 axis = args_dict["axis"]
1718 result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001719
1720 # Invalidate Input/Output list for error if checks.
1721 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001722 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001723 pCount, cCount = op["operands"]
1724 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001725 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1726 self, error_name, input_list, output_list
1727 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001728
Les Bell729b0352021-11-24 10:28:21 +00001729 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001730 self.ser,
1731 validator_fcns,
1732 error_name,
1733 op=op,
1734 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001735 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001736 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001737 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001738 output_dtype=result_tensor.dtype,
1739 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001740 input_list=input_list,
1741 output_list=output_list,
1742 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001743 ):
1744 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001745
1746 attr = ts.TosaSerializerAttribute()
1747 attr.AxisAttribute(axis)
1748
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001749 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001750 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001751
evacha0198477222024-01-26 12:25:32 +00001752 def build_transpose(
1753 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1754 ):
1755 assert len(inputs) == 1
1756 a = inputs[0]
1757 perms = args_dict["perms"]
1758
1759 result_tensor = OutputShaper.transposeOp(
1760 self.ser, self.rng, a, perms, error_name
1761 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001762
Kevin Chengfe392ce2021-10-18 21:51:55 +00001763 attr = ts.TosaSerializerAttribute()
1764 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001765
Matthew Haddone807aae2021-10-11 18:12:58 +01001766 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001767 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001768 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001769 pCount, cCount = op["operands"]
1770 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001771 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1772 self, error_name, input_list, output_list
1773 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001774
Les Bell729b0352021-11-24 10:28:21 +00001775 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001776 self.ser,
1777 validator_fcns,
1778 error_name,
1779 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001780 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001781 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001782 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001783 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001784 output_dtype=result_tensor.dtype,
1785 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001786 input_list=input_list,
1787 output_list=output_list,
1788 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001789 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001790 ):
1791 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001792
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001793 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001794
1795 compliance = self.tensorComplianceMetaData(
1796 op, a.dtype, args_dict, result_tensor, error_name
1797 )
1798
1799 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001800
evacha017f7d4252024-01-24 12:08:09 +00001801 def build_slice(
1802 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1803 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001804 assert len(inputs) == 3
1805 a, start_var, size_var = inputs
1806 start_const = args_dict["start"]
1807 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001808
1809 result_tensor = OutputShaper.sliceOp(
TatWai Chongf15bad82024-01-31 21:33:27 -08001810 self.ser, self.rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001811 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001812
1813 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001814 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001815 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001816 pCount, cCount = op["operands"]
1817 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001818 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1819 self, error_name, input_list, output_list
1820 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001821
Les Bell729b0352021-11-24 10:28:21 +00001822 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001823 self.ser,
1824 validator_fcns,
1825 error_name,
1826 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001827 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001828 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001829 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001830 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001831 start=start_const,
1832 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001833 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001834 input_list=input_list,
1835 output_list=output_list,
1836 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001837 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001838 ):
1839 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001840
Tai Ly8ead6c42024-02-14 22:35:44 +00001841 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001842
1843 compliance = self.tensorComplianceMetaData(
1844 op, a.dtype, args_dict, result_tensor, error_name
1845 )
1846
1847 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001848
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001849 def build_tile(
1850 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1851 ):
Tai Ly8690a082023-12-18 20:40:24 +00001852 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001853 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001854 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001855 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001856 result_tensor = OutputShaper.tileOp(
Tai Ly8690a082023-12-18 20:40:24 +00001857 self.ser, self.rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001858 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001859
1860 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001861 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001862 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001863 pCount, cCount = op["operands"]
1864 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001865 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1866 self, error_name, input_list, output_list
1867 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001868
Les Bell729b0352021-11-24 10:28:21 +00001869 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001870 self.ser,
1871 validator_fcns,
1872 error_name,
1873 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001874 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001875 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001876 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001877 output_dtype=result_tensor.dtype,
1878 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001879 input_list=input_list,
1880 output_list=output_list,
1881 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001882 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001883 ):
1884 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001885
Tai Ly8690a082023-12-18 20:40:24 +00001886 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001887
1888 compliance = self.tensorComplianceMetaData(
1889 op, a.dtype, args_dict, result_tensor, error_name
1890 )
1891
1892 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001893
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001894 def build_gather(
1895 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1896 ):
1897 assert len(inputs) == 2
1898 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001899
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001900 result_tensor = OutputShaper.gatherOp(
1901 self.ser, self.rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001902 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001903
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001904 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001905 input_list = [values.name, indices.name]
1906 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001907 pCount, cCount = op["operands"]
1908 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001909 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1910 self, error_name, input_list, output_list
1911 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001912
Les Bell729b0352021-11-24 10:28:21 +00001913 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001914 self.ser,
1915 validator_fcns,
1916 error_name,
1917 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001918 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001919 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001920 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001921 output_dtype=result_tensor.dtype,
1922 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001923 input_list=input_list,
1924 output_list=output_list,
1925 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001926 ):
1927 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001928
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001929 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001930
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001931 compliance = self.tensorComplianceMetaData(
1932 op, values.dtype, args_dict, result_tensor, error_name
1933 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001934
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001935 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001936
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001937 def build_scatter(
1938 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
1939 ):
1940 assert len(inputs) == 3
1941 values_in, indices, input = inputs
1942 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson194fe312023-12-07 14:17:57 +00001943 self.ser, self.rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001944 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001945
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001946 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001947 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001948 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001949 pCount, cCount = op["operands"]
1950 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001951 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1952 self, error_name, input_list, output_list
1953 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001954
Les Bell729b0352021-11-24 10:28:21 +00001955 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001956 self.ser,
1957 validator_fcns,
1958 error_name,
1959 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001960 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001961 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001962 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001963 output_dtype=result_tensor.dtype,
1964 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001965 input_list=input_list,
1966 output_list=output_list,
1967 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001968 ):
1969 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001970
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001971 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001972
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001973 compliance = self.tensorComplianceMetaData(
1974 op, values_in.dtype, args_dict, result_tensor, error_name
1975 )
1976
1977 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001978
Kevin Cheng550ccc52021-03-03 11:21:43 -08001979 def build_resize(
1980 self,
1981 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001982 inputs,
1983 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01001984 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001985 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001986 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001987 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001988 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001989 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001990 scale_input = inputs[1]
1991 offset_input = inputs[2]
1992 border_input = inputs[3]
1993
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001994 mode = args_dict["mode"]
1995 scale = args_dict["scale"]
1996 offset = args_dict["offset"]
1997 border = args_dict["border"]
1998 output_dtype = args_dict["output_dtype"]
1999
2000 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08002001 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01002002 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002003 input,
2004 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002005 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002006 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002007 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002008 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002009 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002010 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002011 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002012
Matthew Haddon848efb42021-09-09 12:30:53 +01002013 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002014 input_list = [
2015 input.name,
2016 scale_input.name,
2017 offset_input.name,
2018 border_input.name,
2019 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002020 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002021 pCount, cCount = op["operands"]
2022 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002023 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2024 self, error_name, input_list, output_list
2025 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002026
Les Bell729b0352021-11-24 10:28:21 +00002027 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002028 self.ser,
2029 validator_fcns,
2030 error_name,
2031 op=op,
2032 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002033 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002034 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002035 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002036 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002037 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002038 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002039 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002040 input_list=input_list,
2041 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002042 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002043 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002044 ):
2045 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002046
Eric Kunzee5e26762020-10-13 16:11:07 -07002047 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002048 # write empty scale/offset/border into ResizeAttribute
2049 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002050 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002051
2052 compliance = self.tensorComplianceMetaData(
2053 op, input.dtype, args_dict, result_tensor, error_name
2054 )
2055
2056 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002057
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002058 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
2059 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
2060 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002061 self.ser.addOperator(
2062 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
2063 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002064 return result_tens
2065
evacha0198477222024-01-26 12:25:32 +00002066 def build_const(
2067 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2068 ):
2069 assert len(inputs) == 1
2070 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002071 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002072
2073 compliance = self.tensorComplianceMetaData(
2074 op, val.dtype, args_dict, val, error_name
2075 )
2076
2077 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002078
2079 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002080 def build_cast(
2081 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2082 ):
2083 assert len(inputs) == 1
2084 val = inputs[0]
2085 out_dtype = args_dict["out_type"]
2086
2087 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002088 self.ser, self.rng, val, out_dtype, error_name
2089 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002090
2091 # Invalidate Input/Output list for error if checks.
2092 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002093 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002094 pCount, cCount = op["operands"]
2095 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002096 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2097 self, error_name, input_list, output_list
2098 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002099
Les Bell729b0352021-11-24 10:28:21 +00002100 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002101 self.ser,
2102 validator_fcns,
2103 error_name,
2104 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002105 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002106 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002107 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002108 output_dtype=result_tensor.dtype,
2109 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002110 input_list=input_list,
2111 output_list=output_list,
2112 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002113 ):
2114 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002115
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002116 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002117
2118 compliance = self.tensorComplianceMetaData(
2119 op, val.dtype, args_dict, result_tensor, error_name
2120 )
2121
2122 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002123
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002124 def build_rescale(
2125 self,
2126 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002127 inputs,
2128 args_dict,
2129 validator_fcns=None,
2130 error_name=None,
2131 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002132 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002133 assert len(inputs) == 1
2134 val = inputs[0]
2135 out_dtype = args_dict["output_dtype"]
2136 scale32 = args_dict["scale"]
2137 double_round = args_dict["double_round"]
2138 per_channel = args_dict["per_channel"]
2139
2140 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002141 self.ser, self.rng, val, out_dtype, error_name
2142 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002143
2144 if per_channel:
2145 nc = val.shape[-1]
2146 else:
2147 nc = 1
2148
2149 in_type_width = self.typeWidth(val.dtype)
2150 out_type_width = self.typeWidth(out_dtype)
2151
Tai Ly8690a082023-12-18 20:40:24 +00002152 input_unsigned = False
2153 output_unsigned = False
2154
Kevin Cheng3a478572021-01-22 17:21:02 -08002155 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002156 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002157 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002158 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002159 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002160 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002161 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002162 elif error_name in [
2163 ErrorIf.InputZeroPointNotZero,
2164 ErrorIf.U16InputZeroPointNotValid,
2165 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002166 input_zp = self.randInt(-128, 128)
2167 if input_zp == 0:
2168 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002169 in_type_width += 1
2170 elif val.dtype == DType.UINT16:
2171 # Must come after ErrorIf.U16InputZeroPointNotValid check
2172 input_zp = self.rng.choice([0, 32768])
2173 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002174 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002175 else:
2176 input_zp = 0
2177
Kevin Cheng3a478572021-01-22 17:21:02 -08002178 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002179 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002180 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002181 elif out_dtype == DType.UINT8:
2182 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002183 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002184 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002185 elif error_name in [
2186 ErrorIf.OutputZeroPointNotZero,
2187 ErrorIf.U16OutputZeroPointNotValid,
2188 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01002189 output_zp = self.randInt(-128, 128)
2190 if output_zp == 0:
2191 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002192 out_type_width += 1
2193 elif out_dtype == DType.UINT16:
2194 # Must come after ErrorIf.U16OutputZeroPointNotValid check
2195 output_zp = self.rng.choice([0, 32768])
2196 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002197 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002198 else:
2199 output_zp = 0
2200
2201 # Calculate scale based on:
2202 # scale = a *(2^output_width)/(2^input_width))
2203
2204 a = np.float32(self.rng.random(size=[nc]))
2205 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
2206
2207 if scale32:
2208 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01002209 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07002210 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
2211 else:
2212 # Cap the scaling at 2^15 - 1 for scale16
2213 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
2214
Kevin Cheng550ccc52021-03-03 11:21:43 -08002215 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07002216
2217 multiplier_arr = np.int32(np.zeros(shape=[nc]))
2218 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002219 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2220 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002221
2222 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002223 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
2224 scale_arr[i], scale32
2225 )
Eric Kunze750d27d2022-06-30 21:37:09 +00002226 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2227 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002228
Kevin Cheng550ccc52021-03-03 11:21:43 -08002229 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002230 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002231 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002232 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002233 assert val.placeholderFilename
2234 values = np.load(
2235 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2236 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002237 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2238 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2239 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002240 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2241 # Check we can safely convert to the expected dtype
2242 assert (
2243 val_adj.all() >= np.iinfo(values.dtype).min
2244 and val_adj.all() <= np.iinfo(values.dtype).max
2245 )
2246
2247 # Force casting to output datatype
2248 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2249
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002250 if not np.all(np.array_equal(values, val_adj)):
2251 # Values changed so overwrite file with new values
2252 np.save(
2253 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2254 val_adj,
2255 False,
2256 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002257
Matthew Haddonc2025212021-10-08 21:21:05 +01002258 # Invalidate Input/Output list for error if checks.
2259 input_list = [val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002260 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002261 pCount, cCount = op["operands"]
2262 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002263 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2264 self, error_name, input_list, output_list
2265 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002266
2267 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002268 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002269 self.ser,
2270 validator_fcns,
2271 error_name,
2272 op=op,
2273 input_dtype=val.dtype,
2274 output_dtype=out_dtype,
2275 input_shape=val.shape,
2276 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002277 scale32=scale32,
2278 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002279 input_list=input_list,
2280 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002281 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002282 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002283 ):
2284 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002285
Eric Kunzee5e26762020-10-13 16:11:07 -07002286 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002287 attr.RescaleAttribute(
2288 input_zp,
2289 output_zp,
2290 multiplier_arr,
2291 shift_arr,
2292 scale32,
2293 double_round,
2294 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002295 input_unsigned,
2296 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002297 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002298
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002299 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002300
2301 compliance = self.tensorComplianceMetaData(
2302 op, val.dtype, args_dict, result_tensor, error_name
2303 )
2304
2305 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002306
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002307 def _get_condition_tensor(self, op, cond, error_name):
2308 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002309 cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002310 else:
2311 cond_type = DType.BOOL
2312 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
2313 choice = self.rng.choice([1, 2])
2314 if choice == 1:
2315 cond_shape = [2]
2316 else:
2317 cond_shape = [1, 2]
2318 else:
2319 # Must be of size 1 (rank 0)
2320 cond_shape = []
2321 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2322 return cond_tens
2323
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002324 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002325 self,
2326 op,
2327 inputs,
2328 args_dict,
2329 validator_fcns=None,
2330 error_name=None,
2331 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002332 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002333 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002334 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002335 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002336 assert len(inputs) == 2
2337 then_tens, else_tens = inputs
2338
2339 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002340
2341 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002342 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002343
2344 # Make then/else tensors
2345 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002346
Jeremy Johnson587cc842024-02-08 11:45:44 +00002347 dtype = DType.INT32
2348
Matthew Haddon630c17c2021-10-14 15:05:41 +01002349 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002350 if error_name in [
2351 ErrorIf.CondIfOutputListThenGraphMismatch,
2352 ErrorIf.CondIfOutputListElseGraphMismatch,
2353 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002354 incorrect_shape = deepcopy(then_tens.shape)
2355 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002356 incorrect_shape[i] += (
2357 self.rng.choice([-3, -2, 2, 3])
2358 if incorrect_shape[i] > 3
2359 else self.rng.choice([1, 2, 4])
2360 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002361 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
2362
Jeremy Johnson18e26662021-07-22 16:15:29 +01002363 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
2364 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002365
2366 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002367 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002368
2369 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002370 then_block = "THEN_BLOCK"
2371 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002372 attr = ts.TosaSerializerAttribute()
2373 attr.CondIfAttribute(then_block, else_block)
2374
2375 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002376 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002377
Jerry Ge9e94af82022-10-27 09:57:00 -07002378 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002379 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002380 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002381 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002382 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002383 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002384 self.ser.addOutputTensor(then_tens)
2385
Jerry Ge9e94af82022-10-27 09:57:00 -07002386 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002387 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002388 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002389 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002390 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002391 self.ser.addOutputTensor(else_tens)
2392
Les Bell729b0352021-11-24 10:28:21 +00002393 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002394 self.ser,
2395 validator_fcns,
2396 error_name,
2397 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002398 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002399 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002400 ):
2401 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002402
Jeremy Johnson587cc842024-02-08 11:45:44 +00002403 compliance = self.tensorComplianceMetaData(
2404 op, dtype, args_dict, result_tensor, error_name
2405 )
2406
2407 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002408
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002409 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002410 self,
2411 op,
2412 inputs,
2413 args_dict,
2414 validator_fcns=None,
2415 error_name=None,
2416 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002417 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002418 # For cond_if with a binary op in the then/else blocks, take a and b and
2419 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002420 assert len(inputs) == 2
2421 a, b = inputs
2422
2423 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002424
2425 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002426 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002427
Jeremy Johnson587cc842024-02-08 11:45:44 +00002428 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002429
2430 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002431 then_block = "THEN_BLOCK"
2432 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002433 attr = ts.TosaSerializerAttribute()
2434 attr.CondIfAttribute(then_block, else_block)
2435
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002436 if error_name in [
2437 ErrorIf.CondIfInputListThenGraphMismatch,
2438 ErrorIf.CondIfInputListElseGraphMismatch,
2439 ErrorIf.CondIfOutputListElseGraphMismatch,
2440 ErrorIf.CondIfOutputListThenGraphMismatch,
2441 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002442 incorrect_shape = a.shape.copy()
2443 for i in range(len(incorrect_shape)):
2444 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
2445 incorrect_block_input = deepcopy(a)
2446 incorrect_block_input.shape = incorrect_shape
2447
Eric Kunzee5e26762020-10-13 16:11:07 -07002448 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002449 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002450 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002451 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002452
James Ward24dbc422022-10-19 12:20:31 +01002453 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002454 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002455 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002456 then_op, else_op = (
2457 self.TOSA_OP_LIST["logical_right_shift"],
2458 self.TOSA_OP_LIST["logical_left_shift"],
2459 )
Les Bell6040b4d2021-10-11 12:50:31 +01002460 else:
2461 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002462
Jeremy Johnson587cc842024-02-08 11:45:44 +00002463 # Determine the element-wise binary operation that compliance will need to
2464 # check the results of
2465 compliance_op = then_op if cond else else_op
2466
2467 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002468 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002469 if (
2470 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2471 and block == then_block
2472 ) or (
2473 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2474 and block == else_block
2475 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002476 self.ser.addInputTensor(incorrect_block_input)
2477 self.ser.addInputTensor(b)
2478 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002479 elif (
2480 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2481 and block == then_block
2482 ) or (
2483 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2484 and block == else_block
2485 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002486 self.ser.addInputTensor(a)
2487 self.ser.addInputTensor(b)
2488 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2489 else:
2490 self.ser.addInputTensor(a)
2491 self.ser.addInputTensor(b)
2492 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002493 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002494
Les Bell729b0352021-11-24 10:28:21 +00002495 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002496 self.ser,
2497 validator_fcns,
2498 error_name,
2499 op=op,
2500 a=a,
2501 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002502 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002503 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002504 ):
2505 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002506
Jeremy Johnson587cc842024-02-08 11:45:44 +00002507 compliance = self.tensorComplianceMetaData(
2508 compliance_op, a.dtype, args_dict, result_tensor, error_name
2509 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002510
Jeremy Johnson587cc842024-02-08 11:45:44 +00002511 return TosaTestGen.BuildInfo(result_tensor, compliance)
2512
2513 def build_while_loop(
2514 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2515 ):
2516 assert len(inputs) == 1
2517 a = inputs[0]
2518 iter_val = args_dict["iterations"]
2519
Kevin Cheng550ccc52021-03-03 11:21:43 -08002520 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002521
Kevin Cheng550ccc52021-03-03 11:21:43 -08002522 cond_block = "COND_BLOCK"
2523 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002524
2525 attr = ts.TosaSerializerAttribute()
2526 attr.WhileLoopAttribute(cond_block, body_block)
2527
2528 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002529 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002530 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002531 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002532
2533 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002534 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2535 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002536 if error_name == ErrorIf.InputListOutputListMismatch:
2537 incorrect_acc = deepcopy(acc)
2538 for i in range(len(incorrect_acc.shape)):
2539 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2540 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2541 else:
2542 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002543
2544 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002545 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002546 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002547 [iter.name, a.name, acc.name],
2548 [iter_out.name, a_out.name, acc_out.name],
2549 attr,
2550 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002551 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002552
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002553 if error_name in [
2554 ErrorIf.InputListCondGraphMismatch,
2555 ErrorIf.InputListBodyGraphInputMismatch,
2556 ErrorIf.InputListBodyGraphOutputMismatch,
2557 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002558 incorrect_iter = deepcopy(iter)
2559 for i in range(len(incorrect_iter.shape)):
2560 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2561 if len(incorrect_iter.shape) == 0:
2562 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2563
2564 incorrect_acc = deepcopy(acc)
2565 for i in range(len(incorrect_acc.shape)):
2566 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2567
Eric Kunzee5e26762020-10-13 16:11:07 -07002568 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002569 self.ser.addBasicBlock(cond_block)
2570
Matthew Haddon630c17c2021-10-14 15:05:41 +01002571 if error_name == ErrorIf.InputListCondGraphMismatch:
2572 self.ser.addInputTensor(incorrect_iter)
2573 self.ser.addInputTensor(a)
2574 self.ser.addInputTensor(incorrect_acc)
2575 else:
2576 self.ser.addInputTensor(iter)
2577 self.ser.addInputTensor(a)
2578 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002579 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002580
2581 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002582 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002583 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002584 cond_type = DType.BOOL
2585 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2586 choice = self.rng.choice([1, 2])
2587 if choice == 1:
2588 cond_shape = [3]
2589 else:
2590 cond_shape = [1, 2]
2591 else:
2592 cond_shape = []
2593 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002594
Kevin Cheng550ccc52021-03-03 11:21:43 -08002595 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002596
2597 # BODY block (input: a, acc, iter, output: a, acc, iter)
2598 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002599 self.ser.addBasicBlock(body_block)
2600
Matthew Haddon630c17c2021-10-14 15:05:41 +01002601 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2602 self.ser.addInputTensor(incorrect_iter)
2603 self.ser.addInputTensor(a)
2604 self.ser.addInputTensor(incorrect_acc)
2605 else:
2606 self.ser.addInputTensor(iter)
2607 self.ser.addInputTensor(a)
2608 self.ser.addInputTensor(acc)
2609
Kevin Cheng550ccc52021-03-03 11:21:43 -08002610 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002611
2612 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002613 iter_body_out = self.ser.addIntermediate(
2614 incorrect_iter.shape, incorrect_iter.dtype
2615 )
2616 acc_body_out = self.ser.addIntermediate(
2617 incorrect_acc.shape, incorrect_acc.dtype
2618 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002619 else:
2620 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2621 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2622
Eric Kunzee5e26762020-10-13 16:11:07 -07002623 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2624 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2625 self.ser.addOutputTensor(iter_body_out)
2626 self.ser.addOutputTensor(a)
2627 self.ser.addOutputTensor(acc_body_out)
2628
Les Bell729b0352021-11-24 10:28:21 +00002629 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002630 self.ser,
2631 validator_fcns,
2632 error_name,
2633 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002634 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002635 ):
2636 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002637
Jeremy Johnson587cc842024-02-08 11:45:44 +00002638 compliance = self.tensorComplianceMetaData(
2639 op, a.dtype, args_dict, acc_out, error_name
2640 )
2641
2642 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002643
Luke Hutton57287132023-02-06 14:54:18 +00002644 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002645 self,
2646 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002647 inputs,
2648 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002649 validator_fcns=None,
2650 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002651 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002652 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002653 assert len(inputs) == 2
2654 val1, val2 = inputs
2655 inverse = args_dict["inverse"]
2656
Luke Hutton57287132023-02-06 14:54:18 +00002657 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2658
2659 input_names = [val1.name, val2.name]
2660 pCount, cCount = op["operands"]
2661 num_operands = pCount + cCount
2662
2663 output_names = [res.name for res in results]
2664 output_shapes = [res.shape for res in results]
2665 output_dtypes = [res.dtype for res in results]
2666
2667 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2668 self, error_name, input_names, output_names
2669 )
2670
2671 if not TosaErrorValidator.evValidateErrorIfs(
2672 self.ser,
2673 validator_fcns,
2674 error_name,
2675 op=op,
2676 inverse=inverse,
2677 input1=val1,
2678 input2=val2,
2679 input_shape=val1.shape,
2680 input_dtype=val1.dtype,
2681 output_shape=output_shapes,
2682 output_dtype=output_dtypes,
2683 result_tensors=results,
2684 input_list=input_names,
2685 output_list=output_names,
2686 num_operands=num_operands,
2687 ):
2688 return None
2689
Tai Lyd3797f02023-11-15 23:06:19 +00002690 # TODO - Test local_bound, for now set local bound attribute to False
2691 local_bound = False
2692
Luke Hutton57287132023-02-06 14:54:18 +00002693 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002694 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002695
2696 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002697
2698 compliance = []
2699 for res in results:
2700 compliance.append(
2701 self.tensorComplianceMetaData(
2702 op, val1.dtype, args_dict, res, error_name
2703 )
2704 )
2705
2706 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002707
Tai Lyd3797f02023-11-15 23:06:19 +00002708 def build_rfft2d(
2709 self,
2710 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002711 inputs,
2712 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002713 validator_fcns=None,
2714 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002715 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002716 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002717 assert len(inputs) == 1
2718 val = inputs[0]
Luke Hutton261b7b62023-01-10 14:50:31 +00002719 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2720
2721 input_names = [val.name]
2722 pCount, cCount = op["operands"]
2723 num_operands = pCount + cCount
2724
2725 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002726 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002727 output_dtypes = [res.dtype for res in results]
2728
2729 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2730 self, error_name, input_names, output_names
2731 )
2732
2733 if not TosaErrorValidator.evValidateErrorIfs(
2734 self.ser,
2735 validator_fcns,
2736 error_name,
2737 op=op,
2738 input_shape=val.shape,
2739 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002740 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002741 output_dtype=output_dtypes,
2742 result_tensors=results,
2743 input_list=input_names,
2744 output_list=output_names,
2745 num_operands=num_operands,
2746 ):
2747 return None
2748
Tai Lyd3797f02023-11-15 23:06:19 +00002749 # TODO - Test local_bound, for now set local bound attribute to False
2750 local_bound = False
2751
2752 attr = ts.TosaSerializerAttribute()
2753 attr.RFFTAttribute(local_bound)
2754
2755 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002756
2757 compliance = []
2758 for res in results:
2759 compliance.append(
2760 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2761 )
2762
2763 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002764
Won Jeon74342e52024-01-09 00:34:40 +00002765 def build_shape_op(
2766 self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
2767 ):
2768 assert len(inputs) == 2
2769 a, b = inputs
2770
2771 result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
2772
2773 # Invalidate Input/Output list for error if checks.
2774 input_list = [a.name, b.name]
2775 output_list = [result_tensor.name]
2776 pCount, cCount = op["operands"]
2777 num_operands = pCount + cCount
2778 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2779 self, error_name, input_list, output_list
2780 )
2781
2782 if not TosaErrorValidator.evValidateErrorIfs(
2783 self.ser,
2784 validator_fcns,
2785 error_name,
2786 op=op,
2787 input1=a,
2788 input2=b,
2789 input_shape=a.shape,
2790 input_dtype=a.dtype,
2791 output_shape=result_tensor.shape,
2792 output_dtype=result_tensor.dtype,
2793 result_tensors=[result_tensor],
2794 input_list=input_list,
2795 output_list=output_list,
2796 num_operands=num_operands,
2797 ):
2798 return None
2799
2800 self.ser.addOperator(
2801 op["op"],
2802 input_list,
2803 output_list,
2804 )
2805 compliance = self.tensorComplianceMetaData(
2806 op, a.dtype, args_dict, result_tensor, error_name
2807 )
2808
2809 return TosaTestGen.BuildInfo(result_tensor, compliance)
2810
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002811 def create_filter_lists(
2812 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2813 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002814 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2815 default_test_rank_range = range(1, 5)
2816 if not shapeFilter:
2817 shapeFilter = [None]
2818
2819 # Calculate the filters based on what is requested and what the operator allows
2820 rmin, rmax = op["rank"]
2821 if rankFilter is not None:
2822 cleanRankFilter = []
2823 # Ensure rankFilter values are allowed by operator
2824 for rank in rankFilter:
2825 if rank >= rmin and rank <= rmax:
2826 cleanRankFilter.append(rank)
2827 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002828 # Ensure default behaviour is bounded by default range or by operator,
2829 # whichever is the smaller range of ranks.
2830 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002831 cleanRankFilter = (
2832 opRankRange
2833 if len(opRankRange) <= len(default_test_rank_range)
2834 else default_test_rank_range
2835 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002836 else:
2837 cleanRankFilter = range(rmin, rmax + 1)
2838
2839 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002840
Matthew Haddon1c00b712021-10-01 15:51:03 +01002841 if dtypeFilter is not None:
2842 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002843 # Create list of operator dtypes filtered by requested dtypes
2844 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002845 if dtype in dtypeFilter or (
2846 isinstance(dtype, list) and dtype[0] in dtypeFilter
2847 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002848 cleanDtypeFilter.append(dtype)
2849 else:
2850 cleanDtypeFilter = dtypes
2851
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002852 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002853 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002854 "shapeFilter": shapeFilter,
2855 "rankFilter": cleanRankFilter,
2856 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002857 }
2858 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002859 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002860 if validator is not None:
2861 validator_info = validator(check=False, op=op)
2862 else:
2863 return None
2864
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002865 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002866
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002867 # Set parameters as required
2868 if error_arguments["rank"] is not None:
2869 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002870 else:
2871 rankFilter = cleanRankFilter
2872
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002873 if error_arguments["dtype"] is not None:
2874 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002875 else:
2876 dtypeFilter = cleanDtypeFilter
2877
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002878 if error_arguments["shape"] is not None:
2879 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002880 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002881 shapeFilter = shapeFilter[
2882 :2
2883 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002884
2885 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002886 "shapeFilter": shapeFilter,
2887 "rankFilter": rankFilter,
2888 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002889 }
2890 return filterDict
2891
Kevin Cheng550ccc52021-03-03 11:21:43 -08002892 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002893 self,
2894 opName,
2895 shapeFilter=[None],
2896 rankFilter=None,
2897 dtypeFilter=None,
2898 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002899 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002900
2901 try:
2902 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002903 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002904 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002905
2906 # Initialize a new random number generator
2907 self.rng = np.random.default_rng(self.random_seed)
2908
Jeremy Johnson1271c442023-09-05 11:39:26 +01002909 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002910
Eric Kunzee5e26762020-10-13 16:11:07 -07002911 # Test list consists of a tuple of:
2912 # (opName, testNameStr, dtype, shapeList, argumentsList)
2913 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002914 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002915 error_if_validators = op["error_if_validators"]
2916 else:
2917 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002918
Matthew Haddon1c00b712021-10-01 15:51:03 +01002919 for validator in error_if_validators:
2920 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002921 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002922 else:
2923 error_name = None
2924
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002925 filterDict = self.create_filter_lists(
2926 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2927 )
2928 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002929 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002930 cleanRankFilter = filterDict["rankFilter"]
2931 cleanDtypeFilter = filterDict["dtypeFilter"]
2932 cleanShapeFilter = filterDict["shapeFilter"]
2933 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002934
2935 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002936 for t in cleanDtypeFilter:
2937 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002938 # Filter out by rank
2939 if shape is not None and len(shape) != r:
2940 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002941 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002942 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002943
Matthew Haddon74567092021-07-16 15:38:20 +01002944 shapeStr = self.shapeStr(shapeList[0])
2945 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002946
Matthew Haddon74567092021-07-16 15:38:20 +01002947 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2948 argList = []
2949 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002950 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002951 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002952 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002953
Matthew Haddon74567092021-07-16 15:38:20 +01002954 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002955 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002956 if argStr:
2957 testStr = "{}_{}_{}_{}".format(
2958 opName, shapeStr, typeStr, argStr
2959 )
2960 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002961 testStr = "{}_{}_{}".format(
2962 opName, shapeStr, typeStr
2963 )
2964 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002965 if argStr:
2966 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2967 opName, error_name, shapeStr, typeStr, argStr
2968 )
2969 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002970 testStr = "{}_ERRORIF_{}_{}_{}".format(
2971 opName, error_name, shapeStr, typeStr
2972 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002973
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002974 testList.append(
2975 (opName, testStr, t, error_name, shapeList, args)
2976 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002977
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002978 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002979 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2980 if "invalid_test_validators" in op:
2981 invalid_test_validators = op["invalid_test_validators"]
2982 clean_testList = []
2983 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002984 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002985 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002986 if validator_fcn(
2987 opName=test[0],
2988 input_dtype=test[2],
2989 shapeList=test[4],
2990 args=test[5],
2991 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002992 remove_test = True
2993 if not remove_test:
2994 clean_testList.append(test)
2995 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002996
2997 return testList
2998
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002999 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00003000 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003001 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003002 try:
3003 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003004 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003005 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003006
Jeremy Johnson0c716862023-04-13 17:18:19 +01003007 if self.args.verbose:
3008 print(f"Creating {testStr}")
3009
Eric Kunzee5e26762020-10-13 16:11:07 -07003010 # Create a serializer
3011 self.createSerializer(opName, testStr)
3012
Jeremy Johnson1271c442023-09-05 11:39:26 +01003013 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003014 if "error_if_validators" in op:
3015 error_if_validators = op["error_if_validators"]
3016 else:
3017 error_if_validators = None
3018
Kevin Cheng550ccc52021-03-03 11:21:43 -08003019 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003020 num_operands = pCount + cCount
3021
3022 if isinstance(dtype_or_dtypeList, list):
3023 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003024 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003025 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003026 else:
3027 dtypeList = [dtype_or_dtypeList] * (num_operands)
3028
Won Jeon74342e52024-01-09 00:34:40 +00003029 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003030 assert (
3031 len(shapeList) == num_operands
3032 ), "shapeList length {} must match number of operands {}".format(
3033 len(shapeList), num_operands
3034 )
3035 assert (
3036 len(dtypeList) == num_operands
3037 ), "dtypeList length {} must match number of operands {}".format(
3038 len(dtypeList), num_operands
3039 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003040
3041 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003042 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003043 except KeyError:
3044 qgen = None
3045
3046 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003047
Matthew Haddon1c00b712021-10-01 15:51:03 +01003048 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003049 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003050 else:
3051 qinfo = None
3052
Jeremy Johnson1271c442023-09-05 11:39:26 +01003053 # Extra meta data for the desc.json
3054 tensMeta = {}
3055
Jeremy Johnson587cc842024-02-08 11:45:44 +00003056 # Check we are using the new interface with an argsDict dictionary
3057 assert isinstance(
3058 argsDict, dict
3059 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003060
Jeremy Johnson587cc842024-02-08 11:45:44 +00003061 # New interface with args info in dictionary
3062 assert "dg_type" in argsDict
3063 tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name)
3064 if tvgInfo.dataGenDict:
3065 tensMeta["data_gen"] = tvgInfo.dataGenDict
3066 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003067
Jeremy Johnson587cc842024-02-08 11:45:44 +00003068 result = build_fcn(
3069 self,
3070 op,
3071 tens,
3072 argsDict,
3073 validator_fcns=error_if_validators,
3074 error_name=error_name,
3075 qinfo=qinfo,
3076 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003077
Jeremy Johnson1271c442023-09-05 11:39:26 +01003078 if result:
Les Bell729b0352021-11-24 10:28:21 +00003079 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003080 if isinstance(result, TosaTestGen.BuildInfo):
3081 # Add the compliance meta data (if any)
3082 compliance = result.getComplianceInfo()
3083 if compliance:
3084 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003085 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00003086 else:
3087 # The test is not valid
3088 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01003089
Eric Kunzee5e26762020-10-13 16:11:07 -07003090 def createDynamicOpLists(self):
3091
Jeremy Johnson00423432022-09-12 17:27:37 +01003092 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3093 # Already created these lists (can occur when class is initialized more than once)
3094 return
3095
Eric Kunzee5e26762020-10-13 16:11:07 -07003096 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003097 if not self.args.level8k:
3098 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3099 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3100 else:
3101 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3102 KERNELS_2D = [[1, bigK], [bigK, 2]]
3103 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003104
Kevin Cheng1533b852021-09-01 12:51:58 -07003105 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003106 testName = "conv2d_{}x{}".format(k[0], k[1])
3107 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3108 self.TOSA_OP_LIST[testName]["filter"] = k
3109 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003110
Kevin Cheng550ccc52021-03-03 11:21:43 -08003111 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3112 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3113 "depthwise_conv2d_TEMPLATE"
3114 ].copy()
3115 self.TOSA_OP_LIST[testName]["filter"] = k
3116 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003117
Kevin Cheng550ccc52021-03-03 11:21:43 -08003118 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3119 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3120 "transpose_conv2d_TEMPLATE"
3121 ].copy()
3122 self.TOSA_OP_LIST[testName]["filter"] = k
3123 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07003124
Kevin Cheng1533b852021-09-01 12:51:58 -07003125 for k in KERNELS_3D:
3126 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3127 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3128 self.TOSA_OP_LIST[testName]["filter"] = k
3129 self.TOSA_OP_LIST[testName]["template"] = False
3130
Eric Kunzee5e26762020-10-13 16:11:07 -07003131 # Delete any templates after having created any dynamic ops
3132 # This is a two-pass operation because it's bad practice to delete
3133 # keys from dictionaries while iterating
3134 keyList = []
3135 for k in self.TOSA_OP_LIST:
3136 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003137 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003138 keyList.append(k)
3139 continue
3140 except KeyError:
3141 pass
3142
3143 for k in keyList:
3144 del self.TOSA_OP_LIST[k]
3145
3146 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003147 """Fill in default fields for ops if they aren't already specified.
3148 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003149 for op in self.TOSA_OP_LIST:
3150
3151 # Required fields
3152 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003153 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003154 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003155 raise Exception(
3156 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3157 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003158
3159 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003160 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003161 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003162 raise Exception(
3163 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3164 op
3165 )
3166 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003167
3168 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003169 _ = self.TOSA_OP_LIST[op]["types"]
3170 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003171 raise Exception(
3172 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3173 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003174
3175 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003176 _ = self.TOSA_OP_LIST[op]["op"]
3177 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003178 raise Exception(
3179 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3180 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003181
3182 # Put in default rank range, if missing
3183 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003184 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003185 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003186 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003187
3188 # Tensor operator list
3189 # 'op': op name
3190 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003191 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3192 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003193 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3194 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003195 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003196
Kevin Cheng550ccc52021-03-03 11:21:43 -08003197 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003198 TYPE_INT_FP = [
3199 DType.INT8,
3200 DType.INT16,
3201 DType.INT32,
3202 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003203 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003204 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003205 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003206
Kevin Cheng550ccc52021-03-03 11:21:43 -08003207 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003208 TYPE_FI32 = [
3209 DType.FP32,
3210 DType.FP16,
3211 DType.BF16,
3212 DType.INT32,
3213 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003214 TYPE_FIB = [
3215 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003216 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003217 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003218 DType.INT8,
3219 DType.INT16,
3220 DType.INT32,
3221 DType.BOOL,
3222 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003223 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003224
Won Jeon2c34b462024-02-06 18:37:00 +00003225 TYPE_NARROW_INT_FP = [
3226 DType.INT8,
3227 DType.INT16,
3228 DType.FP16,
3229 DType.BF16,
3230 DType.FP32,
3231 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003232
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003233 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003234 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003235 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003236 [DType.INT8, DType.INT8, DType.INT32],
3237 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003238 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003239 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003240 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003241 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003242 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3243 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003244 ]
3245
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003246 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003247
3248 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003249 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003250 "argmax": {
3251 "op": Op.ARGMAX,
3252 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003253 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003254 "build_fcn": (
3255 build_argmax,
3256 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003257 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003258 TosaArgGen.agAxis,
3259 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003260 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003261 "error_if_validators": (
3262 TosaErrorValidator.evAxisSmallerZero,
3263 TosaErrorValidator.evAxisLargerRank,
3264 TosaErrorValidator.evArgmaxOutputRankMismatch,
3265 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3266 TosaErrorValidator.evWrongRank,
3267 TosaErrorValidator.evWrongInputType,
3268 TosaErrorValidator.evWrongOutputType,
3269 TosaErrorValidator.evWrongInputList,
3270 TosaErrorValidator.evWrongOutputList,
3271 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003272 "data_gen": {
3273 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3274 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003275 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003276 "avg_pool2d": {
3277 "op": Op.AVG_POOL2D,
3278 "operands": (1, 0),
3279 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003280 "build_fcn": (
3281 build_pool2d,
3282 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003283 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003284 TosaArgGen.agPooling,
3285 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003286 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003287 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003288 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003289 "error_if_validators": (
3290 TosaErrorValidator.evKernelSmallerOne,
3291 TosaErrorValidator.evStrideSmallerOne,
3292 TosaErrorValidator.evPadSmallerZero,
3293 TosaErrorValidator.evWrongRank,
3294 TosaErrorValidator.evWrongInputType,
3295 TosaErrorValidator.evWrongOutputType,
3296 TosaErrorValidator.evWrongInputList,
3297 TosaErrorValidator.evWrongOutputList,
3298 TosaErrorValidator.evInputZeroPointNotZero,
3299 TosaErrorValidator.evOutputZeroPointNotZero,
3300 TosaErrorValidator.evPadLargerEqualKernel,
3301 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003302 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003303 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003304 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003305 "data_gen": {
3306 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3307 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003308 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003309 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003310 "conv2d_TEMPLATE": {
3311 "op": Op.CONV2D,
3312 "operands": (1, 2),
3313 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003314 "build_fcn": (
3315 build_conv2d,
3316 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003317 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003318 TosaArgGen.agConv,
3319 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003320 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003321 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003322 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3323 "error_if_validators": (
3324 TosaErrorValidator.evWrongInputType,
3325 TosaErrorValidator.evWrongOutputType,
3326 TosaErrorValidator.evWrongInputList,
3327 TosaErrorValidator.evWrongOutputList,
3328 TosaErrorValidator.evInputZeroPointNotZero,
3329 TosaErrorValidator.evWeightZeroPointNotZero,
3330 TosaErrorValidator.evPadSmallerZero,
3331 TosaErrorValidator.evStrideSmallerOne,
3332 TosaErrorValidator.evDilationSmallerOne,
3333 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003334 TosaErrorValidator.evConvOutputShapeMismatch,
3335 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003336 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003337 "data_gen": {
3338 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3339 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003340 "template": True,
3341 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003342 # Templated operator. Filled in by createDynamicOpLists
3343 "conv3d_TEMPLATE": {
3344 "op": Op.CONV3D,
3345 "operands": (1, 2),
3346 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003347 "build_fcn": (
3348 build_conv3d,
3349 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003350 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003351 TosaArgGen.agConv,
3352 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003353 "qgen": TosaQuantGen.qgConv,
3354 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003355 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3356 "error_if_validators": (
3357 TosaErrorValidator.evWrongInputType,
3358 TosaErrorValidator.evWrongOutputType,
3359 TosaErrorValidator.evWrongInputList,
3360 TosaErrorValidator.evWrongOutputList,
3361 TosaErrorValidator.evInputZeroPointNotZero,
3362 TosaErrorValidator.evWeightZeroPointNotZero,
3363 TosaErrorValidator.evPadSmallerZero,
3364 TosaErrorValidator.evStrideSmallerOne,
3365 TosaErrorValidator.evDilationSmallerOne,
3366 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003367 TosaErrorValidator.evConvOutputShapeMismatch,
3368 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003369 ),
evacha0147ab1762024-01-29 13:23:23 +00003370 "data_gen": {
3371 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3372 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003373 "template": True,
3374 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003375 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003376 "depthwise_conv2d_TEMPLATE": {
3377 "op": Op.DEPTHWISE_CONV2D,
3378 "operands": (1, 2),
3379 "filter": [1, 1],
3380 "rank": (4, 4),
3381 "build_fcn": (
3382 build_depthwise_conv2d,
3383 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003384 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003385 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003386 ),
3387 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003388 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003389 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3390 "error_if_validators": (
3391 TosaErrorValidator.evWrongInputType,
3392 TosaErrorValidator.evWrongOutputType,
3393 TosaErrorValidator.evWrongInputList,
3394 TosaErrorValidator.evWrongOutputList,
3395 TosaErrorValidator.evInputZeroPointNotZero,
3396 TosaErrorValidator.evWeightZeroPointNotZero,
3397 TosaErrorValidator.evPadSmallerZero,
3398 TosaErrorValidator.evStrideSmallerOne,
3399 TosaErrorValidator.evDilationSmallerOne,
3400 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003401 TosaErrorValidator.evConvOutputShapeMismatch,
3402 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003403 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003404 "data_gen": {
3405 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3406 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003407 "template": True,
3408 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003409 "fully_connected": {
3410 "op": Op.FULLY_CONNECTED,
3411 "operands": (1, 2),
3412 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003413 "build_fcn": (
3414 build_fully_connected,
3415 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003416 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003417 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003418 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003419 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003420 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003421 "error_if_validators": (
3422 TosaErrorValidator.evInputZeroPointNotZero,
3423 TosaErrorValidator.evWeightZeroPointNotZero,
3424 TosaErrorValidator.evWrongRank,
3425 TosaErrorValidator.evWrongInputType,
3426 TosaErrorValidator.evWrongOutputType,
3427 TosaErrorValidator.evWrongInputList,
3428 TosaErrorValidator.evWrongOutputList,
3429 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003430 "data_gen": {
3431 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3432 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003433 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003434 "matmul": {
3435 "op": Op.MATMUL,
3436 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003437 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003438 "build_fcn": (
3439 build_matmul,
3440 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003441 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003442 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003443 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003444 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003445 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003446 "error_if_validators": (
3447 TosaErrorValidator.evInputZeroPointNotZero,
3448 TosaErrorValidator.evWrongRank,
3449 TosaErrorValidator.evWrongInputType,
3450 TosaErrorValidator.evWrongOutputType,
3451 TosaErrorValidator.evWrongInputList,
3452 TosaErrorValidator.evWrongOutputList,
3453 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003454 "data_gen": {
3455 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003456 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003457 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003458 "max_pool2d": {
3459 "op": Op.MAX_POOL2D,
3460 "operands": (1, 0),
3461 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003462 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003463 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003464 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003465 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003466 TosaArgGen.agPooling,
3467 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003468 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003469 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003470 "error_if_validators": (
3471 TosaErrorValidator.evKernelSmallerOne,
3472 TosaErrorValidator.evStrideSmallerOne,
3473 TosaErrorValidator.evPadSmallerZero,
3474 TosaErrorValidator.evWrongRank,
3475 TosaErrorValidator.evWrongInputType,
3476 TosaErrorValidator.evWrongOutputType,
3477 TosaErrorValidator.evWrongInputList,
3478 TosaErrorValidator.evWrongOutputList,
3479 TosaErrorValidator.evPadLargerEqualKernel,
3480 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003481 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003482 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003483 "data_gen": {
3484 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3485 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003486 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003487 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003488 "transpose_conv2d_TEMPLATE": {
3489 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003490 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003491 "rank": (4, 4),
3492 "build_fcn": (
3493 build_transpose_conv2d,
3494 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003495 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003496 TosaArgGen.agTransposeConv2D,
3497 ),
3498 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003499 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003500 "invalid_test_validators": (
3501 TosaInvalidValidator.ivHeightWidthInvalid,
3502 TosaInvalidValidator.ivNonPositiveOutputShape,
3503 ),
3504 "error_if_validators": (
3505 TosaErrorValidator.evWrongInputType,
3506 TosaErrorValidator.evWrongOutputType,
3507 TosaErrorValidator.evWrongInputList,
3508 TosaErrorValidator.evWrongOutputList,
3509 TosaErrorValidator.evInputZeroPointNotZero,
3510 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003511 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003512 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003513 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003514 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003515 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003516 "data_gen": {
3517 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3518 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003519 "template": True,
3520 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003521 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003522 "clamp": {
3523 "op": Op.CLAMP,
3524 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003525 "build_fcn": (
3526 build_clamp,
3527 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003528 TosaTensorValuesGen.tvgLazyGenDefault,
3529 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003530 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003531 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003532 "error_if_validators": (
3533 TosaErrorValidator.evMaxSmallerMin,
3534 TosaErrorValidator.evWrongInputType,
3535 TosaErrorValidator.evWrongOutputType,
3536 TosaErrorValidator.evWrongInputList,
3537 TosaErrorValidator.evWrongOutputList,
3538 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003539 "data_gen": {
3540 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3541 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003542 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003543 "sigmoid": {
3544 "op": Op.SIGMOID,
3545 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003546 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003547 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003548 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003549 TosaTensorValuesGen.tvgLazyGenDefault,
3550 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003551 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003552 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003553 "error_if_validators": (
3554 TosaErrorValidator.evWrongInputType,
3555 TosaErrorValidator.evWrongOutputType,
3556 TosaErrorValidator.evWrongInputList,
3557 TosaErrorValidator.evWrongOutputList,
3558 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003559 "data_gen": {
3560 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3561 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003562 },
3563 "tanh": {
3564 "op": Op.TANH,
3565 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003566 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003567 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003568 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003569 TosaTensorValuesGen.tvgLazyGenDefault,
3570 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003571 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003572 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003573 "error_if_validators": (
3574 TosaErrorValidator.evWrongInputType,
3575 TosaErrorValidator.evWrongOutputType,
3576 TosaErrorValidator.evWrongInputList,
3577 TosaErrorValidator.evWrongOutputList,
3578 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003579 "data_gen": {
3580 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3581 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003582 "compliance": {
3583 "abs_error_lower_bound": 0.5,
3584 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003585 },
Won Jeon78155c62023-06-10 00:20:04 +00003586 "erf": {
3587 "op": Op.ERF,
3588 "operands": (1, 0),
3589 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003590 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003591 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003592 TosaTensorValuesGen.tvgLazyGenDefault,
3593 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003594 ),
3595 "types": TYPE_FP,
3596 "error_if_validators": (
3597 TosaErrorValidator.evWrongInputType,
3598 TosaErrorValidator.evWrongOutputType,
3599 TosaErrorValidator.evWrongInputList,
3600 TosaErrorValidator.evWrongOutputList,
3601 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003602 "data_gen": {
3603 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3604 },
3605 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003606 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003607 # Elementwise Binary Operators
3608 "add": {
3609 "op": Op.ADD,
3610 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003611 "build_fcn": (
3612 build_binary_broadcast,
3613 TosaTensorGen.tgBroadcastFuzz,
3614 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003615 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003616 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003617 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003618 "error_if_validators": (
3619 TosaErrorValidator.evRankMismatch,
3620 TosaErrorValidator.evWrongInputType,
3621 TosaErrorValidator.evWrongOutputType,
3622 TosaErrorValidator.evWrongInputList,
3623 TosaErrorValidator.evWrongOutputList,
3624 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003625 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003626 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003627 "data_gen": {
3628 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3629 },
3630 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003631 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003632 "arithmetic_right_shift": {
3633 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3634 "operands": (2, 0),
3635 "build_fcn": (
3636 build_arithmetic_right_shift,
3637 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003638 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003639 TosaArgGen.agArithmeticRightShift,
3640 ),
3641 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003642 "error_if_validators": (
3643 TosaErrorValidator.evRankMismatch,
3644 TosaErrorValidator.evWrongInputType,
3645 TosaErrorValidator.evWrongOutputType,
3646 TosaErrorValidator.evWrongInputList,
3647 TosaErrorValidator.evWrongOutputList,
3648 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003649 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003650 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003651 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003652 "bitwise_and": {
3653 "op": Op.BITWISE_AND,
3654 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003655 "build_fcn": (
3656 build_binary_broadcast,
3657 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003658 TosaTensorValuesGen.tvgLazyGenDefault,
3659 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003660 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003661 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003662 "error_if_validators": (
3663 TosaErrorValidator.evRankMismatch,
3664 TosaErrorValidator.evWrongInputType,
3665 TosaErrorValidator.evWrongOutputType,
3666 TosaErrorValidator.evWrongInputList,
3667 TosaErrorValidator.evWrongOutputList,
3668 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003669 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003670 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003671 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003672 "bitwise_or": {
3673 "op": Op.BITWISE_OR,
3674 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003675 "build_fcn": (
3676 build_binary_broadcast,
3677 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003678 TosaTensorValuesGen.tvgLazyGenDefault,
3679 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003680 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003681 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003682 "error_if_validators": (
3683 TosaErrorValidator.evRankMismatch,
3684 TosaErrorValidator.evWrongInputType,
3685 TosaErrorValidator.evWrongOutputType,
3686 TosaErrorValidator.evWrongInputList,
3687 TosaErrorValidator.evWrongOutputList,
3688 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003689 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003690 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003691 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003692 "bitwise_xor": {
3693 "op": Op.BITWISE_XOR,
3694 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003695 "build_fcn": (
3696 build_binary_broadcast,
3697 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003698 TosaTensorValuesGen.tvgLazyGenDefault,
3699 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003700 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003701 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003702 "error_if_validators": (
3703 TosaErrorValidator.evRankMismatch,
3704 TosaErrorValidator.evWrongInputType,
3705 TosaErrorValidator.evWrongOutputType,
3706 TosaErrorValidator.evWrongInputList,
3707 TosaErrorValidator.evWrongOutputList,
3708 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003709 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003710 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003711 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003712 "intdiv": {
3713 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003714 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003715 "build_fcn": (
3716 build_binary_broadcast,
3717 TosaTensorGen.tgBroadcastFuzz,
3718 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003719 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003720 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003721 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003722 "error_if_validators": (
3723 TosaErrorValidator.evRankMismatch,
3724 TosaErrorValidator.evWrongInputType,
3725 TosaErrorValidator.evWrongOutputType,
3726 TosaErrorValidator.evWrongInputList,
3727 TosaErrorValidator.evWrongOutputList,
3728 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003729 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003730 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003731 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003732 "logical_and": {
3733 "op": Op.LOGICAL_AND,
3734 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003735 "build_fcn": (
3736 build_binary_broadcast,
3737 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003738 TosaTensorValuesGen.tvgLazyGenDefault,
3739 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003740 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003741 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003742 "error_if_validators": (
3743 TosaErrorValidator.evRankMismatch,
3744 TosaErrorValidator.evWrongInputType,
3745 TosaErrorValidator.evWrongOutputType,
3746 TosaErrorValidator.evWrongInputList,
3747 TosaErrorValidator.evWrongOutputList,
3748 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003749 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003750 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003751 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003752 "logical_left_shift": {
3753 "op": Op.LOGICAL_LEFT_SHIFT,
3754 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003755 "build_fcn": (
3756 build_binary_broadcast,
3757 TosaTensorGen.tgBroadcastFuzz,
3758 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003759 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003760 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003761 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003762 "error_if_validators": (
3763 TosaErrorValidator.evRankMismatch,
3764 TosaErrorValidator.evWrongInputType,
3765 TosaErrorValidator.evWrongOutputType,
3766 TosaErrorValidator.evWrongInputList,
3767 TosaErrorValidator.evWrongOutputList,
3768 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003769 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003770 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003771 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003772 "logical_right_shift": {
3773 "op": Op.LOGICAL_RIGHT_SHIFT,
3774 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003775 "build_fcn": (
3776 build_binary_broadcast,
3777 TosaTensorGen.tgBroadcastFuzz,
3778 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003779 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003780 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003782 "error_if_validators": (
3783 TosaErrorValidator.evRankMismatch,
3784 TosaErrorValidator.evWrongInputType,
3785 TosaErrorValidator.evWrongOutputType,
3786 TosaErrorValidator.evWrongInputList,
3787 TosaErrorValidator.evWrongOutputList,
3788 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003789 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003790 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003791 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003792 "logical_or": {
3793 "op": Op.LOGICAL_OR,
3794 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003795 "build_fcn": (
3796 build_binary_broadcast,
3797 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003798 TosaTensorValuesGen.tvgLazyGenDefault,
3799 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003800 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003801 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003802 "error_if_validators": (
3803 TosaErrorValidator.evRankMismatch,
3804 TosaErrorValidator.evWrongInputType,
3805 TosaErrorValidator.evWrongOutputType,
3806 TosaErrorValidator.evWrongInputList,
3807 TosaErrorValidator.evWrongOutputList,
3808 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003809 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003810 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003811 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003812 "logical_xor": {
3813 "op": Op.LOGICAL_XOR,
3814 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003815 "build_fcn": (
3816 build_binary_broadcast,
3817 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003818 TosaTensorValuesGen.tvgLazyGenDefault,
3819 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003820 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003821 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003822 "error_if_validators": (
3823 TosaErrorValidator.evRankMismatch,
3824 TosaErrorValidator.evWrongInputType,
3825 TosaErrorValidator.evWrongOutputType,
3826 TosaErrorValidator.evWrongInputList,
3827 TosaErrorValidator.evWrongOutputList,
3828 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003829 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003830 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003831 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003832 "maximum": {
3833 "op": Op.MAXIMUM,
3834 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003835 "build_fcn": (
3836 build_binary_broadcast,
3837 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003838 TosaTensorValuesGen.tvgLazyGenDefault,
3839 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003840 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003841 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003842 "error_if_validators": (
3843 TosaErrorValidator.evRankMismatch,
3844 TosaErrorValidator.evWrongInputType,
3845 TosaErrorValidator.evWrongOutputType,
3846 TosaErrorValidator.evWrongInputList,
3847 TosaErrorValidator.evWrongOutputList,
3848 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003849 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003850 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003851 "data_gen": {
3852 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3853 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003854 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003855 "minimum": {
3856 "op": Op.MINIMUM,
3857 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003858 "build_fcn": (
3859 build_binary_broadcast,
3860 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003861 TosaTensorValuesGen.tvgLazyGenDefault,
3862 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003863 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003864 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003865 "error_if_validators": (
3866 TosaErrorValidator.evRankMismatch,
3867 TosaErrorValidator.evWrongInputType,
3868 TosaErrorValidator.evWrongOutputType,
3869 TosaErrorValidator.evWrongInputList,
3870 TosaErrorValidator.evWrongOutputList,
3871 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003872 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003873 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003874 "data_gen": {
3875 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3876 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003877 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003878 "mul": {
3879 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003880 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003881 "build_fcn": (
3882 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003883 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003884 TosaTensorValuesGen.tvgMul,
3885 TosaArgGen.agMul,
3886 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003887 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003888 "error_if_validators": (
3889 TosaErrorValidator.evWrongInputType,
3890 TosaErrorValidator.evWrongOutputType,
3891 TosaErrorValidator.evWrongInputList,
3892 TosaErrorValidator.evWrongOutputList,
3893 TosaErrorValidator.evRankMismatch,
3894 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003895 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003896 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003897 "data_gen": {
3898 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3899 },
3900 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003901 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003902 "pow": {
3903 "op": Op.POW,
3904 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003905 "build_fcn": (
3906 build_binary_broadcast,
3907 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003908 TosaTensorValuesGen.tvgPow,
3909 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003910 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003911 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003912 "error_if_validators": (
3913 TosaErrorValidator.evRankMismatch,
3914 TosaErrorValidator.evWrongInputType,
3915 TosaErrorValidator.evWrongOutputType,
3916 TosaErrorValidator.evWrongInputList,
3917 TosaErrorValidator.evWrongOutputList,
3918 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003919 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003920 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003921 "data_gen": {
3922 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3923 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003924 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003925 "sub": {
3926 "op": Op.SUB,
3927 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003928 "build_fcn": (
3929 build_binary_broadcast,
3930 TosaTensorGen.tgBroadcastFuzz,
3931 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003932 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003933 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003934 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003935 "error_if_validators": (
3936 TosaErrorValidator.evRankMismatch,
3937 TosaErrorValidator.evWrongInputType,
3938 TosaErrorValidator.evWrongOutputType,
3939 TosaErrorValidator.evWrongInputList,
3940 TosaErrorValidator.evWrongOutputList,
3941 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003942 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003943 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003944 "data_gen": {
3945 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3946 },
3947 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003948 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003949 "table": {
3950 "op": Op.TABLE,
3951 # Use the automatic generation functions to create the input array
3952 # but create the table tensor in the build function, as it may be
3953 # a different type from the input
3954 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003955 "build_fcn": (
3956 build_table,
3957 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003958 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003959 TosaArgGen.agTable,
3960 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003961 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003962 "error_if_validators": (
3963 TosaErrorValidator.evWrongInputType,
3964 TosaErrorValidator.evWrongOutputType,
3965 TosaErrorValidator.evWrongInputList,
3966 TosaErrorValidator.evWrongOutputList,
3967 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003968 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003969 # Elementwise Unary operators
3970 "abs": {
3971 "op": Op.ABS,
3972 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003973 "build_fcn": (
3974 build_unary,
3975 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003976 TosaTensorValuesGen.tvgLazyGenDefault,
3977 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003978 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003979 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003980 "error_if_validators": (
3981 TosaErrorValidator.evWrongInputType,
3982 TosaErrorValidator.evWrongOutputType,
3983 TosaErrorValidator.evWrongInputList,
3984 TosaErrorValidator.evWrongOutputList,
3985 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003986 "data_gen": {
3987 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3988 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003989 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003990 "bitwise_not": {
3991 "op": Op.BITWISE_NOT,
3992 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003993 "build_fcn": (
3994 build_unary,
3995 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00003996 TosaTensorValuesGen.tvgLazyGenDefault,
3997 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003998 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003999 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004000 "error_if_validators": (
4001 TosaErrorValidator.evWrongInputType,
4002 TosaErrorValidator.evWrongOutputType,
4003 TosaErrorValidator.evWrongInputList,
4004 TosaErrorValidator.evWrongOutputList,
4005 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004006 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004007 "ceil": {
4008 "op": Op.CEIL,
4009 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004010 "build_fcn": (
4011 build_unary,
4012 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004013 TosaTensorValuesGen.tvgLazyGenDefault,
4014 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004015 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004016 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004017 "error_if_validators": (
4018 TosaErrorValidator.evWrongInputType,
4019 TosaErrorValidator.evWrongOutputType,
4020 TosaErrorValidator.evWrongInputList,
4021 TosaErrorValidator.evWrongOutputList,
4022 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004023 "data_gen": {
4024 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4025 },
4026 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004027 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004028 "clz": {
4029 "op": Op.CLZ,
4030 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004031 "build_fcn": (
4032 build_unary,
4033 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004034 TosaTensorValuesGen.tvgLazyGenDefault,
4035 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004036 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004037 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004038 "error_if_validators": (
4039 TosaErrorValidator.evWrongInputType,
4040 TosaErrorValidator.evWrongOutputType,
4041 TosaErrorValidator.evWrongInputList,
4042 TosaErrorValidator.evWrongOutputList,
4043 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004044 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004045 "cos": {
4046 "op": Op.COS,
4047 "operands": (1, 0),
4048 "build_fcn": (
4049 build_unary,
4050 TosaTensorGen.tgBasic,
4051 TosaTensorValuesGen.tvgLazyGenDefault,
4052 TosaArgGen.agNone,
4053 ),
4054 "types": TYPE_FP,
4055 "error_if_validators": (
4056 TosaErrorValidator.evWrongInputType,
4057 TosaErrorValidator.evWrongOutputType,
4058 TosaErrorValidator.evWrongInputList,
4059 TosaErrorValidator.evWrongOutputList,
4060 ),
4061 "data_gen": {
4062 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4063 },
4064 "compliance": {"abs_error_normal_divisor": 2},
4065 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004066 "exp": {
4067 "op": Op.EXP,
4068 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004069 "build_fcn": (
4070 build_unary,
4071 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004072 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004073 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004074 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004075 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004076 "error_if_validators": (
4077 TosaErrorValidator.evWrongInputType,
4078 TosaErrorValidator.evWrongOutputType,
4079 TosaErrorValidator.evWrongInputList,
4080 TosaErrorValidator.evWrongOutputList,
4081 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004082 "data_gen": {
4083 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4084 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004085 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004086 "floor": {
4087 "op": Op.FLOOR,
4088 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004089 "build_fcn": (
4090 build_unary,
4091 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004092 TosaTensorValuesGen.tvgLazyGenDefault,
4093 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004094 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004095 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004096 "error_if_validators": (
4097 TosaErrorValidator.evWrongInputType,
4098 TosaErrorValidator.evWrongOutputType,
4099 TosaErrorValidator.evWrongInputList,
4100 TosaErrorValidator.evWrongOutputList,
4101 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004102 "data_gen": {
4103 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4104 },
4105 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004106 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004107 "log": {
4108 "op": Op.LOG,
4109 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004110 "build_fcn": (
4111 build_unary,
4112 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004113 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004114 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004115 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004116 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004117 "error_if_validators": (
4118 TosaErrorValidator.evWrongInputType,
4119 TosaErrorValidator.evWrongOutputType,
4120 TosaErrorValidator.evWrongInputList,
4121 TosaErrorValidator.evWrongOutputList,
4122 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004123 "data_gen": {
4124 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4125 },
4126 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004127 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004128 "logical_not": {
4129 "op": Op.LOGICAL_NOT,
4130 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004131 "build_fcn": (
4132 build_unary,
4133 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004134 TosaTensorValuesGen.tvgLazyGenDefault,
4135 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004136 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004137 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004138 "error_if_validators": (
4139 TosaErrorValidator.evWrongInputType,
4140 TosaErrorValidator.evWrongOutputType,
4141 TosaErrorValidator.evWrongInputList,
4142 TosaErrorValidator.evWrongOutputList,
4143 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004144 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004145 "negate": {
4146 "op": Op.NEGATE,
4147 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004148 "build_fcn": (
4149 build_unary,
4150 TosaTensorGen.tgBasic,
4151 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004152 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004153 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004154 "qgen": TosaQuantGen.qgUnary,
4155 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004156 "error_if_validators": (
4157 TosaErrorValidator.evInputZeroPointNotZero,
4158 TosaErrorValidator.evOutputZeroPointNotZero,
4159 TosaErrorValidator.evWrongInputType,
4160 TosaErrorValidator.evWrongOutputType,
4161 TosaErrorValidator.evWrongInputList,
4162 TosaErrorValidator.evWrongOutputList,
4163 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004164 "data_gen": {
4165 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4166 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004167 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004168 "reciprocal": {
4169 "op": Op.RECIPROCAL,
4170 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004171 "build_fcn": (
4172 build_unary,
4173 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004174 TosaTensorValuesGen.tvgLazyGenDefault,
4175 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004176 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004177 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004178 "error_if_validators": (
4179 TosaErrorValidator.evWrongInputType,
4180 TosaErrorValidator.evWrongOutputType,
4181 TosaErrorValidator.evWrongInputList,
4182 TosaErrorValidator.evWrongOutputList,
4183 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004184 "data_gen": {
4185 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4186 },
4187 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004188 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004189 "rsqrt": {
4190 "op": Op.RSQRT,
4191 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004192 "build_fcn": (
4193 build_unary,
4194 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004195 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004196 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004197 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004198 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004199 "error_if_validators": (
4200 TosaErrorValidator.evWrongInputType,
4201 TosaErrorValidator.evWrongOutputType,
4202 TosaErrorValidator.evWrongInputList,
4203 TosaErrorValidator.evWrongOutputList,
4204 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004205 "data_gen": {
4206 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4207 },
4208 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004209 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004210 "sin": {
4211 "op": Op.SIN,
4212 "operands": (1, 0),
4213 "build_fcn": (
4214 build_unary,
4215 TosaTensorGen.tgBasic,
4216 TosaTensorValuesGen.tvgLazyGenDefault,
4217 TosaArgGen.agNone,
4218 ),
4219 "types": TYPE_FP,
4220 "error_if_validators": (
4221 TosaErrorValidator.evWrongInputType,
4222 TosaErrorValidator.evWrongOutputType,
4223 TosaErrorValidator.evWrongInputList,
4224 TosaErrorValidator.evWrongOutputList,
4225 ),
4226 "data_gen": {
4227 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4228 },
4229 "compliance": {"abs_error_normal_divisor": 2},
4230 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004231 # Elementwise Ternary operators
4232 "select": {
4233 "op": Op.SELECT,
4234 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004235 "build_fcn": (
4236 build_select,
4237 TosaTensorGen.tgBroadcastFuzz,
4238 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004239 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004240 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004241 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004242 "error_if_validators": (
4243 TosaErrorValidator.evRankMismatch,
4244 TosaErrorValidator.evWrongInputType,
4245 TosaErrorValidator.evWrongOutputType,
4246 TosaErrorValidator.evWrongInputList,
4247 TosaErrorValidator.evWrongOutputList,
4248 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004249 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004250 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004251 "data_gen": {
4252 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4253 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004254 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004255 # Comparison operators
4256 "equal": {
4257 "op": Op.EQUAL,
4258 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004259 "build_fcn": (
4260 build_comparison,
4261 TosaTensorGen.tgBroadcastFuzz,
4262 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004263 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004264 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004265 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004266 "error_if_validators": (
4267 TosaErrorValidator.evRankMismatch,
4268 TosaErrorValidator.evWrongInputType,
4269 TosaErrorValidator.evWrongOutputType,
4270 TosaErrorValidator.evWrongInputList,
4271 TosaErrorValidator.evWrongOutputList,
4272 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004273 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004274 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004275 "data_gen": {
4276 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4277 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004278 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004279 "greater_equal": {
4280 "op": Op.GREATER_EQUAL,
4281 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004282 "build_fcn": (
4283 build_comparison,
4284 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004285 TosaTensorValuesGen.tvgLazyGenDefault,
4286 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004287 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004288 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004289 "error_if_validators": (
4290 TosaErrorValidator.evRankMismatch,
4291 TosaErrorValidator.evWrongInputType,
4292 TosaErrorValidator.evWrongOutputType,
4293 TosaErrorValidator.evWrongInputList,
4294 TosaErrorValidator.evWrongOutputList,
4295 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004296 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004297 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004298 "data_gen": {
4299 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4300 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004301 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004302 "greater": {
4303 "op": Op.GREATER,
4304 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004305 "build_fcn": (
4306 build_comparison,
4307 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004308 TosaTensorValuesGen.tvgLazyGenDefault,
4309 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004310 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004311 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004312 "error_if_validators": (
4313 TosaErrorValidator.evRankMismatch,
4314 TosaErrorValidator.evWrongInputType,
4315 TosaErrorValidator.evWrongOutputType,
4316 TosaErrorValidator.evWrongInputList,
4317 TosaErrorValidator.evWrongOutputList,
4318 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004319 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004320 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004321 "data_gen": {
4322 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4323 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004324 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004325 # Reduction operators
4326 "reduce_all": {
4327 "op": Op.REDUCE_ALL,
4328 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004329 "build_fcn": (
4330 build_reduce,
4331 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004332 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004333 TosaArgGen.agAxis,
4334 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004335 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004336 "error_if_validators": (
4337 TosaErrorValidator.evAxisLargerRank,
4338 TosaErrorValidator.evAxisSmallerZero,
4339 TosaErrorValidator.evShapeOfAxisNotOne,
4340 TosaErrorValidator.evWrongInputType,
4341 TosaErrorValidator.evWrongOutputType,
4342 TosaErrorValidator.evWrongRank,
4343 TosaErrorValidator.evWrongInputList,
4344 TosaErrorValidator.evWrongOutputList,
4345 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004346 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004347 "reduce_any": {
4348 "op": Op.REDUCE_ANY,
4349 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004350 "build_fcn": (
4351 build_reduce,
4352 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004353 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004354 TosaArgGen.agAxis,
4355 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004356 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004357 "error_if_validators": (
4358 TosaErrorValidator.evAxisLargerRank,
4359 TosaErrorValidator.evAxisSmallerZero,
4360 TosaErrorValidator.evShapeOfAxisNotOne,
4361 TosaErrorValidator.evWrongInputType,
4362 TosaErrorValidator.evWrongOutputType,
4363 TosaErrorValidator.evWrongRank,
4364 TosaErrorValidator.evWrongInputList,
4365 TosaErrorValidator.evWrongOutputList,
4366 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004367 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004368 "reduce_max": {
4369 "op": Op.REDUCE_MAX,
4370 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004371 "build_fcn": (
4372 build_reduce,
4373 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004374 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004375 TosaArgGen.agAxis,
4376 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004377 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004378 "error_if_validators": (
4379 TosaErrorValidator.evAxisLargerRank,
4380 TosaErrorValidator.evAxisSmallerZero,
4381 TosaErrorValidator.evShapeOfAxisNotOne,
4382 TosaErrorValidator.evWrongInputType,
4383 TosaErrorValidator.evWrongOutputType,
4384 TosaErrorValidator.evWrongRank,
4385 TosaErrorValidator.evWrongInputList,
4386 TosaErrorValidator.evWrongOutputList,
4387 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004388 "data_gen": {
4389 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4390 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004391 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004392 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004393 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004394 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004395 "build_fcn": (
4396 build_reduce,
4397 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004398 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004399 TosaArgGen.agAxis,
4400 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004401 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004402 "error_if_validators": (
4403 TosaErrorValidator.evAxisLargerRank,
4404 TosaErrorValidator.evAxisSmallerZero,
4405 TosaErrorValidator.evShapeOfAxisNotOne,
4406 TosaErrorValidator.evWrongInputType,
4407 TosaErrorValidator.evWrongOutputType,
4408 TosaErrorValidator.evWrongRank,
4409 TosaErrorValidator.evWrongInputList,
4410 TosaErrorValidator.evWrongOutputList,
4411 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004412 "data_gen": {
4413 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4414 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004415 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004416 "reduce_product": {
4417 "op": Op.REDUCE_PRODUCT,
4418 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004419 "build_fcn": (
4420 build_reduce,
4421 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004422 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004423 TosaArgGen.agAxis,
4424 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004425 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004426 "error_if_validators": (
4427 TosaErrorValidator.evAxisLargerRank,
4428 TosaErrorValidator.evAxisSmallerZero,
4429 TosaErrorValidator.evShapeOfAxisNotOne,
4430 TosaErrorValidator.evWrongInputType,
4431 TosaErrorValidator.evWrongOutputType,
4432 TosaErrorValidator.evWrongRank,
4433 TosaErrorValidator.evWrongInputList,
4434 TosaErrorValidator.evWrongOutputList,
4435 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004436 "data_gen": {
4437 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4438 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004439 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004440 "reduce_sum": {
4441 "op": Op.REDUCE_SUM,
4442 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004443 "build_fcn": (
4444 build_reduce,
4445 TosaTensorGen.tgBasic,
4446 TosaTensorValuesGen.tvgReduceSum,
4447 TosaArgGen.agAxis,
4448 ),
James Ward24dbc422022-10-19 12:20:31 +01004449 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004450 "error_if_validators": (
4451 TosaErrorValidator.evAxisLargerRank,
4452 TosaErrorValidator.evAxisSmallerZero,
4453 TosaErrorValidator.evShapeOfAxisNotOne,
4454 TosaErrorValidator.evWrongInputType,
4455 TosaErrorValidator.evWrongOutputType,
4456 TosaErrorValidator.evWrongRank,
4457 TosaErrorValidator.evWrongInputList,
4458 TosaErrorValidator.evWrongOutputList,
4459 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004460 "data_gen": {
4461 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4462 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004463 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004464 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004465 "concat": {
4466 "op": Op.CONCAT,
4467 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004468 "build_fcn": (
4469 build_concat,
4470 TosaTensorGen.tgConcat,
4471 TosaTensorValuesGen.tvgConcat,
4472 TosaArgGen.agAxis,
4473 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004474 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004475 "error_if_validators": (
4476 TosaErrorValidator.evAxisLargerRank,
4477 TosaErrorValidator.evAxisSmallerZero,
4478 TosaErrorValidator.evConcatInputRankMismatch,
4479 TosaErrorValidator.evConcatShapeSumMismatch,
4480 TosaErrorValidator.evConcatInputDimMismatch,
4481 TosaErrorValidator.evWrongInputType,
4482 TosaErrorValidator.evWrongOutputType,
4483 TosaErrorValidator.evWrongOutputList,
4484 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004485 "data_gen": {
4486 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4487 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004488 },
4489 "pad": {
4490 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004491 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004492 "build_fcn": (
4493 build_pad,
4494 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004495 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004496 TosaArgGen.agPad,
4497 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004498 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004499 "error_if_validators": (
4500 TosaErrorValidator.evWrongInputType,
4501 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004502 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004503 TosaErrorValidator.evWrongOutputType,
4504 TosaErrorValidator.evWrongInputList,
4505 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004506 TosaErrorValidator.evRankMismatch,
4507 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004508 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004509 "data_gen": {
4510 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4511 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004512 },
Won Jeona21b2e82023-08-10 10:33:01 +00004513 "dim": {
4514 "op": Op.DIM,
4515 "operands": (1, 0),
4516 "build_fcn": (
4517 build_dim,
4518 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004519 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004520 TosaArgGen.agAxis,
4521 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004522 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004523 "error_if_validators": (
4524 TosaErrorValidator.evAxisLargerRank,
4525 TosaErrorValidator.evAxisSmallerZero,
4526 TosaErrorValidator.evWrongInputType,
4527 TosaErrorValidator.evWrongInputList,
4528 TosaErrorValidator.evWrongOutputList,
4529 TosaErrorValidator.evWrongRank,
4530 ),
4531 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004532 "reshape": {
4533 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004534 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004535 "build_fcn": (
4536 build_reshape,
4537 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004538 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004539 TosaArgGen.agReshape,
4540 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004541 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004542 "error_if_validators": (
4543 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4544 TosaErrorValidator.evWrongInputType,
4545 TosaErrorValidator.evWrongOutputType,
4546 TosaErrorValidator.evWrongInputList,
4547 TosaErrorValidator.evWrongOutputList,
4548 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004549 "data_gen": {
4550 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4551 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004552 },
4553 "reverse": {
4554 "op": Op.REVERSE,
4555 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004556 "build_fcn": (
4557 build_reverse,
4558 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004559 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004560 TosaArgGen.agAxis,
4561 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004562 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004563 "error_if_validators": (
4564 TosaErrorValidator.evAxisSmallerZero,
4565 TosaErrorValidator.evAxisLargerRank,
4566 TosaErrorValidator.evWrongInputType,
4567 TosaErrorValidator.evWrongOutputType,
4568 TosaErrorValidator.evWrongInputList,
4569 TosaErrorValidator.evWrongOutputList,
4570 ),
evacha0198477222024-01-26 12:25:32 +00004571 "data_gen": {
4572 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4573 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004574 },
4575 "slice": {
4576 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004577 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004578 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004579 "build_fcn": (
4580 build_slice,
4581 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004582 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004583 TosaArgGen.agSlice,
4584 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004585 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004586 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004587 # TODO Turn off these error categories for now as the reference
4588 # model cannot allocate memory space for empty tensor. We probably
4589 # can report an accurate error messege at the right place during
4590 # exeuction.
4591 # TosaErrorValidator.evStartSmallerZero,
4592 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004593 TosaErrorValidator.evStartSizeOutsideBounds,
4594 TosaErrorValidator.evSizeOutputShapeMismatch,
4595 TosaErrorValidator.evInputSizeStartLengthMismatch,
4596 TosaErrorValidator.evWrongRank,
4597 TosaErrorValidator.evWrongInputType,
4598 TosaErrorValidator.evWrongOutputType,
4599 TosaErrorValidator.evWrongInputList,
4600 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004601 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004602 ),
evacha017f7d4252024-01-24 12:08:09 +00004603 "data_gen": {
4604 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4605 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004606 },
4607 "tile": {
4608 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004609 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004610 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004611 "build_fcn": (
4612 build_tile,
4613 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004614 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004615 TosaArgGen.agTile,
4616 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004617 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004618 "error_if_validators": (
4619 TosaErrorValidator.evWrongInputType,
4620 TosaErrorValidator.evWrongOutputType,
4621 TosaErrorValidator.evWrongInputList,
4622 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004623 TosaErrorValidator.evRankMismatch,
4624 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004625 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004626 "data_gen": {
4627 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4628 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004629 },
4630 "transpose": {
4631 "op": Op.TRANSPOSE,
4632 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004633 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004634 "build_fcn": (
4635 build_transpose,
4636 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004637 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004638 TosaArgGen.agTranspose,
4639 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004640 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004641 "error_if_validators": (
4642 TosaErrorValidator.evIndexOutsideBounds,
4643 TosaErrorValidator.evIndexUsedTwice,
4644 TosaErrorValidator.evWrongInputType,
4645 TosaErrorValidator.evWrongOutputType,
4646 TosaErrorValidator.evWrongInputList,
4647 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004648 TosaErrorValidator.evWrongRank,
4649 TosaErrorValidator.evRankMismatch,
4650 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004651 ),
evacha0198477222024-01-26 12:25:32 +00004652 "data_gen": {
4653 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4654 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004655 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004656 # Data nodes
4657 "const": {
4658 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004659 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004660 "build_fcn": (
4661 build_const,
4662 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004663 TosaTensorValuesGen.tvgLazyGenDefault,
4664 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004665 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004666 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha0198477222024-01-26 12:25:32 +00004667 "data_gen": {
4668 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4669 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004670 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004671 "identity": {
4672 "op": Op.IDENTITY,
4673 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004674 "build_fcn": (
4675 build_unary,
4676 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004677 TosaTensorValuesGen.tvgLazyGenDefault,
4678 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004679 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004680 "types": TYPE_FIB,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004681 "data_gen": {
4682 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4683 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004684 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004685 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004686 "gather": {
4687 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004688 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004689 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004690 "build_fcn": (
4691 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004692 TosaTensorGen.tgGather,
4693 TosaTensorValuesGen.tvgGather,
4694 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004695 ),
James Ward24dbc422022-10-19 12:20:31 +01004696 "types": (
4697 DType.INT8,
4698 DType.INT16,
4699 DType.INT32,
4700 DType.FP16,
4701 DType.BF16,
4702 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004703 DType.FP8E4M3,
4704 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004705 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004706 "error_if_validators": (
4707 TosaErrorValidator.evWrongInputType,
4708 TosaErrorValidator.evWrongOutputType,
4709 TosaErrorValidator.evWrongInputList,
4710 TosaErrorValidator.evWrongOutputList,
4711 TosaErrorValidator.evWrongRank,
4712 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004713 "data_gen": {
4714 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4715 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004716 },
4717 "scatter": {
4718 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004719 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004720 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004721 "build_fcn": (
4722 build_scatter,
4723 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004724 TosaTensorValuesGen.tvgScatter,
4725 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004726 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004727 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004728 "error_if_validators": (
4729 TosaErrorValidator.evWrongInputType,
4730 TosaErrorValidator.evWrongOutputType,
4731 TosaErrorValidator.evWrongInputList,
4732 TosaErrorValidator.evWrongOutputList,
4733 TosaErrorValidator.evWrongRank,
4734 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004735 "data_gen": {
4736 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4737 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004738 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004739 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004740 "resize": {
4741 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004742 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004743 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004744 "build_fcn": (
4745 build_resize,
4746 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004747 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004748 TosaArgGen.agResize,
4749 ),
James Ward24dbc422022-10-19 12:20:31 +01004750 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004751 "invalid_test_validators": (
4752 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004753 ),
4754 "error_if_validators": (
4755 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004756 TosaErrorValidator.evScaleSmallerEqualZero,
4757 TosaErrorValidator.evScaleNLargerMax,
4758 TosaErrorValidator.evScaleDLargerMax,
4759 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004760 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004761 TosaErrorValidator.evBorderSmallerMin,
4762 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004763 TosaErrorValidator.evWrongInputType,
4764 TosaErrorValidator.evWrongOutputType,
4765 TosaErrorValidator.evWrongRank,
4766 TosaErrorValidator.evWrongInputList,
4767 TosaErrorValidator.evWrongOutputList,
4768 TosaErrorValidator.evBatchMismatch,
4769 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004770 TosaErrorValidator.evResizeOutputShapeMismatch,
4771 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004772 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004773 "data_gen": {
4774 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4775 },
4776 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004777 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004778 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004779 "cast": {
4780 "op": Op.CAST,
4781 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004782 "build_fcn": (
4783 build_cast,
4784 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004785 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004786 TosaArgGen.agCast,
4787 ),
James Ward8b390432022-08-12 20:48:56 +01004788 "types": (
4789 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004790 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004791 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004792 DType.INT8,
4793 DType.INT16,
4794 DType.INT32,
4795 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004796 DType.FP8E4M3,
4797 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004798 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004799 "error_if_validators": (
4800 TosaErrorValidator.evWrongInputType,
4801 TosaErrorValidator.evWrongOutputType,
4802 TosaErrorValidator.evWrongInputList,
4803 TosaErrorValidator.evWrongOutputList,
4804 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004805 "data_gen": {
4806 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4807 },
4808 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004809 },
4810 "rescale": {
4811 "op": Op.RESCALE,
4812 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004813 "build_fcn": (
4814 build_rescale,
4815 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00004816 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004817 TosaArgGen.agRescale,
4818 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004819 "types": [
4820 DType.UINT8,
4821 DType.INT8,
4822 DType.INT16,
4823 DType.INT32,
4824 DType.INT48,
4825 DType.UINT16,
4826 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004827 "error_if_validators": (
4828 TosaErrorValidator.evInputZeroPointNotZero,
4829 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004830 TosaErrorValidator.evU16InputZeroPointNotValid,
4831 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004832 TosaErrorValidator.evScaleTrue,
4833 TosaErrorValidator.evScaleNotTrue,
4834 TosaErrorValidator.evWrongInputType,
4835 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004836 TosaErrorValidator.evWrongInputList,
4837 TosaErrorValidator.evWrongOutputList,
4838 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004839 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004840 # Custom
4841 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004842 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004843 # Two varients of cond_if, one that generates one of two constant tensors (no
4844 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4845 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004846 "cond_if_const": {
4847 "op": Op.COND_IF,
4848 "operands": (0, 2),
4849 "build_fcn": (
4850 build_cond_if_const,
4851 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004852 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004853 TosaArgGen.agCondIf,
4854 ),
4855 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004856 "error_if_validators": (
4857 TosaErrorValidator.evOutputListThenGraphMismatch,
4858 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004859 TosaErrorValidator.evCondIfCondNotMatchingBool,
4860 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004861 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004862 },
4863 "cond_if_binary": {
4864 "op": Op.COND_IF,
4865 "operands": (2, 0),
4866 "build_fcn": (
4867 build_cond_if_binary,
4868 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004869 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004870 TosaArgGen.agCondIf,
4871 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004872 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004873 "error_if_validators": (
4874 TosaErrorValidator.evInputListThenGraphMismatch,
4875 TosaErrorValidator.evInputListElseGraphMismatch,
4876 TosaErrorValidator.evOutputListThenGraphMismatch,
4877 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004878 TosaErrorValidator.evCondIfCondNotMatchingBool,
4879 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004880 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004881 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004882 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004883 "while_loop": {
4884 "op": Op.WHILE_LOOP,
4885 "operands": (0, 1),
4886 "build_fcn": (
4887 build_while_loop,
4888 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004889 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004890 TosaArgGen.agWhileLoop,
4891 ),
4892 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004893 "error_if_validators": (
4894 TosaErrorValidator.evInputListOutputListMismatch,
4895 TosaErrorValidator.evInputListCondGraphMismatch,
4896 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4897 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4898 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004899 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004900 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004901 },
Luke Hutton57287132023-02-06 14:54:18 +00004902 "fft2d": {
4903 "op": Op.FFT2D,
4904 "operands": (2, 0),
4905 "rank": (3, 3),
4906 "build_fcn": (
4907 build_fft2d,
4908 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004909 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004910 TosaArgGen.agFFT2d,
4911 ),
4912 "types": [DType.FP32],
4913 "error_if_validators": (
4914 TosaErrorValidator.evWrongInputType,
4915 TosaErrorValidator.evWrongOutputType,
4916 TosaErrorValidator.evWrongInputList,
4917 TosaErrorValidator.evWrongOutputList,
4918 TosaErrorValidator.evWrongRank,
4919 TosaErrorValidator.evBatchMismatch,
4920 TosaErrorValidator.evKernelNotPowerOfTwo,
4921 TosaErrorValidator.evFFTInputShapeMismatch,
4922 TosaErrorValidator.evFFTOutputShapeMismatch,
4923 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004924 "data_gen": {
4925 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4926 },
Luke Hutton57287132023-02-06 14:54:18 +00004927 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004928 "rfft2d": {
4929 "op": Op.RFFT2D,
4930 "operands": (1, 0),
4931 "rank": (3, 3),
4932 "build_fcn": (
4933 build_rfft2d,
4934 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004935 TosaTensorValuesGen.tvgLazyGenDefault,
4936 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004937 ),
4938 "types": [DType.FP32],
4939 "error_if_validators": (
4940 TosaErrorValidator.evWrongInputType,
4941 TosaErrorValidator.evWrongOutputType,
4942 TosaErrorValidator.evWrongInputList,
4943 TosaErrorValidator.evWrongOutputList,
4944 TosaErrorValidator.evWrongRank,
4945 TosaErrorValidator.evBatchMismatch,
4946 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004947 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004948 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004949 "data_gen": {
4950 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4951 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004952 },
Won Jeon74342e52024-01-09 00:34:40 +00004953 # Shape
4954 "add_shape": {
4955 "op": Op.ADD_SHAPE,
4956 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004957 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004958 "build_fcn": (
4959 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004960 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004961 TosaTensorValuesGen.tvgAddSub,
4962 TosaArgGen.agNone,
4963 ),
4964 "types": [DType.SHAPE],
4965 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4966 },
4967 "sub_shape": {
4968 "op": Op.SUB_SHAPE,
4969 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004970 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004971 "build_fcn": (
4972 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004973 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004974 TosaTensorValuesGen.tvgAddSub,
4975 TosaArgGen.agNone,
4976 ),
4977 "types": [DType.SHAPE],
4978 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4979 },
4980 "mul_shape": {
4981 "op": Op.MUL_SHAPE,
4982 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004983 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004984 "build_fcn": (
4985 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004986 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004987 TosaTensorValuesGen.tvgMul,
4988 TosaArgGen.agNone,
4989 ),
4990 "types": [DType.SHAPE],
4991 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4992 },
4993 "div_shape": {
4994 "op": Op.DIV_SHAPE,
4995 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004996 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004997 "build_fcn": (
4998 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004999 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005000 TosaTensorValuesGen.tvgIntDiv,
5001 TosaArgGen.agNone,
5002 ),
5003 "types": [DType.SHAPE],
5004 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5005 },
5006 "concat_shape": {
5007 "op": Op.CONCAT_SHAPE,
5008 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005009 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005010 "build_fcn": (
5011 build_concat,
5012 TosaTensorGen.tgConcat,
5013 TosaTensorValuesGen.tvgConcat,
5014 TosaArgGen.agNone,
5015 ),
5016 "types": [DType.SHAPE],
5017 "error_if_validators": (),
5018 },
5019 "const_shape": {
5020 "op": Op.CONST_SHAPE,
5021 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005022 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005023 "build_fcn": (
5024 build_const,
5025 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00005026 TosaTensorValuesGen.tvgLazyGenDefault,
5027 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00005028 ),
5029 "types": [DType.SHAPE],
5030 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005031 }
5032
Kevin Cheng550ccc52021-03-03 11:21:43 -08005033
Eric Kunzee5e26762020-10-13 16:11:07 -07005034class OutputShaper:
5035 # Methods in this class compute the expected output shape and datatype
5036 # for common classes of operations
5037 def __init__(self):
5038 pass
5039
5040 # These methods return arguments that can be used for
5041 # creating a new output tensor
5042 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005043 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5044 if error_name != ErrorIf.RankMismatch:
5045 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005046 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005047
5048 shape = []
5049 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005050 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005051 shape.append(b.shape[i])
5052 else:
5053 shape.append(a.shape[i])
5054
Jerry Ge135c9552023-05-23 20:59:32 +00005055 fuzz_idx = rng.integers(0, len(a.shape))
5056 if error_name == ErrorIf.DimensionMismatch:
5057 shape[fuzz_idx] += 1
5058
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005059 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005060 all_dtypes = [
5061 DType.INT8,
5062 DType.INT16,
5063 DType.INT32,
5064 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005065 DType.FP16,
5066 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005067 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005068 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005069 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5070 outputDType = rng.choice(wrong_dtypes)
5071 else:
5072 outputDType = a.dtype
5073
5074 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005075
5076 @staticmethod
5077 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005078 assert len(a.shape) == len(b.shape)
5079 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005080
5081 shape = []
5082 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005083 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005084 shape.append(a.shape[i])
5085
Kevin Cheng550ccc52021-03-03 11:21:43 -08005086 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005087
5088 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005089 def unaryOp(ser, rng, a, error_name=None):
5090 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005091 all_dtypes = [
5092 DType.INT8,
5093 DType.INT16,
5094 DType.INT32,
5095 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005096 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005097 DType.FP16,
5098 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005099 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005100 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5101 outputDType = rng.choice(wrong_dtypes)
5102 else:
5103 outputDType = a.dtype
5104
5105 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005106
5107 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005108 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005109 if error_name != ErrorIf.RankMismatch:
5110 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005111 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005112
5113 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005114 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005115 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005116 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5117 else:
5118 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005119
Jerry Ge135c9552023-05-23 20:59:32 +00005120 fuzz_idx = rng.integers(0, len(a.shape))
5121 if error_name == ErrorIf.DimensionMismatch:
5122 shape[fuzz_idx] += 1
5123
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005124 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005125 all_dtypes = [
5126 DType.INT8,
5127 DType.INT16,
5128 DType.INT32,
5129 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005130 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005131 DType.FP16,
5132 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005133 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005134 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5135 outputDType = rng.choice(wrong_dtypes)
5136 else:
5137 outputDType = a.dtype
5138
5139 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005140
5141 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005142 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005143 if error_name != ErrorIf.RankMismatch:
5144 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005145 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005146
5147 # Do broadcast
5148 shape = []
5149 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005150 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005151 shape.append(b.shape[i])
5152 else:
5153 shape.append(a.shape[i])
5154
Jerry Ge135c9552023-05-23 20:59:32 +00005155 fuzz_idx = rng.integers(0, len(a.shape))
5156 if error_name == ErrorIf.DimensionMismatch:
5157 shape[fuzz_idx] += 1
5158
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005159 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005160 wrong_dtypes = [
5161 DType.INT8,
5162 DType.INT16,
5163 DType.INT32,
5164 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005165 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005166 DType.FP16,
5167 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005168 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005169 outputDType = rng.choice(wrong_dtypes)
5170 else:
5171 outputDType = DType.BOOL
5172
5173 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005174
5175 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005176 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005177 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005178 if error_name not in [
5179 ErrorIf.AxisSmallerZero,
5180 ErrorIf.AxisLargerRank,
5181 ErrorIf.ShapeOfAxisNotOne,
5182 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005183 shape[axis] = 1
5184 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5185 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005186
Matthew Haddond6ce7252021-09-29 15:35:44 +01005187 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005188 all_dtypes = [
5189 DType.INT8,
5190 DType.INT16,
5191 DType.INT32,
5192 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005193 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005194 DType.FP16,
5195 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005196 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005197 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5198 outputDType = rng.choice(wrong_dtypes)
5199 else:
5200 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005201
Matthew Haddond6ce7252021-09-29 15:35:44 +01005202 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005203
5204 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005205 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005206 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005207
5208 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5209 del shape[axis]
5210
5211 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5212 remove = rng.choice([True, False])
5213 if remove and len(shape) > 1:
5214 del shape[0]
5215 else:
5216 shape.append(1)
5217 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5218 for i in range(len(shape)):
5219 shape[i] = shape[i] + rng.integers(1, 10)
5220
5221 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005222 all_dtypes = [
5223 DType.INT8,
5224 DType.INT16,
5225 DType.INT32,
5226 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005227 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005228 DType.FP16,
5229 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005230 DType.FP8E4M3,
5231 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005232 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005233 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5234 outputDType = rng.choice(wrong_dtypes)
5235 else:
5236 outputDType = DType.INT32
5237
5238 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005239
5240 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005241 def conv2dOp(
5242 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5243 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005244
5245 # IFM: NHWC
5246 # Filter: OHWI
5247 # OFM: NHWC
5248
Kevin Cheng550ccc52021-03-03 11:21:43 -08005249 h = (
5250 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005251 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005252 + padding[0]
5253 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005254 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005255 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005256
Kevin Cheng550ccc52021-03-03 11:21:43 -08005257 w = (
5258 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005259 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005260 + padding[2]
5261 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005262 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005263 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005264
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005265 if error_name == ErrorIf.ConvOutputShapeMismatch:
5266 choices = [1, 2, 3]
5267 change = rng.choice(choices)
5268 # increment in multiples of stride to not hit non-integer error case
5269 if change in [1, 3]:
5270 h = h + (rng.choice(choices) * strides[0])
5271 if change in [2, 3]:
5272 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005273
Eric Kunzee5e26762020-10-13 16:11:07 -07005274 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5275
James Ward8b390432022-08-12 20:48:56 +01005276 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005277 # Pick some potentially correct output dtype if input type is incorrect
5278 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005279 else:
James Ward8b390432022-08-12 20:48:56 +01005280 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005281
5282 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005283 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005284 excludes = [DType.FP16, DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00005285 if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
5286 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005287 else:
5288 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005289 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005290 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005291
Kevin Cheng550ccc52021-03-03 11:21:43 -08005292 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005293
5294 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005295 def conv3dOp(
5296 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5297 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005298
5299 # IFM: NDHWC
5300 # Filter: ODHWI
5301 # OFM: NDHWC
5302
5303 d = (
5304 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005305 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005306 + padding[0]
5307 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005308 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005309 ) // strides[0] + 1
5310
5311 h = (
5312 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005313 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005314 + padding[2]
5315 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005316 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005317 ) // strides[1] + 1
5318
5319 w = (
5320 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005321 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005322 + padding[4]
5323 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005324 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005325 ) // strides[2] + 1
5326
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005327 if error_name == ErrorIf.ConvOutputShapeMismatch:
5328 choices = [1, 2, 3, 4]
5329 change = rng.choice(choices)
5330 # increment in multiples of stride to not hit non-integer error case
5331 if change in [1, 4]:
5332 d = d + (rng.choice(choices) * strides[0])
5333 if change in [2, 4]:
5334 h = h + (rng.choice(choices) * strides[1])
5335 if change in [3, 4]:
5336 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005337
Kevin Cheng1533b852021-09-01 12:51:58 -07005338 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5339
James Ward8b390432022-08-12 20:48:56 +01005340 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005341 # Pick some potentially correct output dtype if input type is incorrect
5342 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005343 else:
James Ward8b390432022-08-12 20:48:56 +01005344 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005345
5346 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005347 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005348 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005349 else:
5350 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005351 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005352 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005353
5354 return ser.addOutput(ofm_shape, out_dtype)
5355
5356 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005357 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005358 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005359 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005360 # IFM: NHWC
5361 # Filter: HWCM
5362 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005363
Kevin Cheng550ccc52021-03-03 11:21:43 -08005364 h = (
5365 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005366 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005367 + padding[0]
5368 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005369 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005370 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005371
Kevin Cheng550ccc52021-03-03 11:21:43 -08005372 w = (
5373 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005374 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005375 + padding[2]
5376 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005377 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005378 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005379
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005380 if error_name == ErrorIf.ConvOutputShapeMismatch:
5381 choices = [1, 2, 3]
5382 change = rng.choice(choices)
5383 # increment in multiples of stride to not hit non-integer error case
5384 if change in [1, 3]:
5385 h = h + (rng.choice(choices) * strides[0])
5386 if change in [2, 3]:
5387 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005388
Eric Kunzee5e26762020-10-13 16:11:07 -07005389 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5390
James Ward8b390432022-08-12 20:48:56 +01005391 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005392 # Pick some potentially correct output dtype if input type is incorrect
5393 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005394 else:
James Ward8b390432022-08-12 20:48:56 +01005395 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005396
5397 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005398 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005399 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005400 else:
5401 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005402 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005403 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005404
Kevin Cheng550ccc52021-03-03 11:21:43 -08005405 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005406
5407 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005408 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005409 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005410 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005411 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005412 h = 1
5413 w = 1
5414 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005415 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5416 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005417
5418 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005419 choices = [1, 2, 3]
5420 change = rng.choice(choices)
5421 # increment in multiples of stride to not hit non-integer error case
5422 if change in [1, 3]:
5423 h = h + (rng.choice(choices) * stride[0])
5424 if change in [2, 3]:
5425 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005426 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005427
5428 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005429 all_dtypes = [
5430 DType.INT8,
5431 DType.INT16,
5432 DType.INT32,
5433 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005434 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005435 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005436 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005437 DType.FP8E4M3,
5438 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005439 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005440 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5441 outputDType = rng.choice(wrong_dtypes)
5442 else:
5443 outputDType = ifm.dtype
5444
5445 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005446
5447 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005448 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005449 # input: N, IC
5450 # filter: OC, IC
5451 # output: N, OC
5452
5453 output_shape = [input.shape[0], filter.shape[0]]
5454
James Ward8b390432022-08-12 20:48:56 +01005455 # Validated in arg_gen (also invalidated for ErrorIf)
5456 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005457
Kevin Cheng550ccc52021-03-03 11:21:43 -08005458 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005459
5460 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005461 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005462 # a: N, H, C
5463 # b: N, C, W
5464 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005465
Kevin Cheng2d60f002021-06-09 14:18:32 -07005466 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005467
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005468 if error_name == ErrorIf.WrongOutputType:
5469 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005470 incorrect_types = (
5471 DType.INT4,
5472 DType.INT8,
5473 DType.INT16,
5474 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005475 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005476 DType.FP16,
5477 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005478 DType.FP8E4M3,
5479 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005480 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005481 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005482 incorrect_types = (
5483 DType.INT4,
5484 DType.INT8,
5485 DType.INT16,
5486 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005487 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005488 DType.FP16,
5489 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005490 DType.FP8E4M3,
5491 DType.FP8E5M2,
5492 )
5493 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5494 incorrect_types = (
5495 DType.INT4,
5496 DType.INT8,
5497 DType.INT16,
5498 DType.INT32,
5499 DType.INT48,
5500 DType.FP32,
5501 DType.BF16,
5502 DType.FP8E4M3,
5503 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005504 )
James Ward24dbc422022-10-19 12:20:31 +01005505 elif (
5506 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5507 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005508 incorrect_types = (
5509 DType.INT4,
5510 DType.INT8,
5511 DType.INT16,
5512 DType.INT32,
5513 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005514 DType.FP8E4M3,
5515 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005516 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005517 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005518 elif error_name == ErrorIf.WrongInputType:
5519 # Pick some potentially correct output dtype if input type is incorrect
5520 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005521 else:
James Ward8b390432022-08-12 20:48:56 +01005522 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005523
Kevin Cheng550ccc52021-03-03 11:21:43 -08005524 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005525
5526 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005527 def concatOp(ser, rng, axis, inputs, error_name=None):
5528 input1 = inputs[0]
5529 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005530
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005531 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005532 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005533 if not (
5534 # unable to concat tensors of different ranks
5535 error_name == ErrorIf.ConcatInputRankMismatch
5536 # unable to concat tensors along an invalid axis
5537 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005538 ):
5539 for tensor in remaining_inputs:
5540 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005541
Matthew Haddon01c359d2021-10-15 16:30:48 +01005542 if error_name == ErrorIf.ConcatShapeSumMismatch:
5543 output_shape[axis] += rng.integers(5, 10)
5544
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005545 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005546 all_dtypes = {
5547 DType.INT8,
5548 DType.INT16,
5549 DType.INT32,
5550 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005551 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005552 DType.FP16,
5553 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005554 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005555 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5556 outputDType = rng.choice(wrong_dtypes)
5557 else:
5558 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005559
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005560 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005561
5562 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005563 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005564
5565 output_shape = a.shape.copy()
5566
5567 for i in range(len(output_shape)):
5568 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5569
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005570 if error_name == ErrorIf.PadOutputShapeMismatch:
5571 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005572 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005573 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005574 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005575
Matthew Haddone807aae2021-10-11 18:12:58 +01005576 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005577 all_dtypes = [
5578 DType.INT8,
5579 DType.INT16,
5580 DType.INT32,
5581 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005582 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005583 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005584 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005585 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005586 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5587 outputDType = rng.choice(wrong_dtypes)
5588 else:
5589 outputDType = a.dtype
5590
5591 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005592
5593 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005594 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005595 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005596
5597 if error_name == ErrorIf.WrongOutputType:
5598 all_dtypes = [
5599 DType.INT8,
5600 DType.INT16,
5601 DType.INT32,
5602 DType.INT48,
5603 DType.FP32,
5604 DType.FP16,
5605 DType.BF16,
5606 ]
5607 wrong_dtypes = list(set(all_dtypes))
5608 outputDType = rng.choice(wrong_dtypes)
5609 else:
5610 outputDType = DType.SHAPE
5611
5612 return ser.addOutput(output_shape, outputDType)
5613
5614 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005615 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005616 output_shape = shape.copy()
5617
Matthew Haddone807aae2021-10-11 18:12:58 +01005618 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5619 for i in range(len(output_shape)):
5620 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5621
5622 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005623 all_dtypes = [
5624 DType.INT8,
5625 DType.INT16,
5626 DType.INT32,
5627 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005628 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005629 DType.FP16,
5630 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005631 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005632 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5633 outputDType = rng.choice(wrong_dtypes)
5634 else:
5635 outputDType = a.dtype
5636
5637 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005638
5639 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005640 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005641
Matthew Haddone807aae2021-10-11 18:12:58 +01005642 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005643 all_dtypes = [
5644 DType.INT8,
5645 DType.INT16,
5646 DType.INT32,
5647 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005648 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005649 DType.FP16,
5650 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005651 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005652 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005653 outputDType = rng.choice(wrong_dtypes)
5654 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005655 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005656
Luke Huttona4e48ca2023-02-22 11:53:48 +00005657 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005658 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005659 for index in range(len(output_shape)):
5660 if output_shape[index] <= 2:
5661 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5662 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005663 output_shape[index] = output_shape[index] + rng.choice(
5664 [-2, -1, 1, 2]
5665 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005666 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5667 output_shape = input.shape.copy()
5668 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005669 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005670
5671 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005672
5673 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005674 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005675
5676 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005677 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005678
5679 for i in range(len(output_shape)):
5680 output_shape[i] = a.shape[i] * multiples[i]
5681
Luke Huttona4e48ca2023-02-22 11:53:48 +00005682 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005683 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005684
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005685 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005686 all_dtypes = [
5687 DType.INT8,
5688 DType.INT16,
5689 DType.INT32,
5690 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005691 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005692 DType.FP16,
5693 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005694 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005695 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5696 outputDType = rng.choice(wrong_dtypes)
5697 else:
5698 outputDType = a.dtype
5699
5700 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005701
5702 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005703 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005704 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005705
Kevin Cheng550ccc52021-03-03 11:21:43 -08005706 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005707
Luke Huttona4e48ca2023-02-22 11:53:48 +00005708 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005709 for i in range(len(output_shape)):
5710 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005711
Luke Huttona4e48ca2023-02-22 11:53:48 +00005712 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5713 for i in range(len(output_shape)):
5714 output_shape[i] += rng.integers(1, 10)
5715 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005716 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005717
Matthew Haddone807aae2021-10-11 18:12:58 +01005718 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005719 all_dtypes = [
5720 DType.INT8,
5721 DType.INT16,
5722 DType.INT32,
5723 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005724 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005725 DType.FP16,
5726 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005727 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005728 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5729 outputDType = rng.choice(wrong_dtypes)
5730 else:
5731 outputDType = a.dtype
5732
5733 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005734
5735 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005736 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005737 if error_name != ErrorIf.WrongRank:
5738 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005739 assert len(indices.shape) == 2
5740 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005741
Kevin Cheng77d0f762020-11-24 10:26:32 -08005742 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5743
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005744 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005745 all_dtypes = [
5746 DType.INT8,
5747 DType.INT16,
5748 DType.INT32,
5749 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005750 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005751 DType.FP16,
5752 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005753 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005754 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5755 outputDType = rng.choice(wrong_dtypes)
5756 else:
5757 outputDType = values.dtype
5758
5759 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005760
5761 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005762 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005763 if error_name != ErrorIf.WrongRank:
5764 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005765 assert len(indices.shape) == 2
5766 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005767 assert values_in.shape[0] == indices.shape[0] # N
5768 assert input.shape[1] == indices.shape[1] # W
5769 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005770
5771 output_shape = values_in.shape
5772
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005773 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005774 all_dtypes = [
5775 DType.INT8,
5776 DType.INT16,
5777 DType.INT32,
5778 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005779 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005780 DType.FP16,
5781 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005782 DType.FP8E4M3,
5783 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005784 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005785 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5786 outputDType = rng.choice(wrong_dtypes)
5787 else:
5788 outputDType = values_in.dtype
5789
5790 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005791
5792 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005793 def tableOp(ser, rng, input, error_name=None):
5794 # Same shape as the input, dtype dependent on input dtype
5795 if error_name != ErrorIf.WrongInputType:
5796 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005797 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005798 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005799 wrong_dtypes = [
5800 DType.INT8,
5801 DType.INT16,
5802 DType.INT32,
5803 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005804 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005805 DType.FP16,
5806 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005807 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005808 wrong_dtypes.remove(output_dtype)
5809 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005810 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005811
5812 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005813 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005814 serializer,
5815 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005816 input,
5817 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005818 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005819 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005820 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005821 input_dtype,
5822 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005823 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005824 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005825 # Calculate OH, OW
5826 scale_y_n = scale[0]
5827 scale_y_d = scale[1]
5828 scale_x_n = scale[2]
5829 scale_x_d = scale[3]
5830 if error_name == ErrorIf.ScaleSmallerEqualZero:
5831 scale_y_n = max(scale_y_n, 1)
5832 scale_y_d = max(scale_y_d, 1)
5833 scale_x_n = max(scale_x_n, 1)
5834 scale_x_d = max(scale_x_d, 1)
5835
5836 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5837 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5838
5839 if error_name is not None:
5840 # Make sure the output tensor is valid, which can occur when
5841 # scale, offset or border have been changed for ERROR_IFs
5842 oh = max(oh, 1)
5843 ow = max(ow, 1)
5844 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005845 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5846 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005847
5848 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5849 choices = [1, 2, 3]
5850 change = rng.choice(choices)
5851 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5852 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005853 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005854 oh -= scale_y_d
5855 assert oh > 0 # Should have been caught in agResize
5856 else:
5857 oh += scale_y_d
5858 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005859 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005860 ow -= scale_x_d
5861 assert ow > 0 # Should have been caught in agResize
5862 else:
5863 ow += scale_x_d
5864
Matthew Haddon848efb42021-09-09 12:30:53 +01005865 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005866 output_dims = [
5867 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005868 oh,
5869 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005870 input.shape[0],
5871 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005872 elif error_name == ErrorIf.BatchMismatch:
5873 output_dims = [
5874 input.shape[0] + rng.integers(1, 10),
5875 oh,
5876 ow,
5877 input.shape[3],
5878 ]
5879 elif error_name == ErrorIf.ChannelMismatch:
5880 output_dims = [
5881 input.shape[0],
5882 oh,
5883 ow,
5884 input.shape[3] + rng.integers(1, 10),
5885 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005886 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005887 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005888
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005889 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005890
5891 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005892 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005893 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005894
5895 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005896 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005897 if error_name == ErrorIf.ConvOutputShapeMismatch:
5898 choices = [1, 2, 3]
5899 change = rng.choice(choices)
5900 if change in [1, 3]:
5901 output_shape[1] = output_shape[1] + rng.choice(choices)
5902 if change in [2, 3]:
5903 output_shape[2] = output_shape[2] + rng.choice(choices)
5904
James Ward8b390432022-08-12 20:48:56 +01005905 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005906 # Pick some potentially correct output dtype if input type is incorrect
5907 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005908 else:
James Ward8b390432022-08-12 20:48:56 +01005909 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005910
5911 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005912 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005913 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005914 else:
5915 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005916 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005917 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005918
Kevin Cheng550ccc52021-03-03 11:21:43 -08005919 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005920
5921 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005922 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5923 outputs = []
5924
5925 assert ifm1.dtype == ifm2.dtype
5926 input_dtype = ifm1.dtype
5927
5928 if error_name != ErrorIf.FFTInputShapeMismatch:
5929 assert ifm1.shape == ifm2.shape
5930
5931 input_shape = ifm1.shape
5932 if error_name != ErrorIf.WrongRank:
5933 assert len(input_shape) == 3
5934
5935 output_shape = input_shape.copy()
5936 output_dtype = input_dtype
5937
5938 if error_name == ErrorIf.WrongOutputType:
5939 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005940 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005941 output_dtype = rng.choice(wrong_dtypes)
5942 elif error_name == ErrorIf.BatchMismatch:
5943 output_shape[0] += rng.integers(1, 10)
5944 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5945 modify_dim = rng.choice([1, 2])
5946 output_shape[modify_dim] += rng.integers(1, 10)
5947
5948 outputs.append(serializer.addOutput(output_shape, output_dtype))
5949 outputs.append(serializer.addOutput(output_shape, output_dtype))
5950 return outputs
5951
5952 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005953 def rfft2dOp(serializer, rng, value, error_name=None):
5954 outputs = []
5955
5956 input_shape = value.shape
5957 if error_name != ErrorIf.WrongRank:
5958 assert len(input_shape) == 3
5959
5960 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5961
5962 output_dtype = value.dtype
5963 if error_name == ErrorIf.WrongOutputType:
5964 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005965 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005966 output_dtype = rng.choice(wrong_dtypes)
5967 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005968 output_shape[0] += rng.integers(1, 10)
5969 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5970 modify_dim = rng.choice([1, 2])
5971 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005972
5973 outputs.append(serializer.addOutput(output_shape, output_dtype))
5974 outputs.append(serializer.addOutput(output_shape, output_dtype))
5975 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005976
5977 @staticmethod
5978 def addShapeOp(ser, rng, a, b, error_name=None):
5979 if error_name != ErrorIf.RankMismatch:
5980 assert len(a.shape) == len(b.shape)
5981 assert a.dtype == b.dtype
5982
5983 shape = []
5984 for i in range(len(a.shape)):
5985 shape.append(a.shape[i])
5986
5987 fuzz_idx = rng.integers(0, len(a.shape))
5988 if error_name == ErrorIf.DimensionMismatch:
5989 shape[fuzz_idx] += 1
5990
5991 if error_name == ErrorIf.WrongOutputType:
5992 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
5993 outputDType = rng.choice(wrong_dtypes)
5994 else:
5995 outputDType = DType.SHAPE
5996 return ser.addOutput(shape, outputDType)